ruff,私聊视为提及了bot
This commit is contained in:
@@ -64,50 +64,50 @@ class AtAction(BaseAction):
|
||||
# 使用回复器生成艾特回复,而不是直接发送命令
|
||||
from src.chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
|
||||
# 获取当前聊天流
|
||||
chat_manager = get_chat_manager()
|
||||
chat_stream = self.chat_stream or chat_manager.get_stream(self.chat_id)
|
||||
|
||||
|
||||
if not chat_stream:
|
||||
logger.error(f"找不到聊天流: {self.chat_stream}")
|
||||
return False, "聊天流不存在"
|
||||
|
||||
|
||||
# 创建回复器实例
|
||||
replyer = DefaultReplyer(chat_stream)
|
||||
|
||||
|
||||
# 构建回复对象,将艾特消息作为回复目标
|
||||
reply_to = f"{user_name}:{at_message}"
|
||||
extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}"
|
||||
|
||||
|
||||
# 使用回复器生成回复
|
||||
success, llm_response, prompt = await replyer.generate_reply_with_context(
|
||||
reply_to=reply_to,
|
||||
extra_info=extra_info,
|
||||
enable_tool=False, # 艾特回复通常不需要工具调用
|
||||
from_plugin=False
|
||||
from_plugin=False,
|
||||
)
|
||||
|
||||
|
||||
if success and llm_response:
|
||||
# 获取生成的回复内容
|
||||
reply_content = llm_response.get("content", "")
|
||||
if reply_content:
|
||||
# 获取用户QQ号,发送真正的艾特消息
|
||||
user_id = user_info.get("user_id")
|
||||
|
||||
|
||||
# 发送真正的艾特命令,使用回复器生成的智能内容
|
||||
await self.send_command(
|
||||
"SEND_AT_MESSAGE",
|
||||
args={"qq_id": user_id, "text": reply_content},
|
||||
display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}",
|
||||
)
|
||||
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True,
|
||||
action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}",
|
||||
action_done=True,
|
||||
)
|
||||
|
||||
|
||||
logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}")
|
||||
return True, "智能艾特消息发送成功"
|
||||
else:
|
||||
@@ -116,7 +116,7 @@ class AtAction(BaseAction):
|
||||
else:
|
||||
logger.error("回复器生成回复失败")
|
||||
return False, "回复生成失败"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True)
|
||||
await self.store_action_info(
|
||||
|
||||
@@ -70,7 +70,9 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 2. 获取所有有效的表情包对象
|
||||
emoji_manager = get_emoji_manager()
|
||||
all_emojis_obj: list[MaiEmoji] = [e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description]
|
||||
all_emojis_obj: list[MaiEmoji] = [
|
||||
e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description
|
||||
]
|
||||
if not all_emojis_obj:
|
||||
logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包")
|
||||
return False, "无法获取任何带有描述的有效表情包"
|
||||
@@ -91,12 +93,12 @@ class EmojiAction(BaseAction):
|
||||
# 4. 准备情感数据和后备列表
|
||||
emotion_map = {}
|
||||
all_emojis_data = []
|
||||
|
||||
|
||||
for emoji in all_emojis_obj:
|
||||
b64 = image_path_to_base64(emoji.full_path)
|
||||
if not b64:
|
||||
continue
|
||||
|
||||
|
||||
desc = emoji.description
|
||||
emotions = emoji.emotion
|
||||
all_emojis_data.append((b64, desc))
|
||||
@@ -168,16 +170,18 @@ class EmojiAction(BaseAction):
|
||||
|
||||
# 使用模糊匹配来查找最相关的情感标签
|
||||
matched_key = next((key for key in emotion_map if chosen_emotion in key), None)
|
||||
|
||||
|
||||
if matched_key:
|
||||
emoji_base64, emoji_description = random.choice(emotion_map[matched_key])
|
||||
logger.info(f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}")
|
||||
logger.info(
|
||||
f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
|
||||
)
|
||||
emoji_base64, emoji_description = random.choice(all_emojis_data)
|
||||
|
||||
|
||||
elif global_config.emoji.emoji_selection_mode == "description":
|
||||
# --- 详细描述选择模式 ---
|
||||
# 获取最近的5条消息内容用于判断
|
||||
@@ -226,15 +230,23 @@ class EmojiAction(BaseAction):
|
||||
logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}")
|
||||
|
||||
# 简单关键词匹配
|
||||
matched_emoji = next((item for item in all_emojis_data if chosen_description.lower() in item[1].lower() or item[1].lower() in chosen_description.lower()), None)
|
||||
|
||||
matched_emoji = next(
|
||||
(
|
||||
item
|
||||
for item in all_emojis_data
|
||||
if chosen_description.lower() in item[1].lower()
|
||||
or item[1].lower() in chosen_description.lower()
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
# 如果包含匹配失败,尝试关键词匹配
|
||||
if not matched_emoji:
|
||||
keywords = ['惊讶', '困惑', '呆滞', '震惊', '懵', '无语', '萌', '可爱']
|
||||
keywords = ["惊讶", "困惑", "呆滞", "震惊", "懵", "无语", "萌", "可爱"]
|
||||
for keyword in keywords:
|
||||
if keyword in chosen_description:
|
||||
for item in all_emojis_data:
|
||||
if any(k in item[1] for k in ['呆', '萌', '惊', '困惑', '无语']):
|
||||
if any(k in item[1] for k in ["呆", "萌", "惊", "困惑", "无语"]):
|
||||
matched_emoji = item
|
||||
break
|
||||
if matched_emoji:
|
||||
@@ -255,7 +267,9 @@ class EmojiAction(BaseAction):
|
||||
|
||||
if not success:
|
||||
logger.error(f"{self.log_prefix} 表情包发送失败")
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False)
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False
|
||||
)
|
||||
return False, "表情包发送失败"
|
||||
|
||||
# 发送成功后,记录到历史
|
||||
@@ -263,8 +277,10 @@ class EmojiAction(BaseAction):
|
||||
add_emoji_to_history(self.chat_id, emoji_description)
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
|
||||
|
||||
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True)
|
||||
|
||||
await self.store_action_info(
|
||||
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True
|
||||
)
|
||||
|
||||
return True, f"发送表情包: {emoji_description}"
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from src.plugin_system import BaseEventHandler
|
||||
from src.plugin_system.base.base_event import HandlerResult
|
||||
|
||||
@@ -1748,6 +1747,7 @@ class SetGroupSignHandler(BaseEventHandler):
|
||||
logger.error("事件 napcat_set_group_sign 请求失败!")
|
||||
return HandlerResult(False, False, {"status": "error"})
|
||||
|
||||
|
||||
# ===PERSONAL===
|
||||
class SetInputStatusHandler(BaseEventHandler):
|
||||
handler_name: str = "napcat_set_input_status_handler"
|
||||
|
||||
@@ -227,7 +227,7 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
await reassembler.start_cleanup_task()
|
||||
|
||||
logger.info("开始启动Napcat Adapter")
|
||||
|
||||
|
||||
# 创建单独的异步任务,防止阻塞主线程
|
||||
asyncio.create_task(self._start_maibot_connection())
|
||||
asyncio.create_task(napcat_server(self.plugin_config))
|
||||
@@ -238,10 +238,10 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
|
||||
"""非阻塞方式启动MaiBot连接,等待主服务启动后再连接"""
|
||||
# 等待一段时间让MaiBot主服务完全启动
|
||||
await asyncio.sleep(5)
|
||||
|
||||
|
||||
max_attempts = 10
|
||||
attempt = 0
|
||||
|
||||
|
||||
while attempt < max_attempts:
|
||||
try:
|
||||
logger.info(f"尝试连接MaiBot (第{attempt + 1}次)")
|
||||
@@ -285,7 +285,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
def enable_plugin(self) -> bool:
|
||||
"""通过配置文件动态控制插件启用状态"""
|
||||
# 如果已经通过配置加载了状态,使用配置中的值
|
||||
if hasattr(self, '_is_enabled'):
|
||||
if hasattr(self, "_is_enabled"):
|
||||
return self._is_enabled
|
||||
# 否则使用默认值(禁用状态)
|
||||
return False
|
||||
@@ -308,60 +308,107 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
|
||||
},
|
||||
"napcat_server": {
|
||||
"mode": ConfigField(type=str, default="reverse", description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]),
|
||||
"mode": ConfigField(
|
||||
type=str,
|
||||
default="reverse",
|
||||
description="连接模式:reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
|
||||
choices=["reverse", "forward"],
|
||||
),
|
||||
"host": ConfigField(type=str, default="localhost", description="主机地址"),
|
||||
"port": ConfigField(type=int, default=8095, description="端口号"),
|
||||
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)"),
|
||||
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"),
|
||||
"url": ConfigField(
|
||||
type=str,
|
||||
default="",
|
||||
description="正向连接时的完整WebSocket URL,如 ws://localhost:8080/ws (仅在forward模式下使用)",
|
||||
),
|
||||
"access_token": ConfigField(
|
||||
type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"
|
||||
),
|
||||
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
|
||||
},
|
||||
"maibot_server": {
|
||||
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"),
|
||||
"host": ConfigField(
|
||||
type=str, default="localhost", description="麦麦在.env文件中设置的主机地址,即HOST字段"
|
||||
),
|
||||
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口,即PORT字段"),
|
||||
"platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"),
|
||||
},
|
||||
"voice": {
|
||||
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"),
|
||||
"use_tts": ConfigField(
|
||||
type=bool, default=False, description="是否使用tts语音(请确保你配置了tts并有对应的adapter)"
|
||||
),
|
||||
},
|
||||
"slicing": {
|
||||
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"),
|
||||
"max_frame_size": ConfigField(
|
||||
type=int, default=64, description="WebSocket帧的最大大小,单位为字节,默认64KB"
|
||||
),
|
||||
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
|
||||
},
|
||||
"debug": {
|
||||
"level": ConfigField(type=str, default="INFO", description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
|
||||
"level": ConfigField(
|
||||
type=str,
|
||||
default="INFO",
|
||||
description="日志等级(DEBUG, INFO, WARNING, ERROR, CRITICAL)",
|
||||
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
|
||||
),
|
||||
},
|
||||
"features": {
|
||||
# 权限设置
|
||||
"group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]),
|
||||
"group_list_type": ConfigField(
|
||||
type=str,
|
||||
default="blacklist",
|
||||
description="群聊列表类型:whitelist(白名单)或 blacklist(黑名单)",
|
||||
choices=["whitelist", "blacklist"],
|
||||
),
|
||||
"group_list": ConfigField(type=list, default=[], description="群聊ID列表"),
|
||||
"private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)", choices=["whitelist", "blacklist"]),
|
||||
"private_list_type": ConfigField(
|
||||
type=str,
|
||||
default="blacklist",
|
||||
description="私聊列表类型:whitelist(白名单)或 blacklist(黑名单)",
|
||||
choices=["whitelist", "blacklist"],
|
||||
),
|
||||
"private_list": ConfigField(type=list, default=[], description="用户ID列表"),
|
||||
"ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人"),
|
||||
"ban_user_id": ConfigField(
|
||||
type=list, default=[], description="全局禁止用户ID列表,这些用户无法在任何地方使用机器人"
|
||||
),
|
||||
"ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"),
|
||||
|
||||
# 聊天功能设置
|
||||
"enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"),
|
||||
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"),
|
||||
"poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"),
|
||||
"poke_debounce_seconds": ConfigField(
|
||||
type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"
|
||||
),
|
||||
"enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"),
|
||||
"reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"),
|
||||
"enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"),
|
||||
|
||||
# 视频处理设置
|
||||
"enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"),
|
||||
"max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制(MB)"),
|
||||
"download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"),
|
||||
"supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"),
|
||||
|
||||
"supported_formats": ConfigField(
|
||||
type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"
|
||||
),
|
||||
# 消息缓冲设置
|
||||
"enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"),
|
||||
"message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"),
|
||||
"message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"),
|
||||
"message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"),
|
||||
"message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"),
|
||||
"message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"),
|
||||
"message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "!", ".", "。", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"),
|
||||
}
|
||||
"message_buffer_enable_private": ConfigField(
|
||||
type=bool, default=True, description="是否启用私聊消息缓冲合并"
|
||||
),
|
||||
"message_buffer_interval": ConfigField(
|
||||
type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"
|
||||
),
|
||||
"message_buffer_initial_delay": ConfigField(
|
||||
type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"
|
||||
),
|
||||
"message_buffer_max_components": ConfigField(
|
||||
type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"
|
||||
),
|
||||
"message_buffer_block_prefixes": ConfigField(
|
||||
type=list,
|
||||
default=["/", "!", "!", ".", "。", "#", "%"],
|
||||
description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲",
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# 配置节描述
|
||||
@@ -374,7 +421,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"voice": "发送语音设置",
|
||||
"slicing": "WebSocket消息切片设置",
|
||||
"debug": "调试设置",
|
||||
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)"
|
||||
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)",
|
||||
}
|
||||
|
||||
def register_events(self):
|
||||
@@ -409,6 +456,7 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
chunker.set_plugin_config(self.config)
|
||||
# 设置response_pool的插件配置
|
||||
from .src.response_pool import set_plugin_config as set_response_pool_config
|
||||
|
||||
set_response_pool_config(self.config)
|
||||
# 设置send_handler的插件配置
|
||||
send_handler.set_plugin_config(self.config)
|
||||
@@ -418,4 +466,4 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
notice_handler.set_plugin_config(self.config)
|
||||
# 设置meta_event_handler的插件配置
|
||||
meta_event_handler.set_plugin_config(self.config)
|
||||
# 设置其他handler的插件配置(现在由component_registry在注册时自动设置)
|
||||
# 设置其他handler的插件配置(现在由component_registry在注册时自动设置)
|
||||
|
||||
@@ -102,7 +102,9 @@ class SimpleMessageBuffer:
|
||||
return True
|
||||
|
||||
# 检查屏蔽前缀
|
||||
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []))
|
||||
block_prefixes = tuple(
|
||||
config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])
|
||||
)
|
||||
|
||||
text = text.strip()
|
||||
if text.startswith(block_prefixes):
|
||||
@@ -134,9 +136,13 @@ class SimpleMessageBuffer:
|
||||
|
||||
# 检查是否启用对应类型的缓冲
|
||||
message_type = event_data.get("message_type", "")
|
||||
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False):
|
||||
if message_type == "group" and not config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_enable_group", False
|
||||
):
|
||||
return False
|
||||
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False):
|
||||
elif message_type == "private" and not config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_enable_private", False
|
||||
):
|
||||
return False
|
||||
|
||||
# 提取文本
|
||||
@@ -158,7 +164,9 @@ class SimpleMessageBuffer:
|
||||
session = self.buffer_pool[session_id]
|
||||
|
||||
# 检查是否超过最大组件数量
|
||||
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5):
|
||||
if len(session.messages) >= config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_max_components", 5
|
||||
):
|
||||
logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并")
|
||||
asyncio.create_task(self._force_merge_session(session_id))
|
||||
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)
|
||||
|
||||
@@ -14,7 +14,7 @@ def create_router(plugin_config: dict):
|
||||
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq")
|
||||
host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost")
|
||||
port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000)
|
||||
|
||||
|
||||
route_config = RouteConfig(
|
||||
route_config={
|
||||
platform_name: TargetConfig(
|
||||
@@ -32,7 +32,7 @@ async def mmc_start_com(plugin_config: dict = None):
|
||||
logger.info("正在连接MaiBot")
|
||||
if plugin_config:
|
||||
create_router(plugin_config)
|
||||
|
||||
|
||||
if router:
|
||||
router.register_class_handler(send_handler.handle_message)
|
||||
await router.run()
|
||||
|
||||
@@ -32,7 +32,7 @@ class NoticeType: # 通知事件
|
||||
group_recall = "group_recall" # 群聊消息撤回
|
||||
notify = "notify"
|
||||
group_ban = "group_ban" # 群禁言
|
||||
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
|
||||
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
|
||||
|
||||
class Notify:
|
||||
poke = "poke" # 戳一戳
|
||||
|
||||
@@ -100,7 +100,7 @@ class MessageHandler:
|
||||
# 检查群聊黑白名单
|
||||
group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist")
|
||||
group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", [])
|
||||
|
||||
|
||||
if group_list_type == "whitelist":
|
||||
if group_id not in group_list:
|
||||
logger.warning("群聊不在白名单中,消息被丢弃")
|
||||
@@ -111,9 +111,11 @@ class MessageHandler:
|
||||
return False
|
||||
else:
|
||||
# 检查私聊黑白名单
|
||||
private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist")
|
||||
private_list_type = config_api.get_plugin_config(
|
||||
self.plugin_config, "features.private_list_type", "blacklist"
|
||||
)
|
||||
private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", [])
|
||||
|
||||
|
||||
if private_list_type == "whitelist":
|
||||
if user_id not in private_list:
|
||||
logger.warning("私聊不在白名单中,消息被丢弃")
|
||||
@@ -156,21 +158,23 @@ class MessageHandler:
|
||||
Parameters:
|
||||
raw_message: dict: 原始消息
|
||||
"""
|
||||
|
||||
|
||||
# 添加原始消息调试日志,特别关注message字段
|
||||
logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}")
|
||||
logger.debug(
|
||||
f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}"
|
||||
)
|
||||
logger.debug(f"原始消息内容: {raw_message.get('message', [])}")
|
||||
|
||||
|
||||
# 检查是否包含@或video消息段
|
||||
message_segments = raw_message.get('message', [])
|
||||
message_segments = raw_message.get("message", [])
|
||||
if message_segments:
|
||||
for i, seg in enumerate(message_segments):
|
||||
seg_type = seg.get('type')
|
||||
if seg_type in ['at', 'video']:
|
||||
seg_type = seg.get("type")
|
||||
if seg_type in ["at", "video"]:
|
||||
logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}")
|
||||
elif seg_type not in ['text', 'face', 'image']:
|
||||
elif seg_type not in ["text", "face", "image"]:
|
||||
logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}")
|
||||
|
||||
|
||||
message_type: str = raw_message.get("message_type")
|
||||
message_id: int = raw_message.get("message_id")
|
||||
# message_time: int = raw_message.get("time")
|
||||
@@ -308,9 +312,13 @@ class MessageHandler:
|
||||
message_type = raw_message.get("message_type")
|
||||
should_use_buffer = False
|
||||
|
||||
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True):
|
||||
if message_type == "group" and config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_enable_group", True
|
||||
):
|
||||
should_use_buffer = True
|
||||
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True):
|
||||
elif message_type == "private" and config_api.get_plugin_config(
|
||||
self.plugin_config, "features.message_buffer_enable_private", True
|
||||
):
|
||||
should_use_buffer = True
|
||||
|
||||
if should_use_buffer:
|
||||
@@ -368,10 +376,10 @@ class MessageHandler:
|
||||
for sub_message in real_message:
|
||||
sub_message: dict
|
||||
sub_message_type = sub_message.get("type")
|
||||
|
||||
|
||||
# 添加详细的消息类型调试信息
|
||||
logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}")
|
||||
|
||||
|
||||
# 特别关注 at 和 video 消息的识别
|
||||
if sub_message_type == "at":
|
||||
logger.debug(f"检测到@消息: {sub_message}")
|
||||
@@ -379,7 +387,7 @@ class MessageHandler:
|
||||
logger.debug(f"检测到VIDEO消息: {sub_message}")
|
||||
elif sub_message_type not in ["text", "face", "image", "record"]:
|
||||
logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}")
|
||||
|
||||
|
||||
match sub_message_type:
|
||||
case RealMessageType.text:
|
||||
ret_seg = await self.handle_text_message(sub_message)
|
||||
|
||||
@@ -33,6 +33,7 @@ class MessageSending:
|
||||
try:
|
||||
# 重新导入router
|
||||
from ..mmc_com_layer import router
|
||||
|
||||
self.maibot_router = router
|
||||
if self.maibot_router is not None:
|
||||
logger.info("MaiBot router重连成功")
|
||||
@@ -73,14 +74,14 @@ class MessageSending:
|
||||
|
||||
# 获取对应的客户端并发送切片
|
||||
platform = message_base.message_info.platform
|
||||
|
||||
|
||||
# 再次检查router状态(防止运行时被重置)
|
||||
if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'):
|
||||
if self.maibot_router is None or not hasattr(self.maibot_router, "clients"):
|
||||
logger.warning("MaiBot router连接已断开,尝试重新连接")
|
||||
if not await self._attempt_reconnect():
|
||||
logger.error("MaiBot router重连失败,切片发送中止")
|
||||
return False
|
||||
|
||||
|
||||
if platform not in self.maibot_router.clients:
|
||||
logger.error(f"平台 {platform} 未连接")
|
||||
return False
|
||||
|
||||
@@ -22,7 +22,9 @@ class MetaEventHandler:
|
||||
"""设置插件配置"""
|
||||
self.plugin_config = plugin_config
|
||||
# 更新interval值
|
||||
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
|
||||
self.interval = (
|
||||
config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
|
||||
)
|
||||
|
||||
async def handle_meta_event(self, message: dict) -> None:
|
||||
event_type = message.get("meta_event_type")
|
||||
|
||||
@@ -116,9 +116,9 @@ class NoticeHandler:
|
||||
sub_type = raw_message.get("sub_type")
|
||||
match sub_type:
|
||||
case NoticeType.Notify.poke:
|
||||
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat(
|
||||
user_id, group_id, False, False
|
||||
):
|
||||
if config_api.get_plugin_config(
|
||||
self.plugin_config, "features.enable_poke", True
|
||||
) and await message_handler.check_allow_to_chat(user_id, group_id, False, False):
|
||||
logger.debug("处理戳一戳消息")
|
||||
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
|
||||
else:
|
||||
@@ -127,14 +127,18 @@ class NoticeHandler:
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from ...event_types import NapcatEvent
|
||||
|
||||
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
|
||||
await event_manager.trigger_event(
|
||||
NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME
|
||||
)
|
||||
case _:
|
||||
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
|
||||
case NoticeType.group_msg_emoji_like:
|
||||
case NoticeType.group_msg_emoji_like:
|
||||
# 该事件转移到 handle_group_emoji_like_notify函数内触发
|
||||
if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True):
|
||||
logger.debug("处理群聊表情回复")
|
||||
handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id)
|
||||
handled_message, user_info = await self.handle_group_emoji_like_notify(
|
||||
raw_message, group_id, user_id
|
||||
)
|
||||
else:
|
||||
logger.warning("群聊表情回复被禁用,取消群聊表情回复处理")
|
||||
case NoticeType.group_ban:
|
||||
@@ -294,7 +298,7 @@ class NoticeHandler:
|
||||
async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int):
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理群聊表情回复通知")
|
||||
return None, None
|
||||
return None, None
|
||||
|
||||
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
|
||||
if user_qq_info:
|
||||
@@ -304,37 +308,42 @@ class NoticeHandler:
|
||||
user_name = "QQ用户"
|
||||
user_cardname = "QQ用户"
|
||||
logger.debug("无法获取表情回复对方的用户昵称")
|
||||
|
||||
|
||||
from src.plugin_system.core.event_manager import event_manager
|
||||
from ...event_types import NapcatEvent
|
||||
|
||||
target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id",""))
|
||||
target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","")
|
||||
target_message = await event_manager.trigger_event(
|
||||
NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "")
|
||||
)
|
||||
target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "")
|
||||
if not target_message:
|
||||
logger.error("未找到对应消息")
|
||||
return None, None
|
||||
if len(target_message_text) > 15:
|
||||
target_message_text = target_message_text[:15] + "..."
|
||||
|
||||
|
||||
user_info: UserInfo = UserInfo(
|
||||
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
|
||||
user_id=user_id,
|
||||
user_nickname=user_name,
|
||||
user_cardname=user_cardname,
|
||||
)
|
||||
|
||||
|
||||
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
|
||||
await event_manager.trigger_event(
|
||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||
permission_group=PLUGIN_NAME,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
message_id=raw_message.get("message_id",""),
|
||||
emoji_id=like_emoji_id
|
||||
)
|
||||
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]")
|
||||
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
|
||||
permission_group=PLUGIN_NAME,
|
||||
group_id=group_id,
|
||||
user_id=user_id,
|
||||
message_id=raw_message.get("message_id", ""),
|
||||
emoji_id=like_emoji_id,
|
||||
)
|
||||
seg_data = Seg(
|
||||
type="text",
|
||||
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
|
||||
)
|
||||
return seg_data, user_info
|
||||
|
||||
|
||||
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
|
||||
if not group_id:
|
||||
logger.error("群ID不能为空,无法处理禁言通知")
|
||||
|
||||
@@ -45,12 +45,12 @@ async def check_timeout_response() -> None:
|
||||
while True:
|
||||
cleaned_message_count: int = 0
|
||||
now_time = time.time()
|
||||
|
||||
|
||||
# 获取心跳间隔配置
|
||||
heartbeat_interval = 30 # 默认值
|
||||
if plugin_config:
|
||||
heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30)
|
||||
|
||||
|
||||
for echo_id, response_time in list(response_time_dict.items()):
|
||||
if now_time - response_time > heartbeat_interval:
|
||||
cleaned_message_count += 1
|
||||
|
||||
@@ -297,9 +297,9 @@ class SendHandler:
|
||||
|
||||
try:
|
||||
# 检查是否为缓冲消息ID(格式:buffered-{original_id}-{timestamp})
|
||||
if id.startswith('buffered-'):
|
||||
if id.startswith("buffered-"):
|
||||
# 从缓冲消息ID中提取原始消息ID
|
||||
original_id = id.split('-')[1]
|
||||
original_id = id.split("-")[1]
|
||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)})
|
||||
else:
|
||||
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
|
||||
@@ -363,7 +363,7 @@ class SendHandler:
|
||||
use_tts = False
|
||||
if self.plugin_config:
|
||||
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
|
||||
|
||||
|
||||
if not use_tts:
|
||||
logger.warning("未启用语音消息处理")
|
||||
return {}
|
||||
|
||||
@@ -18,7 +18,9 @@ class WebSocketManager:
|
||||
self.max_reconnect_attempts = 10 # 最大重连次数
|
||||
self.plugin_config = None
|
||||
|
||||
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None:
|
||||
async def start_connection(
|
||||
self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict
|
||||
) -> None:
|
||||
"""根据配置启动 WebSocket 连接"""
|
||||
self.plugin_config = plugin_config
|
||||
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
|
||||
@@ -72,9 +74,7 @@ class WebSocketManager:
|
||||
# 如果配置了访问令牌,添加到请求头
|
||||
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
|
||||
if access_token:
|
||||
connect_kwargs["additional_headers"] = {
|
||||
"Authorization": f"Bearer {access_token}"
|
||||
}
|
||||
connect_kwargs["additional_headers"] = {"Authorization": f"Bearer {access_token}"}
|
||||
logger.info("已添加访问令牌到连接请求头")
|
||||
|
||||
async with Server.connect(url, **connect_kwargs) as websocket:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Base search engine interface
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Any
|
||||
|
||||
@@ -9,20 +10,20 @@ class BaseSearchEngine(ABC):
|
||||
"""
|
||||
搜索引擎基类
|
||||
"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
执行搜索
|
||||
|
||||
|
||||
Args:
|
||||
args: 搜索参数,包含 query、num_results、time_range 等
|
||||
|
||||
|
||||
Returns:
|
||||
搜索结果列表,每个结果包含 title、url、snippet、provider 字段
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def is_available(self) -> bool:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Bing search engine implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
import random
|
||||
@@ -58,21 +59,21 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
Bing搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.session = requests.Session()
|
||||
self.session.headers = HEADERS
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查Bing搜索引擎是否可用"""
|
||||
return True # Bing是免费搜索引擎,总是可用
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行Bing搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
time_range = args.get("time_range", "any")
|
||||
|
||||
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(self._search_sync, query, num_results, time_range)
|
||||
@@ -81,17 +82,17 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
except Exception as e:
|
||||
logger.error(f"Bing 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]:
|
||||
"""同步执行Bing搜索"""
|
||||
if not keyword:
|
||||
return []
|
||||
|
||||
list_result = []
|
||||
|
||||
|
||||
# 构建搜索URL
|
||||
search_url = bing_search_url + keyword
|
||||
|
||||
|
||||
# 如果指定了时间范围,添加时间过滤参数
|
||||
if time_range == "week":
|
||||
search_url += "&qft=+filterui:date-range-7"
|
||||
@@ -181,34 +182,29 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
# 尝试提取搜索结果
|
||||
# 方法1: 查找标准的搜索结果容器
|
||||
results = root.select("ol#b_results li.b_algo")
|
||||
|
||||
|
||||
if results:
|
||||
for _rank, result in enumerate(results, 1):
|
||||
# 提取标题和链接
|
||||
title_link = result.select_one("h2 a")
|
||||
if not title_link:
|
||||
continue
|
||||
|
||||
|
||||
title = title_link.get_text().strip()
|
||||
url = title_link.get("href", "")
|
||||
|
||||
|
||||
# 提取摘要
|
||||
abstract = ""
|
||||
abstract_elem = result.select_one("div.b_caption p")
|
||||
if abstract_elem:
|
||||
abstract = abstract_elem.get_text().strip()
|
||||
|
||||
|
||||
# 限制摘要长度
|
||||
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
|
||||
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
|
||||
|
||||
list_data.append({
|
||||
"title": title,
|
||||
"url": url,
|
||||
"snippet": abstract,
|
||||
"provider": "Bing"
|
||||
})
|
||||
|
||||
|
||||
list_data.append({"title": title, "url": url, "snippet": abstract, "provider": "Bing"})
|
||||
|
||||
if len(list_data) >= 10: # 限制结果数量
|
||||
break
|
||||
|
||||
@@ -216,22 +212,34 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
if not list_data:
|
||||
# 查找所有可能的搜索结果链接
|
||||
all_links = root.find_all("a")
|
||||
|
||||
|
||||
for link in all_links:
|
||||
href = link.get("href", "")
|
||||
text = link.get_text().strip()
|
||||
|
||||
|
||||
# 过滤有效的搜索结果链接
|
||||
if (href and text and len(text) > 10
|
||||
if (
|
||||
href
|
||||
and text
|
||||
and len(text) > 10
|
||||
and not href.startswith("javascript:")
|
||||
and not href.startswith("#")
|
||||
and "http" in href
|
||||
and not any(x in href for x in [
|
||||
"bing.com/search", "bing.com/images", "bing.com/videos",
|
||||
"bing.com/maps", "bing.com/news", "login", "account",
|
||||
"microsoft", "javascript"
|
||||
])):
|
||||
|
||||
and not any(
|
||||
x in href
|
||||
for x in [
|
||||
"bing.com/search",
|
||||
"bing.com/images",
|
||||
"bing.com/videos",
|
||||
"bing.com/maps",
|
||||
"bing.com/news",
|
||||
"login",
|
||||
"account",
|
||||
"microsoft",
|
||||
"javascript",
|
||||
]
|
||||
)
|
||||
):
|
||||
# 尝试获取摘要
|
||||
abstract = ""
|
||||
parent = link.parent
|
||||
@@ -239,18 +247,13 @@ class BingSearchEngine(BaseSearchEngine):
|
||||
full_text = parent.get_text().strip()
|
||||
if len(full_text) > len(text):
|
||||
abstract = full_text.replace(text, "", 1).strip()
|
||||
|
||||
|
||||
# 限制摘要长度
|
||||
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
|
||||
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
|
||||
|
||||
list_data.append({
|
||||
"title": text,
|
||||
"url": href,
|
||||
"snippet": abstract,
|
||||
"provider": "Bing"
|
||||
})
|
||||
|
||||
|
||||
list_data.append({"title": text, "url": href, "snippet": abstract, "provider": "Bing"})
|
||||
|
||||
if len(list_data) >= 10:
|
||||
break
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
DuckDuckGo search engine implementation
|
||||
"""
|
||||
|
||||
from typing import Dict, List, Any
|
||||
from asyncddgs import aDDGS
|
||||
|
||||
@@ -14,27 +15,22 @@ class DDGSearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
DuckDuckGo搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查DuckDuckGo搜索引擎是否可用"""
|
||||
return True # DuckDuckGo不需要API密钥,总是可用
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行DuckDuckGo搜索"""
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
|
||||
|
||||
try:
|
||||
async with aDDGS() as ddgs:
|
||||
search_response = await ddgs.text(query, max_results=num_results)
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"title": r.get("title"),
|
||||
"url": r.get("href"),
|
||||
"snippet": r.get("body"),
|
||||
"provider": "DuckDuckGo"
|
||||
}
|
||||
{"title": r.get("title"), "url": r.get("href"), "snippet": r.get("body"), "provider": "DuckDuckGo"}
|
||||
for r in search_response
|
||||
]
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Exa search engine implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from datetime import datetime, timedelta
|
||||
@@ -19,31 +20,27 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
Exa搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._initialize_clients()
|
||||
|
||||
|
||||
def _initialize_clients(self):
|
||||
"""初始化Exa客户端"""
|
||||
# 从主配置文件读取API密钥
|
||||
exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None)
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
exa_api_keys,
|
||||
lambda key: Exa(api_key=key),
|
||||
"Exa"
|
||||
)
|
||||
|
||||
self.api_manager = create_api_key_manager_from_config(exa_api_keys, lambda key: Exa(api_key=key), "Exa")
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查Exa搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行Exa搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
time_range = args.get("time_range", "any")
|
||||
@@ -52,7 +49,7 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
if time_range != "any":
|
||||
today = datetime.now()
|
||||
start_date = today - timedelta(days=7 if time_range == "week" else 30)
|
||||
exa_args["start_published_date"] = start_date.strftime('%Y-%m-%d')
|
||||
exa_args["start_published_date"] = start_date.strftime("%Y-%m-%d")
|
||||
|
||||
try:
|
||||
# 使用API密钥管理器获取下一个客户端
|
||||
@@ -60,17 +57,17 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
if not exa_client:
|
||||
logger.error("无法获取Exa客户端")
|
||||
return []
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
|
||||
return [
|
||||
{
|
||||
"title": res.title,
|
||||
"url": res.url,
|
||||
"snippet": " ".join(getattr(res, 'highlights', [])) or (getattr(res, 'text', '')[:250] + '...'),
|
||||
"provider": "Exa"
|
||||
"snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
|
||||
"provider": "Exa",
|
||||
}
|
||||
for res in search_response.results
|
||||
]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Tavily search engine implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Dict, List, Any
|
||||
@@ -18,31 +19,29 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
"""
|
||||
Tavily搜索引擎实现
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self._initialize_clients()
|
||||
|
||||
|
||||
def _initialize_clients(self):
|
||||
"""初始化Tavily客户端"""
|
||||
# 从主配置文件读取API密钥
|
||||
tavily_api_keys = config_api.get_global_config("web_search.tavily_api_keys", None)
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
tavily_api_keys,
|
||||
lambda key: TavilyClient(api_key=key),
|
||||
"Tavily"
|
||||
tavily_api_keys, lambda key: TavilyClient(api_key=key), "Tavily"
|
||||
)
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查Tavily搜索引擎是否可用"""
|
||||
return self.api_manager.is_available()
|
||||
|
||||
|
||||
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
"""执行Tavily搜索"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
time_range = args.get("time_range", "any")
|
||||
@@ -53,38 +52,40 @@ class TavilySearchEngine(BaseSearchEngine):
|
||||
if not tavily_client:
|
||||
logger.error("无法获取Tavily客户端")
|
||||
return []
|
||||
|
||||
|
||||
# 构建Tavily搜索参数
|
||||
search_params = {
|
||||
"query": query,
|
||||
"max_results": num_results,
|
||||
"search_depth": "basic",
|
||||
"include_answer": False,
|
||||
"include_raw_content": False
|
||||
"include_raw_content": False,
|
||||
}
|
||||
|
||||
|
||||
# 根据时间范围调整搜索参数
|
||||
if time_range == "week":
|
||||
search_params["days"] = 7
|
||||
elif time_range == "month":
|
||||
search_params["days"] = 30
|
||||
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(tavily_client.search, **search_params)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
|
||||
results = []
|
||||
if search_response and "results" in search_response:
|
||||
for res in search_response["results"]:
|
||||
results.append({
|
||||
"title": res.get("title", "无标题"),
|
||||
"url": res.get("url", ""),
|
||||
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
|
||||
"provider": "Tavily"
|
||||
})
|
||||
|
||||
results.append(
|
||||
{
|
||||
"title": res.get("title", "无标题"),
|
||||
"url": res.get("url", ""),
|
||||
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
|
||||
"provider": "Tavily",
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tavily 搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@@ -3,15 +3,10 @@ Web Search Tool Plugin
|
||||
|
||||
一个功能强大的网络搜索和URL解析插件,支持多种搜索引擎和解析策略。
|
||||
"""
|
||||
|
||||
from typing import List, Tuple, Type
|
||||
|
||||
from src.plugin_system import (
|
||||
BasePlugin,
|
||||
register_plugin,
|
||||
ComponentInfo,
|
||||
ConfigField,
|
||||
PythonDependency
|
||||
)
|
||||
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
|
||||
from src.plugin_system.apis import config_api
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -25,7 +20,7 @@ logger = get_logger("web_search_plugin")
|
||||
class WEBSEARCHPLUGIN(BasePlugin):
|
||||
"""
|
||||
网络搜索工具插件
|
||||
|
||||
|
||||
提供网络搜索和URL解析功能,支持多种搜索引擎:
|
||||
- Exa (需要API密钥)
|
||||
- Tavily (需要API密钥)
|
||||
@@ -37,11 +32,11 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
plugin_name: str = "web_search_tool" # 内部标识符
|
||||
enable_plugin: bool = True
|
||||
dependencies: List[str] = [] # 插件依赖列表
|
||||
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""初始化插件,立即加载所有搜索引擎"""
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
# 立即初始化所有搜索引擎,触发API密钥管理器的日志输出
|
||||
logger.info("🚀 正在初始化所有搜索引擎...")
|
||||
try:
|
||||
@@ -49,65 +44,58 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
from .engines.tavily_engine import TavilySearchEngine
|
||||
from .engines.ddg_engine import DDGSearchEngine
|
||||
from .engines.bing_engine import BingSearchEngine
|
||||
|
||||
|
||||
# 实例化所有搜索引擎,这会触发API密钥管理器的初始化
|
||||
exa_engine = ExaSearchEngine()
|
||||
tavily_engine = TavilySearchEngine()
|
||||
ddg_engine = DDGSearchEngine()
|
||||
bing_engine = BingSearchEngine()
|
||||
|
||||
|
||||
# 报告每个引擎的状态
|
||||
engines_status = {
|
||||
"Exa": exa_engine.is_available(),
|
||||
"Tavily": tavily_engine.is_available(),
|
||||
"DuckDuckGo": ddg_engine.is_available(),
|
||||
"Bing": bing_engine.is_available()
|
||||
"Bing": bing_engine.is_available(),
|
||||
}
|
||||
|
||||
|
||||
available_engines = [name for name, available in engines_status.items() if available]
|
||||
unavailable_engines = [name for name, available in engines_status.items() if not available]
|
||||
|
||||
|
||||
if available_engines:
|
||||
logger.info(f"✅ 可用搜索引擎: {', '.join(available_engines)}")
|
||||
if unavailable_engines:
|
||||
logger.info(f"❌ 不可用搜索引擎: {', '.join(unavailable_engines)}")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True)
|
||||
|
||||
|
||||
# Python包依赖列表
|
||||
python_dependencies: List[PythonDependency] = [
|
||||
PythonDependency(
|
||||
package_name="asyncddgs",
|
||||
description="异步DuckDuckGo搜索库",
|
||||
optional=False
|
||||
),
|
||||
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
|
||||
PythonDependency(
|
||||
package_name="exa_py",
|
||||
description="Exa搜索API客户端库",
|
||||
optional=True # 如果没有API密钥,这个是可选的
|
||||
optional=True, # 如果没有API密钥,这个是可选的
|
||||
),
|
||||
PythonDependency(
|
||||
package_name="tavily",
|
||||
install_name="tavily-python", # 安装时使用这个名称
|
||||
description="Tavily搜索API客户端库",
|
||||
optional=True # 如果没有API密钥,这个是可选的
|
||||
optional=True, # 如果没有API密钥,这个是可选的
|
||||
),
|
||||
PythonDependency(
|
||||
package_name="httpx",
|
||||
version=">=0.20.0",
|
||||
install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖)
|
||||
description="支持SOCKS代理的HTTP客户端库",
|
||||
optional=False
|
||||
)
|
||||
optional=False,
|
||||
),
|
||||
]
|
||||
config_file_name: str = "config.toml" # 配置文件名
|
||||
|
||||
# 配置节描述
|
||||
config_section_descriptions = {
|
||||
"plugin": "插件基本信息",
|
||||
"proxy": "链接本地解析代理配置"
|
||||
}
|
||||
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
|
||||
|
||||
# 配置Schema定义
|
||||
# 注意:EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
|
||||
@@ -119,42 +107,32 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
||||
},
|
||||
"proxy": {
|
||||
"http_proxy": ConfigField(
|
||||
type=str,
|
||||
default=None,
|
||||
description="HTTP代理地址,格式如: http://proxy.example.com:8080"
|
||||
type=str, default=None, description="HTTP代理地址,格式如: http://proxy.example.com:8080"
|
||||
),
|
||||
"https_proxy": ConfigField(
|
||||
type=str,
|
||||
default=None,
|
||||
description="HTTPS代理地址,格式如: http://proxy.example.com:8080"
|
||||
type=str, default=None, description="HTTPS代理地址,格式如: http://proxy.example.com:8080"
|
||||
),
|
||||
"socks5_proxy": ConfigField(
|
||||
type=str,
|
||||
default=None,
|
||||
description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080"
|
||||
type=str, default=None, description="SOCKS5代理地址,格式如: socks5://proxy.example.com:1080"
|
||||
),
|
||||
"enable_proxy": ConfigField(
|
||||
type=bool,
|
||||
default=False,
|
||||
description="是否启用代理"
|
||||
)
|
||||
"enable_proxy": ConfigField(type=bool, default=False, description="是否启用代理"),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
|
||||
"""
|
||||
获取插件组件列表
|
||||
|
||||
|
||||
Returns:
|
||||
组件信息和类型的元组列表
|
||||
"""
|
||||
enable_tool = []
|
||||
|
||||
|
||||
# 从主配置文件读取组件启用配置
|
||||
if config_api.get_global_config("web_search.enable_web_search_tool", True):
|
||||
enable_tool.append((WebSurfingTool.get_tool_info(), WebSurfingTool))
|
||||
|
||||
|
||||
if config_api.get_global_config("web_search.enable_url_tool", True):
|
||||
enable_tool.append((URLParserTool.get_tool_info(), URLParserTool))
|
||||
|
||||
|
||||
return enable_tool
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
URL parser tool implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import functools
|
||||
from typing import Any, Dict
|
||||
@@ -24,17 +25,18 @@ class URLParserTool(BaseTool):
|
||||
"""
|
||||
一个用于解析和总结一个或多个网页URL内容的工具。
|
||||
"""
|
||||
|
||||
name: str = "parse_url"
|
||||
description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]' 或 '帮我总结一下这些文章'"
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
("urls", ToolParamType.STRING, "要理解的网站", True, None),
|
||||
]
|
||||
|
||||
|
||||
def __init__(self, plugin_config=None):
|
||||
super().__init__(plugin_config)
|
||||
self._initialize_exa_clients()
|
||||
|
||||
|
||||
def _initialize_exa_clients(self):
|
||||
"""初始化Exa客户端"""
|
||||
# 优先从主配置文件读取,如果没有则从插件配置文件读取
|
||||
@@ -42,12 +44,10 @@ class URLParserTool(BaseTool):
|
||||
if exa_api_keys is None:
|
||||
# 从插件配置文件读取
|
||||
exa_api_keys = self.get_config("exa.api_keys", [])
|
||||
|
||||
|
||||
# 创建API密钥管理器
|
||||
self.api_manager = create_api_key_manager_from_config(
|
||||
exa_api_keys,
|
||||
lambda key: Exa(api_key=key),
|
||||
"Exa URL Parser"
|
||||
exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
|
||||
)
|
||||
|
||||
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
|
||||
@@ -58,12 +58,12 @@ class URLParserTool(BaseTool):
|
||||
# 读取代理配置
|
||||
enable_proxy = self.get_config("proxy.enable_proxy", False)
|
||||
proxies = None
|
||||
|
||||
|
||||
if enable_proxy:
|
||||
socks5_proxy = self.get_config("proxy.socks5_proxy", None)
|
||||
http_proxy = self.get_config("proxy.http_proxy", None)
|
||||
https_proxy = self.get_config("proxy.https_proxy", None)
|
||||
|
||||
|
||||
# 优先使用SOCKS5代理(全协议代理)
|
||||
if socks5_proxy:
|
||||
proxies = socks5_proxy
|
||||
@@ -75,17 +75,17 @@ class URLParserTool(BaseTool):
|
||||
if https_proxy:
|
||||
proxies["https://"] = https_proxy
|
||||
logger.info(f"使用HTTP/HTTPS代理配置: {proxies}")
|
||||
|
||||
|
||||
client_kwargs = {"timeout": 15.0, "follow_redirects": True}
|
||||
if proxies:
|
||||
client_kwargs["proxies"] = proxies
|
||||
|
||||
|
||||
async with httpx.AsyncClient(**client_kwargs) as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(response.text, "html.parser")
|
||||
|
||||
|
||||
title = soup.title.string if soup.title else "无标题"
|
||||
for script in soup(["script", "style"]):
|
||||
script.extract()
|
||||
@@ -104,12 +104,12 @@ class URLParserTool(BaseTool):
|
||||
return {"error": "未配置LLM模型"}
|
||||
|
||||
success, summary, reasoning, model_name = await llm_api.generate_with_model(
|
||||
prompt=summary_prompt,
|
||||
model_config=model_config,
|
||||
request_type="story.generate",
|
||||
temperature=0.3,
|
||||
max_tokens=1000
|
||||
)
|
||||
prompt=summary_prompt,
|
||||
model_config=model_config,
|
||||
request_type="story.generate",
|
||||
temperature=0.3,
|
||||
max_tokens=1000,
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.info(f"生成摘要失败: {summary}")
|
||||
@@ -117,12 +117,7 @@ class URLParserTool(BaseTool):
|
||||
|
||||
logger.info(f"成功生成摘要内容:'{summary}'")
|
||||
|
||||
return {
|
||||
"title": title,
|
||||
"url": url,
|
||||
"snippet": summary,
|
||||
"source": "local"
|
||||
}
|
||||
return {"title": title, "url": url, "snippet": summary, "source": "local"}
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})")
|
||||
@@ -137,6 +132,7 @@ class URLParserTool(BaseTool):
|
||||
"""
|
||||
# 获取当前文件路径用于缓存键
|
||||
import os
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
# 检查缓存
|
||||
@@ -144,7 +140,7 @@ class URLParserTool(BaseTool):
|
||||
if cached_result:
|
||||
logger.info(f"缓存命中: {self.name} -> {function_args}")
|
||||
return cached_result
|
||||
|
||||
|
||||
urls_input = function_args.get("urls")
|
||||
if not urls_input:
|
||||
return {"error": "URL列表不能为空。"}
|
||||
@@ -158,14 +154,14 @@ class URLParserTool(BaseTool):
|
||||
valid_urls = validate_urls(urls)
|
||||
if not valid_urls:
|
||||
return {"error": "未找到有效的URL。"}
|
||||
|
||||
|
||||
urls = valid_urls
|
||||
logger.info(f"准备解析 {len(urls)} 个URL: {urls}")
|
||||
|
||||
successful_results = []
|
||||
error_messages = []
|
||||
urls_to_retry_locally = []
|
||||
|
||||
|
||||
# 步骤 1: 尝试使用 Exa API 进行解析
|
||||
contents_response = None
|
||||
if self.api_manager.is_available():
|
||||
@@ -182,41 +178,45 @@ class URLParserTool(BaseTool):
|
||||
contents_response = await loop.run_in_executor(None, func)
|
||||
except Exception as e:
|
||||
logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True)
|
||||
contents_response = None # 确保异常后为None
|
||||
contents_response = None # 确保异常后为None
|
||||
|
||||
# 步骤 2: 处理Exa的响应
|
||||
if contents_response and hasattr(contents_response, 'statuses'):
|
||||
results_map = {res.url: res for res in contents_response.results} if hasattr(contents_response, 'results') else {}
|
||||
if contents_response and hasattr(contents_response, "statuses"):
|
||||
results_map = (
|
||||
{res.url: res for res in contents_response.results} if hasattr(contents_response, "results") else {}
|
||||
)
|
||||
if contents_response.statuses:
|
||||
for status in contents_response.statuses:
|
||||
if status.status == 'success':
|
||||
if status.status == "success":
|
||||
res = results_map.get(status.id)
|
||||
if res:
|
||||
summary = getattr(res, 'summary', '')
|
||||
highlights = " ".join(getattr(res, 'highlights', []))
|
||||
text_snippet = (getattr(res, 'text', '')[:300] + '...') if getattr(res, 'text', '') else ''
|
||||
snippet = summary or highlights or text_snippet or '无摘要'
|
||||
|
||||
successful_results.append({
|
||||
"title": getattr(res, 'title', '无标题'),
|
||||
"url": getattr(res, 'url', status.id),
|
||||
"snippet": snippet,
|
||||
"source": "exa"
|
||||
})
|
||||
summary = getattr(res, "summary", "")
|
||||
highlights = " ".join(getattr(res, "highlights", []))
|
||||
text_snippet = (getattr(res, "text", "")[:300] + "...") if getattr(res, "text", "") else ""
|
||||
snippet = summary or highlights or text_snippet or "无摘要"
|
||||
|
||||
successful_results.append(
|
||||
{
|
||||
"title": getattr(res, "title", "无标题"),
|
||||
"url": getattr(res, "url", status.id),
|
||||
"snippet": snippet,
|
||||
"source": "exa",
|
||||
}
|
||||
)
|
||||
else:
|
||||
error_tag = getattr(status, 'error', '未知错误')
|
||||
error_tag = getattr(status, "error", "未知错误")
|
||||
logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。")
|
||||
urls_to_retry_locally.append(status.id)
|
||||
else:
|
||||
# 如果Exa未配置、API调用失败或返回无效响应,则所有URL都进入本地重试
|
||||
urls_to_retry_locally.extend(url for url in urls if url not in [res['url'] for res in successful_results])
|
||||
urls_to_retry_locally.extend(url for url in urls if url not in [res["url"] for res in successful_results])
|
||||
|
||||
# 步骤 3: 对失败的URL进行本地解析
|
||||
if urls_to_retry_locally:
|
||||
logger.info(f"开始本地解析以下URL: {urls_to_retry_locally}")
|
||||
local_tasks = [self._local_parse_and_summarize(url) for url in urls_to_retry_locally]
|
||||
local_results = await asyncio.gather(*local_tasks)
|
||||
|
||||
|
||||
for i, res in enumerate(local_results):
|
||||
url = urls_to_retry_locally[i]
|
||||
if "error" in res:
|
||||
@@ -228,13 +228,9 @@ class URLParserTool(BaseTool):
|
||||
return {"error": "无法从所有给定的URL获取内容。", "details": error_messages}
|
||||
|
||||
formatted_content = format_url_parse_results(successful_results)
|
||||
|
||||
result = {
|
||||
"type": "url_parse_result",
|
||||
"content": formatted_content,
|
||||
"errors": error_messages
|
||||
}
|
||||
|
||||
|
||||
result = {"type": "url_parse_result", "content": formatted_content, "errors": error_messages}
|
||||
|
||||
# 保存到缓存
|
||||
if "error" not in result:
|
||||
await tool_cache.set(self.name, function_args, current_file_path, result)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Web search tool implementation
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, List
|
||||
|
||||
@@ -22,14 +23,23 @@ class WebSurfingTool(BaseTool):
|
||||
"""
|
||||
网络搜索工具
|
||||
"""
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
|
||||
description: str = (
|
||||
"用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
|
||||
)
|
||||
available_for_llm: bool = True
|
||||
parameters = [
|
||||
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
|
||||
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量,默认为5。", False, None),
|
||||
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。", False, ["any", "week", "month"])
|
||||
] # type: ignore
|
||||
(
|
||||
"time_range",
|
||||
ToolParamType.STRING,
|
||||
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'。",
|
||||
False,
|
||||
["any", "week", "month"],
|
||||
),
|
||||
] # type: ignore
|
||||
|
||||
def __init__(self, plugin_config=None):
|
||||
super().__init__(plugin_config)
|
||||
@@ -38,7 +48,7 @@ class WebSurfingTool(BaseTool):
|
||||
"exa": ExaSearchEngine(),
|
||||
"tavily": TavilySearchEngine(),
|
||||
"ddg": DDGSearchEngine(),
|
||||
"bing": BingSearchEngine()
|
||||
"bing": BingSearchEngine(),
|
||||
}
|
||||
|
||||
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -48,6 +58,7 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
# 获取当前文件路径用于缓存键
|
||||
import os
|
||||
|
||||
current_file_path = os.path.abspath(__file__)
|
||||
|
||||
# 检查缓存
|
||||
@@ -59,7 +70,7 @@ class WebSurfingTool(BaseTool):
|
||||
# 读取搜索配置
|
||||
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
|
||||
search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
|
||||
|
||||
|
||||
logger.info(f"开始搜索,策略: {search_strategy}, 启用引擎: {enabled_engines}, 参数: '{function_args}'")
|
||||
|
||||
# 根据策略执行搜索
|
||||
@@ -69,17 +80,19 @@ class WebSurfingTool(BaseTool):
|
||||
result = await self._execute_fallback_search(function_args, enabled_engines)
|
||||
else: # single
|
||||
result = await self._execute_single_search(function_args, enabled_engines)
|
||||
|
||||
|
||||
# 保存到缓存
|
||||
if "error" not in result:
|
||||
await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
async def _execute_parallel_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""并行搜索策略:同时使用所有启用的搜索引擎"""
|
||||
search_tasks = []
|
||||
|
||||
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if engine and engine.is_available():
|
||||
@@ -92,7 +105,7 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
try:
|
||||
search_results_lists = await asyncio.gather(*search_tasks, return_exceptions=True)
|
||||
|
||||
|
||||
all_results = []
|
||||
for result in search_results_lists:
|
||||
if isinstance(result, list):
|
||||
@@ -103,7 +116,7 @@ class WebSurfingTool(BaseTool):
|
||||
# 去重并格式化
|
||||
unique_results = deduplicate_results(all_results)
|
||||
formatted_content = format_search_results(unique_results)
|
||||
|
||||
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
@@ -113,30 +126,32 @@ class WebSurfingTool(BaseTool):
|
||||
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
|
||||
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
|
||||
|
||||
async def _execute_fallback_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
async def _execute_fallback_search(
|
||||
self, function_args: Dict[str, Any], enabled_engines: List[str]
|
||||
) -> Dict[str, Any]:
|
||||
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
|
||||
|
||||
if results: # 如果有结果,直接返回
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"{engine_name} 搜索失败,尝试下一个引擎: {e}")
|
||||
continue
|
||||
|
||||
|
||||
return {"error": "所有搜索引擎都失败了。"}
|
||||
|
||||
async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
|
||||
@@ -145,20 +160,20 @@ class WebSurfingTool(BaseTool):
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
continue
|
||||
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{engine_name} 搜索失败: {e}")
|
||||
return {"error": f"{engine_name} 搜索失败: {str(e)}"}
|
||||
|
||||
|
||||
return {"error": "没有可用的搜索引擎。"}
|
||||
|
||||
@@ -1,24 +1,25 @@
|
||||
"""
|
||||
API密钥管理器,提供轮询机制
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from typing import List, Optional, TypeVar, Generic, Callable
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("api_key_manager")
|
||||
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class APIKeyManager(Generic[T]):
|
||||
"""
|
||||
API密钥管理器,支持轮询机制
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
|
||||
"""
|
||||
初始化API密钥管理器
|
||||
|
||||
|
||||
Args:
|
||||
api_keys: API密钥列表
|
||||
client_factory: 客户端工厂函数,接受API密钥参数并返回客户端实例
|
||||
@@ -27,14 +28,14 @@ class APIKeyManager(Generic[T]):
|
||||
self.service_name = service_name
|
||||
self.clients: List[T] = []
|
||||
self.client_cycle: Optional[itertools.cycle] = None
|
||||
|
||||
|
||||
if api_keys:
|
||||
# 过滤有效的API密钥,排除None、空字符串、"None"字符串等
|
||||
valid_keys = []
|
||||
for key in api_keys:
|
||||
if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", ""):
|
||||
valid_keys.append(key.strip())
|
||||
|
||||
|
||||
if valid_keys:
|
||||
try:
|
||||
self.clients = [client_factory(key) for key in valid_keys]
|
||||
@@ -48,35 +49,33 @@ class APIKeyManager(Generic[T]):
|
||||
logger.warning(f"⚠️ {service_name} API Keys 配置无效(包含None或空值),{service_name} 功能将不可用")
|
||||
else:
|
||||
logger.warning(f"⚠️ {service_name} API Keys 未配置,{service_name} 功能将不可用")
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查是否有可用的客户端"""
|
||||
return bool(self.clients and self.client_cycle)
|
||||
|
||||
|
||||
def get_next_client(self) -> Optional[T]:
|
||||
"""获取下一个客户端(轮询)"""
|
||||
if not self.is_available():
|
||||
return None
|
||||
return next(self.client_cycle)
|
||||
|
||||
|
||||
def get_client_count(self) -> int:
|
||||
"""获取可用客户端数量"""
|
||||
return len(self.clients)
|
||||
|
||||
|
||||
def create_api_key_manager_from_config(
|
||||
config_keys: Optional[List[str]],
|
||||
client_factory: Callable[[str], T],
|
||||
service_name: str
|
||||
config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
|
||||
) -> APIKeyManager[T]:
|
||||
"""
|
||||
从配置创建API密钥管理器的便捷函数
|
||||
|
||||
|
||||
Args:
|
||||
config_keys: 从配置读取的API密钥列表
|
||||
client_factory: 客户端工厂函数
|
||||
service_name: 服务名称
|
||||
|
||||
|
||||
Returns:
|
||||
API密钥管理器实例
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Formatters for web search results
|
||||
"""
|
||||
|
||||
from typing import List, Dict, Any
|
||||
|
||||
|
||||
@@ -13,15 +14,15 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
|
||||
|
||||
formatted_string = "根据网络搜索结果:\n\n"
|
||||
for i, res in enumerate(results, 1):
|
||||
title = res.get("title", '无标题')
|
||||
url = res.get("url", '#')
|
||||
snippet = res.get("snippet", '无摘要')
|
||||
title = res.get("title", "无标题")
|
||||
url = res.get("url", "#")
|
||||
snippet = res.get("snippet", "无摘要")
|
||||
provider = res.get("provider", "未知来源")
|
||||
|
||||
|
||||
formatted_string += f"{i}. **{title}** (来自: {provider})\n"
|
||||
formatted_string += f" - 摘要: {snippet}\n"
|
||||
formatted_string += f" - 来源: {url}\n\n"
|
||||
|
||||
|
||||
return formatted_string
|
||||
|
||||
|
||||
@@ -31,10 +32,10 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
|
||||
"""
|
||||
formatted_parts = []
|
||||
for res in results:
|
||||
title = res.get('title', '无标题')
|
||||
url = res.get('url', '#')
|
||||
snippet = res.get('snippet', '无摘要')
|
||||
source = res.get('source', '未知')
|
||||
title = res.get("title", "无标题")
|
||||
url = res.get("url", "#")
|
||||
snippet = res.get("snippet", "无摘要")
|
||||
source = res.get("source", "未知")
|
||||
|
||||
formatted_string = f"**{title}**\n"
|
||||
formatted_string += f"**内容摘要**:\n{snippet}\n"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
URL processing utilities
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import List
|
||||
|
||||
@@ -12,11 +13,11 @@ def parse_urls_from_input(urls_input) -> List[str]:
|
||||
if isinstance(urls_input, str):
|
||||
# 如果是字符串,尝试解析为URL列表
|
||||
# 提取所有HTTP/HTTPS URL
|
||||
url_pattern = r'https?://[^\s\],]+'
|
||||
url_pattern = r"https?://[^\s\],]+"
|
||||
urls = re.findall(url_pattern, urls_input)
|
||||
if not urls:
|
||||
# 如果没有找到标准URL,将整个字符串作为单个URL
|
||||
if urls_input.strip().startswith(('http://', 'https://')):
|
||||
if urls_input.strip().startswith(("http://", "https://")):
|
||||
urls = [urls_input.strip()]
|
||||
else:
|
||||
return []
|
||||
@@ -24,7 +25,7 @@ def parse_urls_from_input(urls_input) -> List[str]:
|
||||
urls = [url.strip() for url in urls_input if isinstance(url, str) and url.strip()]
|
||||
else:
|
||||
return []
|
||||
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
@@ -34,6 +35,6 @@ def validate_urls(urls: List[str]) -> List[str]:
|
||||
"""
|
||||
valid_urls = []
|
||||
for url in urls:
|
||||
if url.startswith(('http://', 'https://')):
|
||||
if url.startswith(("http://", "https://")):
|
||||
valid_urls.append(url)
|
||||
return valid_urls
|
||||
|
||||
@@ -21,8 +21,18 @@ logger = get_logger(__name__)
|
||||
|
||||
# ============================ AsyncTask ============================
|
||||
|
||||
|
||||
class ReminderTask(AsyncTask):
|
||||
def __init__(self, delay: float, stream_id: str, is_group: bool, target_user_id: str, target_user_name: str, event_details: str, creator_name: str):
|
||||
def __init__(
|
||||
self,
|
||||
delay: float,
|
||||
stream_id: str,
|
||||
is_group: bool,
|
||||
target_user_id: str,
|
||||
target_user_name: str,
|
||||
event_details: str,
|
||||
creator_name: str,
|
||||
):
|
||||
super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}")
|
||||
self.delay = delay
|
||||
self.stream_id = stream_id
|
||||
@@ -37,22 +47,22 @@ class ReminderTask(AsyncTask):
|
||||
if self.delay > 0:
|
||||
logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...")
|
||||
await asyncio.sleep(self.delay)
|
||||
|
||||
|
||||
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
|
||||
|
||||
reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}"
|
||||
|
||||
if self.is_group:
|
||||
# 在群聊中,构造 @ 消息段并发送
|
||||
group_id = self.stream_id.split('_')[-1] if '_' in self.stream_id else self.stream_id
|
||||
group_id = self.stream_id.split("_")[-1] if "_" in self.stream_id else self.stream_id
|
||||
message_payload = [
|
||||
{"type": "at", "data": {"qq": self.target_user_id}},
|
||||
{"type": "text", "data": {"text": f" {reminder_text}"}}
|
||||
{"type": "text", "data": {"text": f" {reminder_text}"}},
|
||||
]
|
||||
await send_api.adapter_command_to_stream(
|
||||
action="send_group_msg",
|
||||
params={"group_id": group_id, "message": message_payload},
|
||||
stream_id=self.stream_id
|
||||
stream_id=self.stream_id,
|
||||
)
|
||||
else:
|
||||
# 在私聊中,直接发送文本
|
||||
@@ -66,6 +76,7 @@ class ReminderTask(AsyncTask):
|
||||
|
||||
# =============================== Actions ===============================
|
||||
|
||||
|
||||
class RemindAction(BaseAction):
|
||||
"""一个能从对话中智能识别并设置定时提醒的动作。"""
|
||||
|
||||
@@ -95,12 +106,12 @@ class RemindAction(BaseAction):
|
||||
action_parameters = {
|
||||
"user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'",
|
||||
"remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后'或'明天下午3点'",
|
||||
"event_details": "需要提醒的具体事件内容"
|
||||
"event_details": "需要提醒的具体事件内容",
|
||||
}
|
||||
action_require = [
|
||||
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
|
||||
"适用于包含明确时间信息和事件描述的对话",
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'"
|
||||
"例如:'10分钟后提醒我收快递'、'明天早上九点喊一下李四参加晨会'",
|
||||
]
|
||||
|
||||
async def execute(self) -> Tuple[bool, str]:
|
||||
@@ -110,7 +121,15 @@ class RemindAction(BaseAction):
|
||||
event_details = self.action_data.get("event_details")
|
||||
|
||||
if not all([user_name, remind_time_str, event_details]):
|
||||
missing_params = [p for p, v in {"user_name": user_name, "remind_time": remind_time_str, "event_details": event_details}.items() if not v]
|
||||
missing_params = [
|
||||
p
|
||||
for p, v in {
|
||||
"user_name": user_name,
|
||||
"remind_time": remind_time_str,
|
||||
"event_details": event_details,
|
||||
}.items()
|
||||
if not v
|
||||
]
|
||||
error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}"
|
||||
logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}")
|
||||
return False, error_msg
|
||||
@@ -135,9 +154,9 @@ class RemindAction(BaseAction):
|
||||
person_manager = get_person_info_manager()
|
||||
user_id_to_remind = None
|
||||
user_name_to_remind = ""
|
||||
|
||||
|
||||
assert isinstance(user_name, str)
|
||||
|
||||
|
||||
if user_name.strip() in ["自己", "我", "me"]:
|
||||
user_id_to_remind = self.user_id
|
||||
user_name_to_remind = self.user_nickname
|
||||
@@ -154,7 +173,7 @@ class RemindAction(BaseAction):
|
||||
try:
|
||||
assert user_id_to_remind is not None
|
||||
assert event_details is not None
|
||||
|
||||
|
||||
reminder_task = ReminderTask(
|
||||
delay=delay_seconds,
|
||||
stream_id=self.chat_id,
|
||||
@@ -162,14 +181,14 @@ class RemindAction(BaseAction):
|
||||
target_user_id=str(user_id_to_remind),
|
||||
target_user_name=str(user_name_to_remind),
|
||||
event_details=str(event_details),
|
||||
creator_name=str(self.user_nickname)
|
||||
creator_name=str(self.user_nickname),
|
||||
)
|
||||
await async_task_manager.add_task(reminder_task)
|
||||
|
||||
|
||||
# 4. 发送确认消息
|
||||
confirm_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}:\n{event_details}"
|
||||
await self.send_text(confirm_message)
|
||||
|
||||
|
||||
return True, "提醒设置成功"
|
||||
except Exception as e:
|
||||
logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True)
|
||||
@@ -179,6 +198,7 @@ class RemindAction(BaseAction):
|
||||
|
||||
# =============================== Plugin ===============================
|
||||
|
||||
|
||||
@register_plugin
|
||||
class ReminderPlugin(BasePlugin):
|
||||
"""一个能从对话中智能识别并设置定时提醒的插件。"""
|
||||
@@ -193,6 +213,4 @@ class ReminderPlugin(BasePlugin):
|
||||
|
||||
def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]:
|
||||
"""注册插件的所有功能组件。"""
|
||||
return [
|
||||
(RemindAction.get_action_info(), RemindAction)
|
||||
]
|
||||
return [(RemindAction.get_action_info(), RemindAction)]
|
||||
|
||||
Reference in New Issue
Block a user