ruff,私聊视为提及了bot

This commit is contained in:
Windpicker-owo
2025-09-20 22:34:22 +08:00
parent 006f9130b9
commit 444f1ca315
76 changed files with 1066 additions and 882 deletions

View File

@@ -64,50 +64,50 @@ class AtAction(BaseAction):
# 使用回复器生成艾特回复,而不是直接发送命令
from src.chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import get_chat_manager
# 获取当前聊天流
chat_manager = get_chat_manager()
chat_stream = self.chat_stream or chat_manager.get_stream(self.chat_id)
if not chat_stream:
logger.error(f"找不到聊天流: {self.chat_stream}")
return False, "聊天流不存在"
# 创建回复器实例
replyer = DefaultReplyer(chat_stream)
# 构建回复对象,将艾特消息作为回复目标
reply_to = f"{user_name}:{at_message}"
extra_info = f"你需要艾特用户 {user_name} 并回复他们说: {at_message}"
# 使用回复器生成回复
success, llm_response, prompt = await replyer.generate_reply_with_context(
reply_to=reply_to,
extra_info=extra_info,
enable_tool=False, # 艾特回复通常不需要工具调用
from_plugin=False
from_plugin=False,
)
if success and llm_response:
# 获取生成的回复内容
reply_content = llm_response.get("content", "")
if reply_content:
# 获取用户QQ号发送真正的艾特消息
user_id = user_info.get("user_id")
# 发送真正的艾特命令,使用回复器生成的智能内容
await self.send_command(
"SEND_AT_MESSAGE",
args={"qq_id": user_id, "text": reply_content},
display_message=f"艾特用户 {user_name} 并发送智能回复: {reply_content}",
)
await self.store_action_info(
action_build_into_prompt=True,
action_prompt_display=f"执行了艾特用户动作:艾特用户 {user_name} 并发送智能回复: {reply_content}",
action_done=True,
)
logger.info(f"成功通过回复器生成智能内容并发送真正的艾特消息给 {user_name}: {reply_content}")
return True, "智能艾特消息发送成功"
else:
@@ -116,7 +116,7 @@ class AtAction(BaseAction):
else:
logger.error("回复器生成回复失败")
return False, "回复生成失败"
except Exception as e:
logger.error(f"执行艾特用户动作时发生异常: {e}", exc_info=True)
await self.store_action_info(

View File

@@ -70,7 +70,9 @@ class EmojiAction(BaseAction):
# 2. 获取所有有效的表情包对象
emoji_manager = get_emoji_manager()
all_emojis_obj: list[MaiEmoji] = [e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description]
all_emojis_obj: list[MaiEmoji] = [
e for e in emoji_manager.emoji_objects if not e.is_deleted and e.description
]
if not all_emojis_obj:
logger.warning(f"{self.log_prefix} 无法获取任何带有描述的有效表情包")
return False, "无法获取任何带有描述的有效表情包"
@@ -91,12 +93,12 @@ class EmojiAction(BaseAction):
# 4. 准备情感数据和后备列表
emotion_map = {}
all_emojis_data = []
for emoji in all_emojis_obj:
b64 = image_path_to_base64(emoji.full_path)
if not b64:
continue
desc = emoji.description
emotions = emoji.emotion
all_emojis_data.append((b64, desc))
@@ -168,16 +170,18 @@ class EmojiAction(BaseAction):
# 使用模糊匹配来查找最相关的情感标签
matched_key = next((key for key in emotion_map if chosen_emotion in key), None)
if matched_key:
emoji_base64, emoji_description = random.choice(emotion_map[matched_key])
logger.info(f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}")
logger.info(
f"{self.log_prefix} 找到匹配情感 '{chosen_emotion}' (匹配到: '{matched_key}') 的表情包: {emoji_description}"
)
else:
logger.warning(
f"{self.log_prefix} LLM选择的情感 '{chosen_emotion}' 不在可用列表中, 将随机选择一个表情包"
)
emoji_base64, emoji_description = random.choice(all_emojis_data)
elif global_config.emoji.emoji_selection_mode == "description":
# --- 详细描述选择模式 ---
# 获取最近的5条消息内容用于判断
@@ -226,15 +230,23 @@ class EmojiAction(BaseAction):
logger.info(f"{self.log_prefix} LLM选择的描述: {chosen_description}")
# 简单关键词匹配
matched_emoji = next((item for item in all_emojis_data if chosen_description.lower() in item[1].lower() or item[1].lower() in chosen_description.lower()), None)
matched_emoji = next(
(
item
for item in all_emojis_data
if chosen_description.lower() in item[1].lower()
or item[1].lower() in chosen_description.lower()
),
None,
)
# 如果包含匹配失败,尝试关键词匹配
if not matched_emoji:
keywords = ['惊讶', '困惑', '呆滞', '震惊', '', '无语', '', '可爱']
keywords = ["惊讶", "困惑", "呆滞", "震惊", "", "无语", "", "可爱"]
for keyword in keywords:
if keyword in chosen_description:
for item in all_emojis_data:
if any(k in item[1] for k in ['', '', '', '困惑', '无语']):
if any(k in item[1] for k in ["", "", "", "困惑", "无语"]):
matched_emoji = item
break
if matched_emoji:
@@ -255,7 +267,9 @@ class EmojiAction(BaseAction):
if not success:
logger.error(f"{self.log_prefix} 表情包发送失败")
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包,但失败了",action_done= False)
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包,但失败了", action_done=False
)
return False, "表情包发送失败"
# 发送成功后,记录到历史
@@ -263,8 +277,10 @@ class EmojiAction(BaseAction):
add_emoji_to_history(self.chat_id, emoji_description)
except Exception as e:
logger.error(f"{self.log_prefix} 添加表情到历史记录时出错: {e}")
await self.store_action_info(action_build_into_prompt = True,action_prompt_display =f"发送了一个表情包",action_done= True)
await self.store_action_info(
action_build_into_prompt=True, action_prompt_display=f"发送了一个表情包", action_done=True
)
return True, f"发送表情包: {emoji_description}"

View File

@@ -1,4 +1,3 @@
from src.plugin_system import BaseEventHandler
from src.plugin_system.base.base_event import HandlerResult
@@ -1748,6 +1747,7 @@ class SetGroupSignHandler(BaseEventHandler):
logger.error("事件 napcat_set_group_sign 请求失败!")
return HandlerResult(False, False, {"status": "error"})
# ===PERSONAL===
class SetInputStatusHandler(BaseEventHandler):
handler_name: str = "napcat_set_input_status_handler"

View File

@@ -227,7 +227,7 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
await reassembler.start_cleanup_task()
logger.info("开始启动Napcat Adapter")
# 创建单独的异步任务,防止阻塞主线程
asyncio.create_task(self._start_maibot_connection())
asyncio.create_task(napcat_server(self.plugin_config))
@@ -238,10 +238,10 @@ class LauchNapcatAdapterHandler(BaseEventHandler):
"""非阻塞方式启动MaiBot连接等待主服务启动后再连接"""
# 等待一段时间让MaiBot主服务完全启动
await asyncio.sleep(5)
max_attempts = 10
attempt = 0
while attempt < max_attempts:
try:
logger.info(f"尝试连接MaiBot (第{attempt + 1}次)")
@@ -285,7 +285,7 @@ class NapcatAdapterPlugin(BasePlugin):
def enable_plugin(self) -> bool:
"""通过配置文件动态控制插件启用状态"""
# 如果已经通过配置加载了状态,使用配置中的值
if hasattr(self, '_is_enabled'):
if hasattr(self, "_is_enabled"):
return self._is_enabled
# 否则使用默认值(禁用状态)
return False
@@ -308,60 +308,107 @@ class NapcatAdapterPlugin(BasePlugin):
"nickname": ConfigField(type=str, default="", description="昵称配置(目前未使用)"),
},
"napcat_server": {
"mode": ConfigField(type=str, default="reverse", description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)", choices=["reverse", "forward"]),
"mode": ConfigField(
type=str,
default="reverse",
description="连接模式reverse=反向连接(作为服务器), forward=正向连接(作为客户端)",
choices=["reverse", "forward"],
),
"host": ConfigField(type=str, default="localhost", description="主机地址"),
"port": ConfigField(type=int, default=8095, description="端口号"),
"url": ConfigField(type=str, default="", description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)"),
"access_token": ConfigField(type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"),
"url": ConfigField(
type=str,
default="",
description="正向连接时的完整WebSocket URL如 ws://localhost:8080/ws (仅在forward模式下使用)",
),
"access_token": ConfigField(
type=str, default="", description="WebSocket 连接的访问令牌,用于身份验证(可选)"
),
"heartbeat_interval": ConfigField(type=int, default=30, description="心跳间隔时间(按秒计)"),
},
"maibot_server": {
"host": ConfigField(type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"),
"host": ConfigField(
type=str, default="localhost", description="麦麦在.env文件中设置的主机地址即HOST字段"
),
"port": ConfigField(type=int, default=8000, description="麦麦在.env文件中设置的端口即PORT字段"),
"platform_name": ConfigField(type=str, default="qq", description="平台名称,用于消息路由"),
},
"voice": {
"use_tts": ConfigField(type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"),
"use_tts": ConfigField(
type=bool, default=False, description="是否使用tts语音请确保你配置了tts并有对应的adapter"
),
},
"slicing": {
"max_frame_size": ConfigField(type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"),
"max_frame_size": ConfigField(
type=int, default=64, description="WebSocket帧的最大大小单位为字节默认64KB"
),
"delay_ms": ConfigField(type=int, default=10, description="切片发送间隔时间,单位为毫秒"),
},
"debug": {
"level": ConfigField(type=str, default="INFO", description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL", choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
"level": ConfigField(
type=str,
default="INFO",
description="日志等级DEBUG, INFO, WARNING, ERROR, CRITICAL",
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
),
},
"features": {
# 权限设置
"group_list_type": ConfigField(type=str, default="blacklist", description="群聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]),
"group_list_type": ConfigField(
type=str,
default="blacklist",
description="群聊列表类型whitelist白名单或 blacklist黑名单",
choices=["whitelist", "blacklist"],
),
"group_list": ConfigField(type=list, default=[], description="群聊ID列表"),
"private_list_type": ConfigField(type=str, default="blacklist", description="私聊列表类型whitelist白名单或 blacklist黑名单", choices=["whitelist", "blacklist"]),
"private_list_type": ConfigField(
type=str,
default="blacklist",
description="私聊列表类型whitelist白名单或 blacklist黑名单",
choices=["whitelist", "blacklist"],
),
"private_list": ConfigField(type=list, default=[], description="用户ID列表"),
"ban_user_id": ConfigField(type=list, default=[], description="全局禁止用户ID列表这些用户无法在任何地方使用机器人"),
"ban_user_id": ConfigField(
type=list, default=[], description="全局禁止用户ID列表这些用户无法在任何地方使用机器人"
),
"ban_qq_bot": ConfigField(type=bool, default=False, description="是否屏蔽QQ官方机器人消息"),
# 聊天功能设置
"enable_poke": ConfigField(type=bool, default=True, description="是否启用戳一戳功能"),
"ignore_non_self_poke": ConfigField(type=bool, default=False, description="是否无视不是针对自己的戳一戳"),
"poke_debounce_seconds": ConfigField(type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"),
"poke_debounce_seconds": ConfigField(
type=int, default=3, description="戳一戳防抖时间(秒),在指定时间内第二次针对机器人的戳一戳将被忽略"
),
"enable_reply_at": ConfigField(type=bool, default=True, description="是否启用引用回复时艾特用户的功能"),
"reply_at_rate": ConfigField(type=float, default=0.5, description="引用回复时艾特用户的几率 (0.0 ~ 1.0)"),
"enable_emoji_like": ConfigField(type=bool, default=True, description="是否启用群聊表情回复功能"),
# 视频处理设置
"enable_video_analysis": ConfigField(type=bool, default=True, description="是否启用视频识别功能"),
"max_video_size_mb": ConfigField(type=int, default=100, description="视频文件最大大小限制MB"),
"download_timeout": ConfigField(type=int, default=60, description="视频下载超时时间(秒)"),
"supported_formats": ConfigField(type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"),
"supported_formats": ConfigField(
type=list, default=["mp4", "avi", "mov", "mkv", "flv", "wmv", "webm"], description="支持的视频格式"
),
# 消息缓冲设置
"enable_message_buffer": ConfigField(type=bool, default=True, description="是否启用消息缓冲合并功能"),
"message_buffer_enable_group": ConfigField(type=bool, default=True, description="是否启用群聊消息缓冲合并"),
"message_buffer_enable_private": ConfigField(type=bool, default=True, description="是否启用私聊消息缓冲合并"),
"message_buffer_interval": ConfigField(type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"),
"message_buffer_initial_delay": ConfigField(type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"),
"message_buffer_max_components": ConfigField(type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"),
"message_buffer_block_prefixes": ConfigField(type=list, default=["/", "!", "", ".", "", "#", "%"], description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲"),
}
"message_buffer_enable_private": ConfigField(
type=bool, default=True, description="是否启用私聊消息缓冲合并"
),
"message_buffer_interval": ConfigField(
type=float, default=3.0, description="消息合并间隔时间(秒),在此时间内的连续消息将被合并"
),
"message_buffer_initial_delay": ConfigField(
type=float, default=0.5, description="消息缓冲初始延迟(秒),收到第一条消息后等待此时间开始合并"
),
"message_buffer_max_components": ConfigField(
type=int, default=50, description="单个会话最大缓冲消息组件数量,超过此数量将强制合并"
),
"message_buffer_block_prefixes": ConfigField(
type=list,
default=["/", "!", "", ".", "", "#", "%"],
description="消息缓冲屏蔽前缀,以这些前缀开头的消息不会被缓冲",
),
},
}
# 配置节描述
@@ -374,7 +421,7 @@ class NapcatAdapterPlugin(BasePlugin):
"voice": "发送语音设置",
"slicing": "WebSocket消息切片设置",
"debug": "调试设置",
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)"
"features": "功能设置(权限控制、聊天功能、视频处理、消息缓冲等)",
}
def register_events(self):
@@ -409,6 +456,7 @@ class NapcatAdapterPlugin(BasePlugin):
chunker.set_plugin_config(self.config)
# 设置response_pool的插件配置
from .src.response_pool import set_plugin_config as set_response_pool_config
set_response_pool_config(self.config)
# 设置send_handler的插件配置
send_handler.set_plugin_config(self.config)
@@ -418,4 +466,4 @@ class NapcatAdapterPlugin(BasePlugin):
notice_handler.set_plugin_config(self.config)
# 设置meta_event_handler的插件配置
meta_event_handler.set_plugin_config(self.config)
# 设置其他handler的插件配置现在由component_registry在注册时自动设置
# 设置其他handler的插件配置现在由component_registry在注册时自动设置

View File

@@ -102,7 +102,9 @@ class SimpleMessageBuffer:
return True
# 检查屏蔽前缀
block_prefixes = tuple(config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", []))
block_prefixes = tuple(
config_api.get_plugin_config(self.plugin_config, "features.message_buffer_block_prefixes", [])
)
text = text.strip()
if text.startswith(block_prefixes):
@@ -134,9 +136,13 @@ class SimpleMessageBuffer:
# 检查是否启用对应类型的缓冲
message_type = event_data.get("message_type", "")
if message_type == "group" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", False):
if message_type == "group" and not config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_group", False
):
return False
elif message_type == "private" and not config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", False):
elif message_type == "private" and not config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_private", False
):
return False
# 提取文本
@@ -158,7 +164,9 @@ class SimpleMessageBuffer:
session = self.buffer_pool[session_id]
# 检查是否超过最大组件数量
if len(session.messages) >= config_api.get_plugin_config(self.plugin_config, "features.message_buffer_max_components", 5):
if len(session.messages) >= config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_max_components", 5
):
logger.debug(f"会话 {session_id} 消息数量达到上限,强制合并")
asyncio.create_task(self._force_merge_session(session_id))
self.buffer_pool[session_id] = BufferedSession(session_id=session_id, original_event=original_event)

View File

@@ -14,7 +14,7 @@ def create_router(plugin_config: dict):
platform_name = config_api.get_plugin_config(plugin_config, "maibot_server.platform_name", "qq")
host = config_api.get_plugin_config(plugin_config, "maibot_server.host", "localhost")
port = config_api.get_plugin_config(plugin_config, "maibot_server.port", 8000)
route_config = RouteConfig(
route_config={
platform_name: TargetConfig(
@@ -32,7 +32,7 @@ async def mmc_start_com(plugin_config: dict = None):
logger.info("正在连接MaiBot")
if plugin_config:
create_router(plugin_config)
if router:
router.register_class_handler(send_handler.handle_message)
await router.run()

View File

@@ -32,7 +32,7 @@ class NoticeType: # 通知事件
group_recall = "group_recall" # 群聊消息撤回
notify = "notify"
group_ban = "group_ban" # 群禁言
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
group_msg_emoji_like = "group_msg_emoji_like" # 群聊表情回复
class Notify:
poke = "poke" # 戳一戳

View File

@@ -100,7 +100,7 @@ class MessageHandler:
# 检查群聊黑白名单
group_list_type = config_api.get_plugin_config(self.plugin_config, "features.group_list_type", "blacklist")
group_list = config_api.get_plugin_config(self.plugin_config, "features.group_list", [])
if group_list_type == "whitelist":
if group_id not in group_list:
logger.warning("群聊不在白名单中,消息被丢弃")
@@ -111,9 +111,11 @@ class MessageHandler:
return False
else:
# 检查私聊黑白名单
private_list_type = config_api.get_plugin_config(self.plugin_config, "features.private_list_type", "blacklist")
private_list_type = config_api.get_plugin_config(
self.plugin_config, "features.private_list_type", "blacklist"
)
private_list = config_api.get_plugin_config(self.plugin_config, "features.private_list", [])
if private_list_type == "whitelist":
if user_id not in private_list:
logger.warning("私聊不在白名单中,消息被丢弃")
@@ -156,21 +158,23 @@ class MessageHandler:
Parameters:
raw_message: dict: 原始消息
"""
# 添加原始消息调试日志特别关注message字段
logger.debug(f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}")
logger.debug(
f"收到原始消息: message_type={raw_message.get('message_type')}, message_id={raw_message.get('message_id')}"
)
logger.debug(f"原始消息内容: {raw_message.get('message', [])}")
# 检查是否包含@或video消息段
message_segments = raw_message.get('message', [])
message_segments = raw_message.get("message", [])
if message_segments:
for i, seg in enumerate(message_segments):
seg_type = seg.get('type')
if seg_type in ['at', 'video']:
seg_type = seg.get("type")
if seg_type in ["at", "video"]:
logger.info(f"检测到 {seg_type.upper()} 消息段 [{i}]: {seg}")
elif seg_type not in ['text', 'face', 'image']:
elif seg_type not in ["text", "face", "image"]:
logger.warning(f"检测到特殊消息段 [{i}]: type={seg_type}, data={seg.get('data', {})}")
message_type: str = raw_message.get("message_type")
message_id: int = raw_message.get("message_id")
# message_time: int = raw_message.get("time")
@@ -308,9 +312,13 @@ class MessageHandler:
message_type = raw_message.get("message_type")
should_use_buffer = False
if message_type == "group" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_group", True):
if message_type == "group" and config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_group", True
):
should_use_buffer = True
elif message_type == "private" and config_api.get_plugin_config(self.plugin_config, "features.message_buffer_enable_private", True):
elif message_type == "private" and config_api.get_plugin_config(
self.plugin_config, "features.message_buffer_enable_private", True
):
should_use_buffer = True
if should_use_buffer:
@@ -368,10 +376,10 @@ class MessageHandler:
for sub_message in real_message:
sub_message: dict
sub_message_type = sub_message.get("type")
# 添加详细的消息类型调试信息
logger.debug(f"处理消息段: type={sub_message_type}, data={sub_message.get('data', {})}")
# 特别关注 at 和 video 消息的识别
if sub_message_type == "at":
logger.debug(f"检测到@消息: {sub_message}")
@@ -379,7 +387,7 @@ class MessageHandler:
logger.debug(f"检测到VIDEO消息: {sub_message}")
elif sub_message_type not in ["text", "face", "image", "record"]:
logger.warning(f"检测到特殊消息类型: {sub_message_type}, 完整消息: {sub_message}")
match sub_message_type:
case RealMessageType.text:
ret_seg = await self.handle_text_message(sub_message)

View File

@@ -33,6 +33,7 @@ class MessageSending:
try:
# 重新导入router
from ..mmc_com_layer import router
self.maibot_router = router
if self.maibot_router is not None:
logger.info("MaiBot router重连成功")
@@ -73,14 +74,14 @@ class MessageSending:
# 获取对应的客户端并发送切片
platform = message_base.message_info.platform
# 再次检查router状态防止运行时被重置
if self.maibot_router is None or not hasattr(self.maibot_router, 'clients'):
if self.maibot_router is None or not hasattr(self.maibot_router, "clients"):
logger.warning("MaiBot router连接已断开尝试重新连接")
if not await self._attempt_reconnect():
logger.error("MaiBot router重连失败切片发送中止")
return False
if platform not in self.maibot_router.clients:
logger.error(f"平台 {platform} 未连接")
return False

View File

@@ -22,7 +22,9 @@ class MetaEventHandler:
"""设置插件配置"""
self.plugin_config = plugin_config
# 更新interval值
self.interval = config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
self.interval = (
config_api.get_plugin_config(self.plugin_config, "napcat_server.heartbeat_interval", 5000) / 1000
)
async def handle_meta_event(self, message: dict) -> None:
event_type = message.get("meta_event_type")

View File

@@ -116,9 +116,9 @@ class NoticeHandler:
sub_type = raw_message.get("sub_type")
match sub_type:
case NoticeType.Notify.poke:
if config_api.get_plugin_config(self.plugin_config, "features.enable_poke", True) and await message_handler.check_allow_to_chat(
user_id, group_id, False, False
):
if config_api.get_plugin_config(
self.plugin_config, "features.enable_poke", True
) and await message_handler.check_allow_to_chat(user_id, group_id, False, False):
logger.debug("处理戳一戳消息")
handled_message, user_info = await self.handle_poke_notify(raw_message, group_id, user_id)
else:
@@ -127,14 +127,18 @@ class NoticeHandler:
from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent
await event_manager.trigger_event(NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME)
await event_manager.trigger_event(
NapcatEvent.ON_RECEIVED.FRIEND_INPUT, permission_group=PLUGIN_NAME
)
case _:
logger.warning(f"不支持的notify类型: {notice_type}.{sub_type}")
case NoticeType.group_msg_emoji_like:
case NoticeType.group_msg_emoji_like:
# 该事件转移到 handle_group_emoji_like_notify函数内触发
if config_api.get_plugin_config(self.plugin_config, "features.enable_emoji_like", True):
logger.debug("处理群聊表情回复")
handled_message, user_info = await self.handle_group_emoji_like_notify(raw_message,group_id,user_id)
handled_message, user_info = await self.handle_group_emoji_like_notify(
raw_message, group_id, user_id
)
else:
logger.warning("群聊表情回复被禁用,取消群聊表情回复处理")
case NoticeType.group_ban:
@@ -294,7 +298,7 @@ class NoticeHandler:
async def handle_group_emoji_like_notify(self, raw_message: dict, group_id: int, user_id: int):
if not group_id:
logger.error("群ID不能为空无法处理群聊表情回复通知")
return None, None
return None, None
user_qq_info: dict = await get_member_info(self.get_server_connection(), group_id, user_id)
if user_qq_info:
@@ -304,37 +308,42 @@ class NoticeHandler:
user_name = "QQ用户"
user_cardname = "QQ用户"
logger.debug("无法获取表情回复对方的用户昵称")
from src.plugin_system.core.event_manager import event_manager
from ...event_types import NapcatEvent
target_message = await event_manager.trigger_event(NapcatEvent.MESSAGE.GET_MSG,message_id=raw_message.get("message_id",""))
target_message_text = target_message.get_message_result().get("data",{}).get("raw_message","")
target_message = await event_manager.trigger_event(
NapcatEvent.MESSAGE.GET_MSG, message_id=raw_message.get("message_id", "")
)
target_message_text = target_message.get_message_result().get("data", {}).get("raw_message", "")
if not target_message:
logger.error("未找到对应消息")
return None, None
if len(target_message_text) > 15:
target_message_text = target_message_text[:15] + "..."
user_info: UserInfo = UserInfo(
platform=config_api.get_plugin_config(self.plugin_config, "maibot_server.platform_name", "qq"),
user_id=user_id,
user_nickname=user_name,
user_cardname=user_cardname,
)
like_emoji_id = raw_message.get("likes")[0].get("emoji_id")
await event_manager.trigger_event(
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
permission_group=PLUGIN_NAME,
group_id=group_id,
user_id=user_id,
message_id=raw_message.get("message_id",""),
emoji_id=like_emoji_id
)
seg_data = Seg(type="text",data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id,"")}回复了你的消息[{target_message_text}]")
NapcatEvent.ON_RECEIVED.EMOJI_LIEK,
permission_group=PLUGIN_NAME,
group_id=group_id,
user_id=user_id,
message_id=raw_message.get("message_id", ""),
emoji_id=like_emoji_id,
)
seg_data = Seg(
type="text",
data=f"{user_name}使用Emoji表情{QQ_FACE.get(like_emoji_id, '')}回复了你的消息[{target_message_text}]",
)
return seg_data, user_info
async def handle_ban_notify(self, raw_message: dict, group_id: int) -> Tuple[Seg, UserInfo] | Tuple[None, None]:
if not group_id:
logger.error("群ID不能为空无法处理禁言通知")

View File

@@ -45,12 +45,12 @@ async def check_timeout_response() -> None:
while True:
cleaned_message_count: int = 0
now_time = time.time()
# 获取心跳间隔配置
heartbeat_interval = 30 # 默认值
if plugin_config:
heartbeat_interval = config_api.get_plugin_config(plugin_config, "napcat_server.heartbeat_interval", 30)
for echo_id, response_time in list(response_time_dict.items()):
if now_time - response_time > heartbeat_interval:
cleaned_message_count += 1

View File

@@ -297,9 +297,9 @@ class SendHandler:
try:
# 检查是否为缓冲消息ID格式buffered-{original_id}-{timestamp}
if id.startswith('buffered-'):
if id.startswith("buffered-"):
# 从缓冲消息ID中提取原始消息ID
original_id = id.split('-')[1]
original_id = id.split("-")[1]
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(original_id)})
else:
msg_info_response = await self.send_message_to_napcat("get_msg", {"message_id": int(id)})
@@ -363,7 +363,7 @@ class SendHandler:
use_tts = False
if self.plugin_config:
use_tts = config_api.get_plugin_config(self.plugin_config, "voice.use_tts", False)
if not use_tts:
logger.warning("未启用语音消息处理")
return {}

View File

@@ -18,7 +18,9 @@ class WebSocketManager:
self.max_reconnect_attempts = 10 # 最大重连次数
self.plugin_config = None
async def start_connection(self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict) -> None:
async def start_connection(
self, message_handler: Callable[[Server.ServerConnection], Any], plugin_config: dict
) -> None:
"""根据配置启动 WebSocket 连接"""
self.plugin_config = plugin_config
mode = config_api.get_plugin_config(plugin_config, "napcat_server.mode")
@@ -72,9 +74,7 @@ class WebSocketManager:
# 如果配置了访问令牌,添加到请求头
access_token = config_api.get_plugin_config(self.plugin_config, "napcat_server.access_token")
if access_token:
connect_kwargs["additional_headers"] = {
"Authorization": f"Bearer {access_token}"
}
connect_kwargs["additional_headers"] = {"Authorization": f"Bearer {access_token}"}
logger.info("已添加访问令牌到连接请求头")
async with Server.connect(url, **connect_kwargs) as websocket:

View File

@@ -1,6 +1,7 @@
"""
Base search engine interface
"""
from abc import ABC, abstractmethod
from typing import Dict, List, Any
@@ -9,20 +10,20 @@ class BaseSearchEngine(ABC):
"""
搜索引擎基类
"""
@abstractmethod
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
执行搜索
Args:
args: 搜索参数,包含 query、num_results、time_range 等
Returns:
搜索结果列表,每个结果包含 title、url、snippet、provider 字段
"""
pass
@abstractmethod
def is_available(self) -> bool:
"""

View File

@@ -1,6 +1,7 @@
"""
Bing search engine implementation
"""
import asyncio
import functools
import random
@@ -58,21 +59,21 @@ class BingSearchEngine(BaseSearchEngine):
"""
Bing搜索引擎实现
"""
def __init__(self):
self.session = requests.Session()
self.session.headers = HEADERS
def is_available(self) -> bool:
"""检查Bing搜索引擎是否可用"""
return True # Bing是免费搜索引擎总是可用
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""执行Bing搜索"""
query = args["query"]
num_results = args.get("num_results", 3)
time_range = args.get("time_range", "any")
try:
loop = asyncio.get_running_loop()
func = functools.partial(self._search_sync, query, num_results, time_range)
@@ -81,17 +82,17 @@ class BingSearchEngine(BaseSearchEngine):
except Exception as e:
logger.error(f"Bing 搜索失败: {e}")
return []
def _search_sync(self, keyword: str, num_results: int, time_range: str) -> List[Dict[str, Any]]:
"""同步执行Bing搜索"""
if not keyword:
return []
list_result = []
# 构建搜索URL
search_url = bing_search_url + keyword
# 如果指定了时间范围,添加时间过滤参数
if time_range == "week":
search_url += "&qft=+filterui:date-range-7"
@@ -181,34 +182,29 @@ class BingSearchEngine(BaseSearchEngine):
# 尝试提取搜索结果
# 方法1: 查找标准的搜索结果容器
results = root.select("ol#b_results li.b_algo")
if results:
for _rank, result in enumerate(results, 1):
# 提取标题和链接
title_link = result.select_one("h2 a")
if not title_link:
continue
title = title_link.get_text().strip()
url = title_link.get("href", "")
# 提取摘要
abstract = ""
abstract_elem = result.select_one("div.b_caption p")
if abstract_elem:
abstract = abstract_elem.get_text().strip()
# 限制摘要长度
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
list_data.append({
"title": title,
"url": url,
"snippet": abstract,
"provider": "Bing"
})
list_data.append({"title": title, "url": url, "snippet": abstract, "provider": "Bing"})
if len(list_data) >= 10: # 限制结果数量
break
@@ -216,22 +212,34 @@ class BingSearchEngine(BaseSearchEngine):
if not list_data:
# 查找所有可能的搜索结果链接
all_links = root.find_all("a")
for link in all_links:
href = link.get("href", "")
text = link.get_text().strip()
# 过滤有效的搜索结果链接
if (href and text and len(text) > 10
if (
href
and text
and len(text) > 10
and not href.startswith("javascript:")
and not href.startswith("#")
and "http" in href
and not any(x in href for x in [
"bing.com/search", "bing.com/images", "bing.com/videos",
"bing.com/maps", "bing.com/news", "login", "account",
"microsoft", "javascript"
])):
and not any(
x in href
for x in [
"bing.com/search",
"bing.com/images",
"bing.com/videos",
"bing.com/maps",
"bing.com/news",
"login",
"account",
"microsoft",
"javascript",
]
)
):
# 尝试获取摘要
abstract = ""
parent = link.parent
@@ -239,18 +247,13 @@ class BingSearchEngine(BaseSearchEngine):
full_text = parent.get_text().strip()
if len(full_text) > len(text):
abstract = full_text.replace(text, "", 1).strip()
# 限制摘要长度
if ABSTRACT_MAX_LENGTH and len(abstract) > ABSTRACT_MAX_LENGTH:
abstract = abstract[:ABSTRACT_MAX_LENGTH] + "..."
list_data.append({
"title": text,
"url": href,
"snippet": abstract,
"provider": "Bing"
})
list_data.append({"title": text, "url": href, "snippet": abstract, "provider": "Bing"})
if len(list_data) >= 10:
break

View File

@@ -1,6 +1,7 @@
"""
DuckDuckGo search engine implementation
"""
from typing import Dict, List, Any
from asyncddgs import aDDGS
@@ -14,27 +15,22 @@ class DDGSearchEngine(BaseSearchEngine):
"""
DuckDuckGo搜索引擎实现
"""
def is_available(self) -> bool:
"""检查DuckDuckGo搜索引擎是否可用"""
return True # DuckDuckGo不需要API密钥总是可用
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""执行DuckDuckGo搜索"""
query = args["query"]
num_results = args.get("num_results", 3)
try:
async with aDDGS() as ddgs:
search_response = await ddgs.text(query, max_results=num_results)
return [
{
"title": r.get("title"),
"url": r.get("href"),
"snippet": r.get("body"),
"provider": "DuckDuckGo"
}
{"title": r.get("title"), "url": r.get("href"), "snippet": r.get("body"), "provider": "DuckDuckGo"}
for r in search_response
]
except Exception as e:

View File

@@ -1,6 +1,7 @@
"""
Exa search engine implementation
"""
import asyncio
import functools
from datetime import datetime, timedelta
@@ -19,31 +20,27 @@ class ExaSearchEngine(BaseSearchEngine):
"""
Exa搜索引擎实现
"""
def __init__(self):
self._initialize_clients()
def _initialize_clients(self):
"""初始化Exa客户端"""
# 从主配置文件读取API密钥
exa_api_keys = config_api.get_global_config("web_search.exa_api_keys", None)
# 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config(
exa_api_keys,
lambda key: Exa(api_key=key),
"Exa"
)
self.api_manager = create_api_key_manager_from_config(exa_api_keys, lambda key: Exa(api_key=key), "Exa")
def is_available(self) -> bool:
"""检查Exa搜索引擎是否可用"""
return self.api_manager.is_available()
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""执行Exa搜索"""
if not self.is_available():
return []
query = args["query"]
num_results = args.get("num_results", 3)
time_range = args.get("time_range", "any")
@@ -52,7 +49,7 @@ class ExaSearchEngine(BaseSearchEngine):
if time_range != "any":
today = datetime.now()
start_date = today - timedelta(days=7 if time_range == "week" else 30)
exa_args["start_published_date"] = start_date.strftime('%Y-%m-%d')
exa_args["start_published_date"] = start_date.strftime("%Y-%m-%d")
try:
# 使用API密钥管理器获取下一个客户端
@@ -60,17 +57,17 @@ class ExaSearchEngine(BaseSearchEngine):
if not exa_client:
logger.error("无法获取Exa客户端")
return []
loop = asyncio.get_running_loop()
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
search_response = await loop.run_in_executor(None, func)
return [
{
"title": res.title,
"url": res.url,
"snippet": " ".join(getattr(res, 'highlights', [])) or (getattr(res, 'text', '')[:250] + '...'),
"provider": "Exa"
"snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
"provider": "Exa",
}
for res in search_response.results
]

View File

@@ -1,6 +1,7 @@
"""
Tavily search engine implementation
"""
import asyncio
import functools
from typing import Dict, List, Any
@@ -18,31 +19,29 @@ class TavilySearchEngine(BaseSearchEngine):
"""
Tavily搜索引擎实现
"""
def __init__(self):
self._initialize_clients()
def _initialize_clients(self):
"""初始化Tavily客户端"""
# 从主配置文件读取API密钥
tavily_api_keys = config_api.get_global_config("web_search.tavily_api_keys", None)
# 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config(
tavily_api_keys,
lambda key: TavilyClient(api_key=key),
"Tavily"
tavily_api_keys, lambda key: TavilyClient(api_key=key), "Tavily"
)
def is_available(self) -> bool:
"""检查Tavily搜索引擎是否可用"""
return self.api_manager.is_available()
async def search(self, args: Dict[str, Any]) -> List[Dict[str, Any]]:
"""执行Tavily搜索"""
if not self.is_available():
return []
query = args["query"]
num_results = args.get("num_results", 3)
time_range = args.get("time_range", "any")
@@ -53,38 +52,40 @@ class TavilySearchEngine(BaseSearchEngine):
if not tavily_client:
logger.error("无法获取Tavily客户端")
return []
# 构建Tavily搜索参数
search_params = {
"query": query,
"max_results": num_results,
"search_depth": "basic",
"include_answer": False,
"include_raw_content": False
"include_raw_content": False,
}
# 根据时间范围调整搜索参数
if time_range == "week":
search_params["days"] = 7
elif time_range == "month":
search_params["days"] = 30
loop = asyncio.get_running_loop()
func = functools.partial(tavily_client.search, **search_params)
search_response = await loop.run_in_executor(None, func)
results = []
if search_response and "results" in search_response:
for res in search_response["results"]:
results.append({
"title": res.get("title", "无标题"),
"url": res.get("url", ""),
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
"provider": "Tavily"
})
results.append(
{
"title": res.get("title", "无标题"),
"url": res.get("url", ""),
"snippet": res.get("content", "")[:300] + "..." if res.get("content") else "无摘要",
"provider": "Tavily",
}
)
return results
except Exception as e:
logger.error(f"Tavily 搜索失败: {e}")
return []

View File

@@ -3,15 +3,10 @@ Web Search Tool Plugin
一个功能强大的网络搜索和URL解析插件支持多种搜索引擎和解析策略。
"""
from typing import List, Tuple, Type
from src.plugin_system import (
BasePlugin,
register_plugin,
ComponentInfo,
ConfigField,
PythonDependency
)
from src.plugin_system import BasePlugin, register_plugin, ComponentInfo, ConfigField, PythonDependency
from src.plugin_system.apis import config_api
from src.common.logger import get_logger
@@ -25,7 +20,7 @@ logger = get_logger("web_search_plugin")
class WEBSEARCHPLUGIN(BasePlugin):
"""
网络搜索工具插件
提供网络搜索和URL解析功能支持多种搜索引擎
- Exa (需要API密钥)
- Tavily (需要API密钥)
@@ -37,11 +32,11 @@ class WEBSEARCHPLUGIN(BasePlugin):
plugin_name: str = "web_search_tool" # 内部标识符
enable_plugin: bool = True
dependencies: List[str] = [] # 插件依赖列表
def __init__(self, *args, **kwargs):
"""初始化插件,立即加载所有搜索引擎"""
super().__init__(*args, **kwargs)
# 立即初始化所有搜索引擎触发API密钥管理器的日志输出
logger.info("🚀 正在初始化所有搜索引擎...")
try:
@@ -49,65 +44,58 @@ class WEBSEARCHPLUGIN(BasePlugin):
from .engines.tavily_engine import TavilySearchEngine
from .engines.ddg_engine import DDGSearchEngine
from .engines.bing_engine import BingSearchEngine
# 实例化所有搜索引擎这会触发API密钥管理器的初始化
exa_engine = ExaSearchEngine()
tavily_engine = TavilySearchEngine()
ddg_engine = DDGSearchEngine()
bing_engine = BingSearchEngine()
# 报告每个引擎的状态
engines_status = {
"Exa": exa_engine.is_available(),
"Tavily": tavily_engine.is_available(),
"DuckDuckGo": ddg_engine.is_available(),
"Bing": bing_engine.is_available()
"Bing": bing_engine.is_available(),
}
available_engines = [name for name, available in engines_status.items() if available]
unavailable_engines = [name for name, available in engines_status.items() if not available]
if available_engines:
logger.info(f"✅ 可用搜索引擎: {', '.join(available_engines)}")
if unavailable_engines:
logger.info(f"❌ 不可用搜索引擎: {', '.join(unavailable_engines)}")
except Exception as e:
logger.error(f"❌ 搜索引擎初始化失败: {e}", exc_info=True)
# Python包依赖列表
python_dependencies: List[PythonDependency] = [
PythonDependency(
package_name="asyncddgs",
description="异步DuckDuckGo搜索库",
optional=False
),
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
PythonDependency(
package_name="exa_py",
description="Exa搜索API客户端库",
optional=True # 如果没有API密钥这个是可选的
optional=True, # 如果没有API密钥这个是可选的
),
PythonDependency(
package_name="tavily",
install_name="tavily-python", # 安装时使用这个名称
description="Tavily搜索API客户端库",
optional=True # 如果没有API密钥这个是可选的
optional=True, # 如果没有API密钥这个是可选的
),
PythonDependency(
package_name="httpx",
version=">=0.20.0",
install_name="httpx[socks]", # 安装时使用这个名称(包含可选依赖)
description="支持SOCKS代理的HTTP客户端库",
optional=False
)
optional=False,
),
]
config_file_name: str = "config.toml" # 配置文件名
# 配置节描述
config_section_descriptions = {
"plugin": "插件基本信息",
"proxy": "链接本地解析代理配置"
}
config_section_descriptions = {"plugin": "插件基本信息", "proxy": "链接本地解析代理配置"}
# 配置Schema定义
# 注意EXA配置和组件设置已迁移到主配置文件(bot_config.toml)的[exa]和[web_search]部分
@@ -119,42 +107,32 @@ class WEBSEARCHPLUGIN(BasePlugin):
},
"proxy": {
"http_proxy": ConfigField(
type=str,
default=None,
description="HTTP代理地址格式如: http://proxy.example.com:8080"
type=str, default=None, description="HTTP代理地址格式如: http://proxy.example.com:8080"
),
"https_proxy": ConfigField(
type=str,
default=None,
description="HTTPS代理地址格式如: http://proxy.example.com:8080"
type=str, default=None, description="HTTPS代理地址格式如: http://proxy.example.com:8080"
),
"socks5_proxy": ConfigField(
type=str,
default=None,
description="SOCKS5代理地址格式如: socks5://proxy.example.com:1080"
type=str, default=None, description="SOCKS5代理地址格式如: socks5://proxy.example.com:1080"
),
"enable_proxy": ConfigField(
type=bool,
default=False,
description="是否启用代理"
)
"enable_proxy": ConfigField(type=bool, default=False, description="是否启用代理"),
},
}
def get_plugin_components(self) -> List[Tuple[ComponentInfo, Type]]:
"""
获取插件组件列表
Returns:
组件信息和类型的元组列表
"""
enable_tool = []
# 从主配置文件读取组件启用配置
if config_api.get_global_config("web_search.enable_web_search_tool", True):
enable_tool.append((WebSurfingTool.get_tool_info(), WebSurfingTool))
if config_api.get_global_config("web_search.enable_url_tool", True):
enable_tool.append((URLParserTool.get_tool_info(), URLParserTool))
return enable_tool

View File

@@ -1,6 +1,7 @@
"""
URL parser tool implementation
"""
import asyncio
import functools
from typing import Any, Dict
@@ -24,17 +25,18 @@ class URLParserTool(BaseTool):
"""
一个用于解析和总结一个或多个网页URL内容的工具。
"""
name: str = "parse_url"
description: str = "当需要理解一个或多个特定网页链接的内容时,使用此工具。例如:'这些网页讲了什么?[https://example.com, https://example2.com]''帮我总结一下这些文章'"
available_for_llm: bool = True
parameters = [
("urls", ToolParamType.STRING, "要理解的网站", True, None),
]
def __init__(self, plugin_config=None):
super().__init__(plugin_config)
self._initialize_exa_clients()
def _initialize_exa_clients(self):
"""初始化Exa客户端"""
# 优先从主配置文件读取,如果没有则从插件配置文件读取
@@ -42,12 +44,10 @@ class URLParserTool(BaseTool):
if exa_api_keys is None:
# 从插件配置文件读取
exa_api_keys = self.get_config("exa.api_keys", [])
# 创建API密钥管理器
self.api_manager = create_api_key_manager_from_config(
exa_api_keys,
lambda key: Exa(api_key=key),
"Exa URL Parser"
exa_api_keys, lambda key: Exa(api_key=key), "Exa URL Parser"
)
async def _local_parse_and_summarize(self, url: str) -> Dict[str, Any]:
@@ -58,12 +58,12 @@ class URLParserTool(BaseTool):
# 读取代理配置
enable_proxy = self.get_config("proxy.enable_proxy", False)
proxies = None
if enable_proxy:
socks5_proxy = self.get_config("proxy.socks5_proxy", None)
http_proxy = self.get_config("proxy.http_proxy", None)
https_proxy = self.get_config("proxy.https_proxy", None)
# 优先使用SOCKS5代理全协议代理
if socks5_proxy:
proxies = socks5_proxy
@@ -75,17 +75,17 @@ class URLParserTool(BaseTool):
if https_proxy:
proxies["https://"] = https_proxy
logger.info(f"使用HTTP/HTTPS代理配置: {proxies}")
client_kwargs = {"timeout": 15.0, "follow_redirects": True}
if proxies:
client_kwargs["proxies"] = proxies
async with httpx.AsyncClient(**client_kwargs) as client:
response = await client.get(url)
response.raise_for_status()
soup = BeautifulSoup(response.text, "html.parser")
title = soup.title.string if soup.title else "无标题"
for script in soup(["script", "style"]):
script.extract()
@@ -104,12 +104,12 @@ class URLParserTool(BaseTool):
return {"error": "未配置LLM模型"}
success, summary, reasoning, model_name = await llm_api.generate_with_model(
prompt=summary_prompt,
model_config=model_config,
request_type="story.generate",
temperature=0.3,
max_tokens=1000
)
prompt=summary_prompt,
model_config=model_config,
request_type="story.generate",
temperature=0.3,
max_tokens=1000,
)
if not success:
logger.info(f"生成摘要失败: {summary}")
@@ -117,12 +117,7 @@ class URLParserTool(BaseTool):
logger.info(f"成功生成摘要内容:'{summary}'")
return {
"title": title,
"url": url,
"snippet": summary,
"source": "local"
}
return {"title": title, "url": url, "snippet": summary, "source": "local"}
except httpx.HTTPStatusError as e:
logger.warning(f"本地解析URL '{url}' 失败 (HTTP {e.response.status_code})")
@@ -137,6 +132,7 @@ class URLParserTool(BaseTool):
"""
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
@@ -144,7 +140,7 @@ class URLParserTool(BaseTool):
if cached_result:
logger.info(f"缓存命中: {self.name} -> {function_args}")
return cached_result
urls_input = function_args.get("urls")
if not urls_input:
return {"error": "URL列表不能为空。"}
@@ -158,14 +154,14 @@ class URLParserTool(BaseTool):
valid_urls = validate_urls(urls)
if not valid_urls:
return {"error": "未找到有效的URL。"}
urls = valid_urls
logger.info(f"准备解析 {len(urls)} 个URL: {urls}")
successful_results = []
error_messages = []
urls_to_retry_locally = []
# 步骤 1: 尝试使用 Exa API 进行解析
contents_response = None
if self.api_manager.is_available():
@@ -182,41 +178,45 @@ class URLParserTool(BaseTool):
contents_response = await loop.run_in_executor(None, func)
except Exception as e:
logger.error(f"执行 Exa URL解析时发生严重异常: {e}", exc_info=True)
contents_response = None # 确保异常后为None
contents_response = None # 确保异常后为None
# 步骤 2: 处理Exa的响应
if contents_response and hasattr(contents_response, 'statuses'):
results_map = {res.url: res for res in contents_response.results} if hasattr(contents_response, 'results') else {}
if contents_response and hasattr(contents_response, "statuses"):
results_map = (
{res.url: res for res in contents_response.results} if hasattr(contents_response, "results") else {}
)
if contents_response.statuses:
for status in contents_response.statuses:
if status.status == 'success':
if status.status == "success":
res = results_map.get(status.id)
if res:
summary = getattr(res, 'summary', '')
highlights = " ".join(getattr(res, 'highlights', []))
text_snippet = (getattr(res, 'text', '')[:300] + '...') if getattr(res, 'text', '') else ''
snippet = summary or highlights or text_snippet or '无摘要'
successful_results.append({
"title": getattr(res, 'title', '无标题'),
"url": getattr(res, 'url', status.id),
"snippet": snippet,
"source": "exa"
})
summary = getattr(res, "summary", "")
highlights = " ".join(getattr(res, "highlights", []))
text_snippet = (getattr(res, "text", "")[:300] + "...") if getattr(res, "text", "") else ""
snippet = summary or highlights or text_snippet or "无摘要"
successful_results.append(
{
"title": getattr(res, "title", "无标题"),
"url": getattr(res, "url", status.id),
"snippet": snippet,
"source": "exa",
}
)
else:
error_tag = getattr(status, 'error', '未知错误')
error_tag = getattr(status, "error", "未知错误")
logger.warning(f"Exa解析URL '{status.id}' 失败: {error_tag}。准备本地重试。")
urls_to_retry_locally.append(status.id)
else:
# 如果Exa未配置、API调用失败或返回无效响应则所有URL都进入本地重试
urls_to_retry_locally.extend(url for url in urls if url not in [res['url'] for res in successful_results])
urls_to_retry_locally.extend(url for url in urls if url not in [res["url"] for res in successful_results])
# 步骤 3: 对失败的URL进行本地解析
if urls_to_retry_locally:
logger.info(f"开始本地解析以下URL: {urls_to_retry_locally}")
local_tasks = [self._local_parse_and_summarize(url) for url in urls_to_retry_locally]
local_results = await asyncio.gather(*local_tasks)
for i, res in enumerate(local_results):
url = urls_to_retry_locally[i]
if "error" in res:
@@ -228,13 +228,9 @@ class URLParserTool(BaseTool):
return {"error": "无法从所有给定的URL获取内容。", "details": error_messages}
formatted_content = format_url_parse_results(successful_results)
result = {
"type": "url_parse_result",
"content": formatted_content,
"errors": error_messages
}
result = {"type": "url_parse_result", "content": formatted_content, "errors": error_messages}
# 保存到缓存
if "error" not in result:
await tool_cache.set(self.name, function_args, current_file_path, result)

View File

@@ -1,6 +1,7 @@
"""
Web search tool implementation
"""
import asyncio
from typing import Any, Dict, List
@@ -22,14 +23,23 @@ class WebSurfingTool(BaseTool):
"""
网络搜索工具
"""
name: str = "web_search"
description: str = "用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
description: str = (
"用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
)
available_for_llm: bool = True
parameters = [
("query", ToolParamType.STRING, "要搜索的关键词或问题。", True, None),
("num_results", ToolParamType.INTEGER, "期望每个搜索引擎返回的搜索结果数量默认为5。", False, None),
("time_range", ToolParamType.STRING, "指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'", False, ["any", "week", "month"])
] # type: ignore
(
"time_range",
ToolParamType.STRING,
"指定搜索的时间范围,可以是 'any', 'week', 'month'。默认为 'any'",
False,
["any", "week", "month"],
),
] # type: ignore
def __init__(self, plugin_config=None):
super().__init__(plugin_config)
@@ -38,7 +48,7 @@ class WebSurfingTool(BaseTool):
"exa": ExaSearchEngine(),
"tavily": TavilySearchEngine(),
"ddg": DDGSearchEngine(),
"bing": BingSearchEngine()
"bing": BingSearchEngine(),
}
async def execute(self, function_args: Dict[str, Any]) -> Dict[str, Any]:
@@ -48,6 +58,7 @@ class WebSurfingTool(BaseTool):
# 获取当前文件路径用于缓存键
import os
current_file_path = os.path.abspath(__file__)
# 检查缓存
@@ -59,7 +70,7 @@ class WebSurfingTool(BaseTool):
# 读取搜索配置
enabled_engines = config_api.get_global_config("web_search.enabled_engines", ["ddg"])
search_strategy = config_api.get_global_config("web_search.search_strategy", "single")
logger.info(f"开始搜索,策略: {search_strategy}, 启用引擎: {enabled_engines}, 参数: '{function_args}'")
# 根据策略执行搜索
@@ -69,17 +80,19 @@ class WebSurfingTool(BaseTool):
result = await self._execute_fallback_search(function_args, enabled_engines)
else: # single
result = await self._execute_single_search(function_args, enabled_engines)
# 保存到缓存
if "error" not in result:
await tool_cache.set(self.name, function_args, current_file_path, result, semantic_query=query)
return result
async def _execute_parallel_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
async def _execute_parallel_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
"""并行搜索策略:同时使用所有启用的搜索引擎"""
search_tasks = []
for engine_name in enabled_engines:
engine = self.engines.get(engine_name)
if engine and engine.is_available():
@@ -92,7 +105,7 @@ class WebSurfingTool(BaseTool):
try:
search_results_lists = await asyncio.gather(*search_tasks, return_exceptions=True)
all_results = []
for result in search_results_lists:
if isinstance(result, list):
@@ -103,7 +116,7 @@ class WebSurfingTool(BaseTool):
# 去重并格式化
unique_results = deduplicate_results(all_results)
formatted_content = format_search_results(unique_results)
return {
"type": "web_search_result",
"content": formatted_content,
@@ -113,30 +126,32 @@ class WebSurfingTool(BaseTool):
logger.error(f"执行并行网络搜索时发生异常: {e}", exc_info=True)
return {"error": f"执行网络搜索时发生严重错误: {str(e)}"}
async def _execute_fallback_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
async def _execute_fallback_search(
self, function_args: Dict[str, Any], enabled_engines: List[str]
) -> Dict[str, Any]:
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
for engine_name in enabled_engines:
engine = self.engines.get(engine_name)
if not engine or not engine.is_available():
continue
try:
custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5)
results = await engine.search(custom_args)
if results: # 如果有结果,直接返回
formatted_content = format_search_results(results)
return {
"type": "web_search_result",
"content": formatted_content,
}
except Exception as e:
logger.warning(f"{engine_name} 搜索失败,尝试下一个引擎: {e}")
continue
return {"error": "所有搜索引擎都失败了。"}
async def _execute_single_search(self, function_args: Dict[str, Any], enabled_engines: List[str]) -> Dict[str, Any]:
@@ -145,20 +160,20 @@ class WebSurfingTool(BaseTool):
engine = self.engines.get(engine_name)
if not engine or not engine.is_available():
continue
try:
custom_args = function_args.copy()
custom_args["num_results"] = custom_args.get("num_results", 5)
results = await engine.search(custom_args)
formatted_content = format_search_results(results)
return {
"type": "web_search_result",
"content": formatted_content,
}
except Exception as e:
logger.error(f"{engine_name} 搜索失败: {e}")
return {"error": f"{engine_name} 搜索失败: {str(e)}"}
return {"error": "没有可用的搜索引擎。"}

View File

@@ -1,24 +1,25 @@
"""
API密钥管理器提供轮询机制
"""
import itertools
from typing import List, Optional, TypeVar, Generic, Callable
from src.common.logger import get_logger
logger = get_logger("api_key_manager")
T = TypeVar('T')
T = TypeVar("T")
class APIKeyManager(Generic[T]):
"""
API密钥管理器支持轮询机制
"""
def __init__(self, api_keys: List[str], client_factory: Callable[[str], T], service_name: str = "Unknown"):
"""
初始化API密钥管理器
Args:
api_keys: API密钥列表
client_factory: 客户端工厂函数接受API密钥参数并返回客户端实例
@@ -27,14 +28,14 @@ class APIKeyManager(Generic[T]):
self.service_name = service_name
self.clients: List[T] = []
self.client_cycle: Optional[itertools.cycle] = None
if api_keys:
# 过滤有效的API密钥排除None、空字符串、"None"字符串等
valid_keys = []
for key in api_keys:
if isinstance(key, str) and key.strip() and key.strip().lower() not in ("none", "null", ""):
valid_keys.append(key.strip())
if valid_keys:
try:
self.clients = [client_factory(key) for key in valid_keys]
@@ -48,35 +49,33 @@ class APIKeyManager(Generic[T]):
logger.warning(f"⚠️ {service_name} API Keys 配置无效包含None或空值{service_name} 功能将不可用")
else:
logger.warning(f"⚠️ {service_name} API Keys 未配置,{service_name} 功能将不可用")
def is_available(self) -> bool:
"""检查是否有可用的客户端"""
return bool(self.clients and self.client_cycle)
def get_next_client(self) -> Optional[T]:
"""获取下一个客户端(轮询)"""
if not self.is_available():
return None
return next(self.client_cycle)
def get_client_count(self) -> int:
"""获取可用客户端数量"""
return len(self.clients)
def create_api_key_manager_from_config(
config_keys: Optional[List[str]],
client_factory: Callable[[str], T],
service_name: str
config_keys: Optional[List[str]], client_factory: Callable[[str], T], service_name: str
) -> APIKeyManager[T]:
"""
从配置创建API密钥管理器的便捷函数
Args:
config_keys: 从配置读取的API密钥列表
client_factory: 客户端工厂函数
service_name: 服务名称
Returns:
API密钥管理器实例
"""

View File

@@ -1,6 +1,7 @@
"""
Formatters for web search results
"""
from typing import List, Dict, Any
@@ -13,15 +14,15 @@ def format_search_results(results: List[Dict[str, Any]]) -> str:
formatted_string = "根据网络搜索结果:\n\n"
for i, res in enumerate(results, 1):
title = res.get("title", '无标题')
url = res.get("url", '#')
snippet = res.get("snippet", '无摘要')
title = res.get("title", "无标题")
url = res.get("url", "#")
snippet = res.get("snippet", "无摘要")
provider = res.get("provider", "未知来源")
formatted_string += f"{i}. **{title}** (来自: {provider})\n"
formatted_string += f" - 摘要: {snippet}\n"
formatted_string += f" - 来源: {url}\n\n"
return formatted_string
@@ -31,10 +32,10 @@ def format_url_parse_results(results: List[Dict[str, Any]]) -> str:
"""
formatted_parts = []
for res in results:
title = res.get('title', '无标题')
url = res.get('url', '#')
snippet = res.get('snippet', '无摘要')
source = res.get('source', '未知')
title = res.get("title", "无标题")
url = res.get("url", "#")
snippet = res.get("snippet", "无摘要")
source = res.get("source", "未知")
formatted_string = f"**{title}**\n"
formatted_string += f"**内容摘要**:\n{snippet}\n"

View File

@@ -1,6 +1,7 @@
"""
URL processing utilities
"""
import re
from typing import List
@@ -12,11 +13,11 @@ def parse_urls_from_input(urls_input) -> List[str]:
if isinstance(urls_input, str):
# 如果是字符串尝试解析为URL列表
# 提取所有HTTP/HTTPS URL
url_pattern = r'https?://[^\s\],]+'
url_pattern = r"https?://[^\s\],]+"
urls = re.findall(url_pattern, urls_input)
if not urls:
# 如果没有找到标准URL将整个字符串作为单个URL
if urls_input.strip().startswith(('http://', 'https://')):
if urls_input.strip().startswith(("http://", "https://")):
urls = [urls_input.strip()]
else:
return []
@@ -24,7 +25,7 @@ def parse_urls_from_input(urls_input) -> List[str]:
urls = [url.strip() for url in urls_input if isinstance(url, str) and url.strip()]
else:
return []
return urls
@@ -34,6 +35,6 @@ def validate_urls(urls: List[str]) -> List[str]:
"""
valid_urls = []
for url in urls:
if url.startswith(('http://', 'https://')):
if url.startswith(("http://", "https://")):
valid_urls.append(url)
return valid_urls

View File

@@ -21,8 +21,18 @@ logger = get_logger(__name__)
# ============================ AsyncTask ============================
class ReminderTask(AsyncTask):
def __init__(self, delay: float, stream_id: str, is_group: bool, target_user_id: str, target_user_name: str, event_details: str, creator_name: str):
def __init__(
self,
delay: float,
stream_id: str,
is_group: bool,
target_user_id: str,
target_user_name: str,
event_details: str,
creator_name: str,
):
super().__init__(task_name=f"ReminderTask_{target_user_id}_{datetime.now().timestamp()}")
self.delay = delay
self.stream_id = stream_id
@@ -37,22 +47,22 @@ class ReminderTask(AsyncTask):
if self.delay > 0:
logger.info(f"等待 {self.delay:.2f} 秒后执行提醒...")
await asyncio.sleep(self.delay)
logger.info(f"执行提醒任务: 给 {self.target_user_name} 发送关于 '{self.event_details}' 的提醒")
reminder_text = f"叮咚!这是 {self.creator_name} 让我准时提醒你的事情:\n\n{self.event_details}"
if self.is_group:
# 在群聊中,构造 @ 消息段并发送
group_id = self.stream_id.split('_')[-1] if '_' in self.stream_id else self.stream_id
group_id = self.stream_id.split("_")[-1] if "_" in self.stream_id else self.stream_id
message_payload = [
{"type": "at", "data": {"qq": self.target_user_id}},
{"type": "text", "data": {"text": f" {reminder_text}"}}
{"type": "text", "data": {"text": f" {reminder_text}"}},
]
await send_api.adapter_command_to_stream(
action="send_group_msg",
params={"group_id": group_id, "message": message_payload},
stream_id=self.stream_id
stream_id=self.stream_id,
)
else:
# 在私聊中,直接发送文本
@@ -66,6 +76,7 @@ class ReminderTask(AsyncTask):
# =============================== Actions ===============================
class RemindAction(BaseAction):
"""一个能从对话中智能识别并设置定时提醒的动作。"""
@@ -95,12 +106,12 @@ class RemindAction(BaseAction):
action_parameters = {
"user_name": "需要被提醒的人的称呼或名字,如果没有明确指定给某人,则默认为'自己'",
"remind_time": "描述提醒时间的自然语言字符串,例如'十分钟后''明天下午3点'",
"event_details": "需要提醒的具体事件内容"
"event_details": "需要提醒的具体事件内容",
}
action_require = [
"当用户请求在未来的某个时间点提醒他/她或别人某件事时使用",
"适用于包含明确时间信息和事件描述的对话",
"例如:'10分钟后提醒我收快递''明天早上九点喊一下李四参加晨会'"
"例如:'10分钟后提醒我收快递''明天早上九点喊一下李四参加晨会'",
]
async def execute(self) -> Tuple[bool, str]:
@@ -110,7 +121,15 @@ class RemindAction(BaseAction):
event_details = self.action_data.get("event_details")
if not all([user_name, remind_time_str, event_details]):
missing_params = [p for p, v in {"user_name": user_name, "remind_time": remind_time_str, "event_details": event_details}.items() if not v]
missing_params = [
p
for p, v in {
"user_name": user_name,
"remind_time": remind_time_str,
"event_details": event_details,
}.items()
if not v
]
error_msg = f"缺少必要的提醒参数: {', '.join(missing_params)}"
logger.warning(f"[ReminderPlugin] LLM未能提取完整参数: {error_msg}")
return False, error_msg
@@ -135,9 +154,9 @@ class RemindAction(BaseAction):
person_manager = get_person_info_manager()
user_id_to_remind = None
user_name_to_remind = ""
assert isinstance(user_name, str)
if user_name.strip() in ["自己", "", "me"]:
user_id_to_remind = self.user_id
user_name_to_remind = self.user_nickname
@@ -154,7 +173,7 @@ class RemindAction(BaseAction):
try:
assert user_id_to_remind is not None
assert event_details is not None
reminder_task = ReminderTask(
delay=delay_seconds,
stream_id=self.chat_id,
@@ -162,14 +181,14 @@ class RemindAction(BaseAction):
target_user_id=str(user_id_to_remind),
target_user_name=str(user_name_to_remind),
event_details=str(event_details),
creator_name=str(self.user_nickname)
creator_name=str(self.user_nickname),
)
await async_task_manager.add_task(reminder_task)
# 4. 发送确认消息
confirm_message = f"好的,我记下了。\n将在 {target_time.strftime('%Y-%m-%d %H:%M:%S')} 提醒 {user_name_to_remind}\n{event_details}"
await self.send_text(confirm_message)
return True, "提醒设置成功"
except Exception as e:
logger.error(f"[ReminderPlugin] 创建提醒任务时出错: {e}", exc_info=True)
@@ -179,6 +198,7 @@ class RemindAction(BaseAction):
# =============================== Plugin ===============================
@register_plugin
class ReminderPlugin(BasePlugin):
"""一个能从对话中智能识别并设置定时提醒的插件。"""
@@ -193,6 +213,4 @@ class ReminderPlugin(BasePlugin):
def get_plugin_components(self) -> List[Tuple[ActionInfo, Type[BaseAction]]]:
"""注册插件的所有功能组件。"""
return [
(RemindAction.get_action_info(), RemindAction)
]
return [(RemindAction.get_action_info(), RemindAction)]