Merge branch 'dev' into dev
This commit is contained in:
@@ -639,18 +639,20 @@ class ChatterPlanFilter:
|
||||
else:
|
||||
keywords.append("晚上")
|
||||
|
||||
# 使用新的统一记忆系统检索记忆
|
||||
# 使用记忆图系统检索记忆
|
||||
try:
|
||||
from src.chat.memory_system import get_memory_system
|
||||
from src.memory_graph.manager_singleton import get_memory_manager
|
||||
|
||||
memory_system = get_memory_system()
|
||||
memory_manager = get_memory_manager()
|
||||
if not memory_manager:
|
||||
return "记忆系统未初始化。"
|
||||
|
||||
# 将关键词转换为查询字符串
|
||||
query = " ".join(keywords)
|
||||
enhanced_memories = await memory_system.retrieve_relevant_memories(
|
||||
query_text=query,
|
||||
user_id="system", # 系统查询
|
||||
scope_id="system",
|
||||
limit=5,
|
||||
enhanced_memories = await memory_manager.search_memories(
|
||||
query=query,
|
||||
top_k=5,
|
||||
use_multi_query=False, # 直接使用关键词查询
|
||||
)
|
||||
|
||||
if not enhanced_memories:
|
||||
@@ -658,9 +660,14 @@ class ChatterPlanFilter:
|
||||
|
||||
# 转换格式以兼容现有代码
|
||||
retrieved_memories = []
|
||||
for memory_chunk in enhanced_memories:
|
||||
content = memory_chunk.display or memory_chunk.text_content or ""
|
||||
memory_type = memory_chunk.memory_type.value if memory_chunk.memory_type else "unknown"
|
||||
for memory in enhanced_memories:
|
||||
# 从记忆图的节点中提取内容
|
||||
content_parts = []
|
||||
for node in memory.nodes:
|
||||
if node.content:
|
||||
content_parts.append(node.content)
|
||||
content = " ".join(content_parts) if content_parts else "无内容"
|
||||
memory_type = memory.memory_type.value
|
||||
retrieved_memories.append((memory_type, content))
|
||||
|
||||
memory_statements = [
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
当定时任务触发时,负责搜集信息、调用LLM决策、并根据决策生成回复
|
||||
"""
|
||||
|
||||
import json
|
||||
import orjson
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
负责记录和管理已回复过的评论ID,避免重复回复
|
||||
"""
|
||||
|
||||
import json
|
||||
import orjson
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
@@ -71,7 +71,7 @@ class ReplyTrackerService:
|
||||
self.replied_comments = {}
|
||||
return
|
||||
|
||||
data = json.loads(file_content)
|
||||
data = orjson.loads(file_content)
|
||||
if self._validate_data(data):
|
||||
self.replied_comments = data
|
||||
logger.info(
|
||||
@@ -81,7 +81,7 @@ class ReplyTrackerService:
|
||||
else:
|
||||
logger.error("加载的数据格式无效,将创建新的记录")
|
||||
self.replied_comments = {}
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.error(f"解析回复记录文件失败: {e}")
|
||||
self._backup_corrupted_file()
|
||||
self.replied_comments = {}
|
||||
@@ -118,7 +118,7 @@ class ReplyTrackerService:
|
||||
|
||||
# 先写入临时文件
|
||||
with open(temp_file, "w", encoding="utf-8") as f:
|
||||
json.dump(self.replied_comments, f, ensure_ascii=False, indent=2)
|
||||
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
|
||||
# 如果写入成功,重命名为正式文件
|
||||
if temp_file.stat().st_size > 0: # 确保写入成功
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import orjson
|
||||
from typing import ClassVar, List
|
||||
|
||||
import websockets as Server
|
||||
@@ -44,10 +44,10 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||
# 只在debug模式下记录原始消息
|
||||
if logger.level <= 10: # DEBUG level
|
||||
logger.debug(f"{raw_message[:1500]}..." if (len(raw_message) > 1500) else raw_message)
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
decoded_raw_message: dict = orjson.loads(raw_message)
|
||||
try:
|
||||
# 首先尝试解析原始消息
|
||||
decoded_raw_message: dict = json.loads(raw_message)
|
||||
decoded_raw_message: dict = orjson.loads(raw_message)
|
||||
|
||||
# 检查是否是切片消息 (来自 MMC)
|
||||
if chunker.is_chunk_message(decoded_raw_message):
|
||||
@@ -71,7 +71,7 @@ async def message_recv(server_connection: Server.ServerConnection):
|
||||
elif post_type is None:
|
||||
await put_response(decoded_raw_message)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.error(f"消息解析失败: {e}")
|
||||
logger.debug(f"原始消息: {raw_message[:500]}...")
|
||||
except Exception as e:
|
||||
|
||||
@@ -5,7 +5,7 @@
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import orjson
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
@@ -34,7 +34,7 @@ class MessageChunker:
|
||||
"""判断消息是否需要切片"""
|
||||
try:
|
||||
if isinstance(message, dict):
|
||||
message_str = json.dumps(message, ensure_ascii=False)
|
||||
message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
else:
|
||||
message_str = message
|
||||
return len(message_str.encode("utf-8")) > self.max_chunk_size
|
||||
@@ -58,7 +58,7 @@ class MessageChunker:
|
||||
try:
|
||||
# 统一转换为字符串
|
||||
if isinstance(message, dict):
|
||||
message_str = json.dumps(message, ensure_ascii=False)
|
||||
message_str = orjson.dumps(message, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')
|
||||
else:
|
||||
message_str = message
|
||||
|
||||
@@ -116,7 +116,7 @@ class MessageChunker:
|
||||
"""判断是否是切片消息"""
|
||||
try:
|
||||
if isinstance(message, str):
|
||||
data = json.loads(message)
|
||||
data = orjson.loads(message)
|
||||
else:
|
||||
data = message
|
||||
|
||||
@@ -126,7 +126,7 @@ class MessageChunker:
|
||||
and "__mmc_chunk_data__" in data
|
||||
and "__mmc_is_chunked__" in data
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
except (orjson.JSONDecodeError, TypeError):
|
||||
return False
|
||||
|
||||
|
||||
@@ -187,7 +187,7 @@ class MessageReassembler:
|
||||
try:
|
||||
# 统一转换为字典
|
||||
if isinstance(message, str):
|
||||
chunk_data = json.loads(message)
|
||||
chunk_data = orjson.loads(message)
|
||||
else:
|
||||
chunk_data = message
|
||||
|
||||
@@ -197,8 +197,8 @@ class MessageReassembler:
|
||||
if "_original_message" in chunk_data:
|
||||
# 这是一个被包装的非切片消息,解包返回
|
||||
try:
|
||||
return json.loads(chunk_data["_original_message"])
|
||||
except json.JSONDecodeError:
|
||||
return orjson.loads(chunk_data["_original_message"])
|
||||
except orjson.JSONDecodeError:
|
||||
return {"text_message": chunk_data["_original_message"]}
|
||||
else:
|
||||
return chunk_data
|
||||
@@ -251,14 +251,14 @@ class MessageReassembler:
|
||||
|
||||
# 尝试反序列化重组后的消息
|
||||
try:
|
||||
return json.loads(reassembled_message)
|
||||
except json.JSONDecodeError:
|
||||
return orjson.loads(reassembled_message)
|
||||
except orjson.JSONDecodeError:
|
||||
# 如果不能反序列化为JSON,则作为文本消息返回
|
||||
return {"text_message": reassembled_message}
|
||||
|
||||
return None
|
||||
|
||||
except (json.JSONDecodeError, KeyError, TypeError) as e:
|
||||
except (orjson.JSONDecodeError, KeyError, TypeError) as e:
|
||||
logger.error(f"处理切片消息时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import base64
|
||||
import json
|
||||
import orjson
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
@@ -783,11 +783,11 @@ class MessageHandler:
|
||||
# 检查JSON消息格式
|
||||
if not message_data or "data" not in message_data:
|
||||
logger.warning("JSON消息格式不正确")
|
||||
return Seg(type="json", data=json.dumps(message_data))
|
||||
return Seg(type="json", data=orjson.dumps(message_data).decode('utf-8'))
|
||||
|
||||
try:
|
||||
# 尝试将json_data解析为Python对象
|
||||
nested_data = json.loads(json_data)
|
||||
nested_data = orjson.loads(json_data)
|
||||
|
||||
# 检查是否是机器人自己上传文件的回声
|
||||
if self._is_file_upload_echo(nested_data):
|
||||
@@ -912,7 +912,7 @@ class MessageHandler:
|
||||
# 如果没有提取到关键信息,返回None
|
||||
return None
|
||||
|
||||
except json.JSONDecodeError:
|
||||
except orjson.JSONDecodeError:
|
||||
# 如果解析失败,我们假设它不是我们关心的任何一种结构化JSON,
|
||||
# 而是普通的文本或者无法解析的格式。
|
||||
logger.debug(f"无法将data字段解析为JSON: {json_data}")
|
||||
@@ -1146,13 +1146,13 @@ class MessageHandler:
|
||||
return None
|
||||
forward_message_id = forward_message_data.get("id")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
payload = orjson.dumps(
|
||||
{
|
||||
"action": "get_forward_msg",
|
||||
"params": {"message_id": forward_message_id},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
).decode('utf-8')
|
||||
try:
|
||||
connection = self.get_server_connection()
|
||||
if not connection:
|
||||
@@ -1167,9 +1167,9 @@ class MessageHandler:
|
||||
logger.error(f"获取转发消息失败: {str(e)}")
|
||||
return None
|
||||
logger.debug(
|
||||
f"转发消息原始格式:{json.dumps(response)[:80]}..."
|
||||
if len(json.dumps(response)) > 80
|
||||
else json.dumps(response)
|
||||
f"转发消息原始格式:{orjson.dumps(response).decode('utf-8')[:80]}..."
|
||||
if len(orjson.dumps(response).decode('utf-8')) > 80
|
||||
else orjson.dumps(response).decode('utf-8')
|
||||
)
|
||||
response_data: Dict = response.get("data")
|
||||
if not response_data:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import orjson
|
||||
import time
|
||||
from typing import ClassVar, Optional, Tuple
|
||||
|
||||
@@ -241,7 +241,7 @@ class NoticeHandler:
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=handled_message,
|
||||
raw_message=json.dumps(raw_message),
|
||||
raw_message=orjson.dumps(raw_message).decode('utf-8'),
|
||||
)
|
||||
|
||||
if system_notice:
|
||||
@@ -602,7 +602,7 @@ class NoticeHandler:
|
||||
message_base: MessageBase = MessageBase(
|
||||
message_info=message_info,
|
||||
message_segment=seg_message,
|
||||
raw_message=json.dumps(
|
||||
raw_message=orjson.dumps(
|
||||
{
|
||||
"post_type": "notice",
|
||||
"notice_type": "group_ban",
|
||||
@@ -611,7 +611,7 @@ class NoticeHandler:
|
||||
"user_id": user_id,
|
||||
"operator_id": None, # 自然解除禁言没有操作者
|
||||
}
|
||||
),
|
||||
).decode('utf-8'),
|
||||
)
|
||||
|
||||
await self.put_notice(message_base)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import orjson
|
||||
import random
|
||||
import time
|
||||
import random
|
||||
import websockets as Server
|
||||
@@ -603,7 +604,7 @@ class SendHandler:
|
||||
|
||||
async def send_message_to_napcat(self, action: str, params: dict, timeout: float = 20.0) -> dict:
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": action, "params": params, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": action, "params": params, "echo": request_uuid}).decode('utf-8')
|
||||
|
||||
# 获取当前连接
|
||||
connection = self.get_server_connection()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
import io
|
||||
import json
|
||||
import orjson
|
||||
import ssl
|
||||
import uuid
|
||||
from typing import List, Optional, Tuple, Union
|
||||
@@ -34,7 +34,7 @@ async def get_group_info(websocket: Server.ServerConnection, group_id: int) -> d
|
||||
"""
|
||||
logger.debug("获取群聊信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_group_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
@@ -56,7 +56,7 @@ async def get_group_detail_info(websocket: Server.ServerConnection, group_id: in
|
||||
"""
|
||||
logger.debug("获取群详细信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_group_detail_info", "params": {"group_id": group_id}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
@@ -78,13 +78,13 @@ async def get_member_info(websocket: Server.ServerConnection, group_id: int, use
|
||||
"""
|
||||
logger.debug("获取群成员信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
payload = orjson.dumps(
|
||||
{
|
||||
"action": "get_group_member_info",
|
||||
"params": {"group_id": group_id, "user_id": user_id, "no_cache": True},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
socket_response: dict = await get_response(request_uuid)
|
||||
@@ -146,7 +146,7 @@ async def get_self_info(websocket: Server.ServerConnection) -> dict | None:
|
||||
"""
|
||||
logger.debug("获取自身信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_login_info", "params": {}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
@@ -183,7 +183,7 @@ async def get_stranger_info(websocket: Server.ServerConnection, user_id: int) ->
|
||||
"""
|
||||
logger.debug("获取陌生人信息中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_stranger_info", "params": {"user_id": user_id}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid)
|
||||
@@ -208,7 +208,7 @@ async def get_message_detail(websocket: Server.ServerConnection, message_id: Uni
|
||||
"""
|
||||
logger.debug("获取消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid})
|
||||
payload = orjson.dumps({"action": "get_msg", "params": {"message_id": message_id}, "echo": request_uuid}).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
@@ -236,13 +236,13 @@ async def get_record_detail(
|
||||
"""
|
||||
logger.debug("获取语音消息详情中")
|
||||
request_uuid = str(uuid.uuid4())
|
||||
payload = json.dumps(
|
||||
payload = orjson.dumps(
|
||||
{
|
||||
"action": "get_record",
|
||||
"params": {"file": file, "file_id": file_id, "out_format": "wav"},
|
||||
"echo": request_uuid,
|
||||
}
|
||||
)
|
||||
).decode('utf-8')
|
||||
try:
|
||||
await websocket.send(payload)
|
||||
response: dict = await get_response(request_uuid, 30) # 增加超时时间到30秒
|
||||
|
||||
@@ -39,15 +39,23 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
return self.api_manager.is_available()
|
||||
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Exa搜索"""
|
||||
"""执行优化的Exa搜索(使用answer模式)"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
query = args["query"]
|
||||
num_results = args.get("num_results", 3)
|
||||
num_results = min(args.get("num_results", 5), 5) # 默认5个结果,但限制最多5个
|
||||
time_range = args.get("time_range", "any")
|
||||
|
||||
exa_args = {"num_results": num_results, "text": True, "highlights": True}
|
||||
# 优化的搜索参数 - 更注重答案质量
|
||||
exa_args = {
|
||||
"num_results": num_results,
|
||||
"text": True,
|
||||
"highlights": True,
|
||||
"summary": True, # 启用自动摘要
|
||||
}
|
||||
|
||||
# 时间范围过滤
|
||||
if time_range != "any":
|
||||
today = datetime.now()
|
||||
start_date = today - timedelta(days=7 if time_range == "week" else 30)
|
||||
@@ -61,18 +69,89 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
return []
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
# 使用search_and_contents获取完整内容,优化为answer模式
|
||||
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
return [
|
||||
{
|
||||
# 优化结果处理 - 更注重答案质量
|
||||
results = []
|
||||
for res in search_response.results:
|
||||
# 获取最佳内容片段
|
||||
highlights = getattr(res, "highlights", [])
|
||||
summary = getattr(res, "summary", "")
|
||||
text = getattr(res, "text", "")
|
||||
|
||||
# 智能内容选择:摘要 > 高亮 > 文本开头
|
||||
if summary and len(summary) > 50:
|
||||
snippet = summary.strip()
|
||||
elif highlights:
|
||||
snippet = " ".join(highlights).strip()
|
||||
elif text:
|
||||
snippet = text[:300] + "..." if len(text) > 300 else text
|
||||
else:
|
||||
snippet = "内容获取失败"
|
||||
|
||||
# 只保留有意义的摘要
|
||||
if len(snippet) < 30:
|
||||
snippet = text[:200] + "..." if text and len(text) > 200 else snippet
|
||||
|
||||
results.append({
|
||||
"title": res.title,
|
||||
"url": res.url,
|
||||
"snippet": " ".join(getattr(res, "highlights", [])) or (getattr(res, "text", "")[:250] + "..."),
|
||||
"snippet": snippet,
|
||||
"provider": "Exa",
|
||||
}
|
||||
for res in search_response.results
|
||||
]
|
||||
"answer_focused": True, # 标记为答案导向的搜索
|
||||
})
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Exa 搜索失败: {e}")
|
||||
logger.error(f"Exa answer模式搜索失败: {e}")
|
||||
return []
|
||||
|
||||
async def answer_search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行Exa快速答案搜索 - 最精简的搜索模式"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
query = args["query"]
|
||||
num_results = min(args.get("num_results", 3), 3) # answer模式默认3个结果,专注质量
|
||||
|
||||
# 精简的搜索参数 - 专注快速答案
|
||||
exa_args = {
|
||||
"num_results": num_results,
|
||||
"text": False, # 不需要全文
|
||||
"highlights": True, # 只要关键高亮
|
||||
"summary": True, # 优先摘要
|
||||
}
|
||||
|
||||
try:
|
||||
exa_client = self.api_manager.get_next_client()
|
||||
if not exa_client:
|
||||
return []
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(exa_client.search_and_contents, query, **exa_args)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
# 极简结果处理 - 只保留最核心信息
|
||||
results = []
|
||||
for res in search_response.results:
|
||||
summary = getattr(res, "summary", "")
|
||||
highlights = getattr(res, "highlights", [])
|
||||
|
||||
# 优先使用摘要,否则使用高亮
|
||||
answer_text = summary.strip() if summary and len(summary) > 30 else " ".join(highlights).strip()
|
||||
|
||||
if answer_text and len(answer_text) > 20:
|
||||
results.append({
|
||||
"title": res.title,
|
||||
"url": res.url,
|
||||
"snippet": answer_text[:400] + "..." if len(answer_text) > 400 else answer_text,
|
||||
"provider": "Exa-Answer",
|
||||
"answer_mode": True # 标记为纯答案模式
|
||||
})
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Exa快速答案搜索失败: {e}")
|
||||
return []
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
Metaso Search Engine (Chat Completions Mode)
|
||||
"""
|
||||
import json
|
||||
import orjson
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
@@ -43,12 +43,12 @@ class MetasoClient:
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
data = orjson.loads(data_str)
|
||||
delta = data.get("choices", [{}])[0].get("delta", {})
|
||||
content_chunk = delta.get("content")
|
||||
if content_chunk:
|
||||
full_response_content += content_chunk
|
||||
except json.JSONDecodeError:
|
||||
except orjson.JSONDecodeError:
|
||||
logger.warning(f"Metaso stream: could not decode JSON line: {data_str}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -41,6 +41,13 @@ class WebSurfingTool(BaseTool):
|
||||
False,
|
||||
["any", "week", "month"],
|
||||
),
|
||||
(
|
||||
"answer_mode",
|
||||
ToolParamType.BOOLEAN,
|
||||
"是否启用答案模式(仅适用于Exa搜索引擎)。启用后将返回更精简、直接的答案,减少冗余信息。默认为False。",
|
||||
False,
|
||||
None,
|
||||
),
|
||||
] # type: ignore
|
||||
|
||||
def __init__(self, plugin_config=None, chat_stream=None):
|
||||
@@ -97,13 +104,19 @@ class WebSurfingTool(BaseTool):
|
||||
) -> dict[str, Any]:
|
||||
"""并行搜索策略:同时使用所有启用的搜索引擎"""
|
||||
search_tasks = []
|
||||
answer_mode = function_args.get("answer_mode", False)
|
||||
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if engine and engine.is_available():
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
search_tasks.append(engine.search(custom_args))
|
||||
|
||||
# 如果启用了answer模式且是Exa引擎,使用answer_search方法
|
||||
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
|
||||
search_tasks.append(engine.answer_search(custom_args))
|
||||
else:
|
||||
search_tasks.append(engine.search(custom_args))
|
||||
|
||||
if not search_tasks:
|
||||
|
||||
@@ -137,17 +150,23 @@ class WebSurfingTool(BaseTool):
|
||||
self, function_args: dict[str, Any], enabled_engines: list[str]
|
||||
) -> dict[str, Any]:
|
||||
"""回退搜索策略:按顺序尝试搜索引擎,失败则尝试下一个"""
|
||||
answer_mode = function_args.get("answer_mode", False)
|
||||
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
|
||||
continue
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
# 如果启用了answer模式且是Exa引擎,使用answer_search方法
|
||||
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
|
||||
logger.info("使用Exa答案模式进行搜索(fallback策略)")
|
||||
results = await engine.answer_search(custom_args)
|
||||
else:
|
||||
results = await engine.search(custom_args)
|
||||
|
||||
if results: # 如果有结果,直接返回
|
||||
formatted_content = format_search_results(results)
|
||||
@@ -164,22 +183,30 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
async def _execute_single_search(self, function_args: dict[str, Any], enabled_engines: list[str]) -> dict[str, Any]:
|
||||
"""单一搜索策略:只使用第一个可用的搜索引擎"""
|
||||
answer_mode = function_args.get("answer_mode", False)
|
||||
|
||||
for engine_name in enabled_engines:
|
||||
engine = self.engines.get(engine_name)
|
||||
if not engine or not engine.is_available():
|
||||
|
||||
continue
|
||||
|
||||
try:
|
||||
custom_args = function_args.copy()
|
||||
custom_args["num_results"] = custom_args.get("num_results", 5)
|
||||
|
||||
results = await engine.search(custom_args)
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
# 如果启用了answer模式且是Exa引擎,使用answer_search方法
|
||||
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
|
||||
logger.info("使用Exa答案模式进行搜索")
|
||||
results = await engine.answer_search(custom_args)
|
||||
else:
|
||||
results = await engine.search(custom_args)
|
||||
|
||||
if results:
|
||||
formatted_content = format_search_results(results)
|
||||
return {
|
||||
"type": "web_search_result",
|
||||
"content": formatted_content,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{engine_name} 搜索失败: {e}")
|
||||
|
||||
Reference in New Issue
Block a user