Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev
This commit is contained in:
10
bot.py
10
bot.py
@@ -296,7 +296,7 @@ class DatabaseManager:
|
||||
# 使用线程执行器运行潜在的阻塞操作
|
||||
await initialize_sql_database()
|
||||
elapsed_time = time.time() - start_time
|
||||
|
||||
|
||||
db_type = global_config.database.database_type if global_config else "unknown"
|
||||
logger.info(
|
||||
f"数据库连接初始化成功,使用 {db_type} 数据库,耗时: {elapsed_time:.2f}秒"
|
||||
@@ -610,9 +610,9 @@ async def wait_for_user_input():
|
||||
try:
|
||||
if os.getenv("ENVIRONMENT") != "production":
|
||||
logger.info("程序执行完成,按 Ctrl+C 退出...")
|
||||
# 使用非阻塞循环
|
||||
while True:
|
||||
await asyncio.sleep(0.1)
|
||||
# 使用 asyncio.Event 而不是 sleep 循环
|
||||
shutdown_event = asyncio.Event()
|
||||
await shutdown_event.wait()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断程序")
|
||||
return True
|
||||
@@ -652,7 +652,7 @@ async def main_async():
|
||||
user_input_done = asyncio.create_task(wait_for_user_input())
|
||||
|
||||
# 使用wait等待任意一个任务完成
|
||||
done, pending = await asyncio.wait([main_task, user_input_done], return_when=asyncio.FIRST_COMPLETED)
|
||||
done, _pending = await asyncio.wait([main_task, user_input_done], return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
# 如果用户输入任务完成(用户按了Ctrl+C),取消主任务
|
||||
if user_input_done in done and main_task not in done:
|
||||
|
||||
@@ -574,6 +574,126 @@ class CacheManager:
|
||||
}
|
||||
}
|
||||
|
||||
async def recall_relevant_cache(
|
||||
self,
|
||||
query_text: str,
|
||||
tool_name: str | None = None,
|
||||
top_k: int = 3,
|
||||
similarity_threshold: float = 0.70,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
根据语义相似度主动召回相关的缓存条目
|
||||
|
||||
用于在回复前扫描缓存,找到与当前对话相关的历史搜索结果
|
||||
|
||||
Args:
|
||||
query_text: 用于语义匹配的查询文本(通常是最近几条聊天内容)
|
||||
tool_name: 可选,限制只召回特定工具的缓存(如 "web_search")
|
||||
top_k: 返回的最大结果数
|
||||
similarity_threshold: 相似度阈值(L2距离,越小越相似)
|
||||
|
||||
Returns:
|
||||
相关缓存条目列表,每个条目包含 {tool_name, query, content, similarity}
|
||||
"""
|
||||
if not query_text or not self.embedding_model:
|
||||
return []
|
||||
|
||||
try:
|
||||
# 生成查询向量
|
||||
embedding_result = await self.embedding_model.get_embedding(query_text)
|
||||
if not embedding_result:
|
||||
return []
|
||||
|
||||
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
|
||||
validated_embedding = self._validate_embedding(embedding_vector)
|
||||
if validated_embedding is None:
|
||||
return []
|
||||
|
||||
query_embedding = np.array([validated_embedding], dtype="float32")
|
||||
|
||||
# 从 L2 向量数据库查询
|
||||
results = vector_db_service.query(
|
||||
collection_name=self.semantic_cache_collection_name,
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=top_k * 2, # 多取一些,后面会过滤
|
||||
)
|
||||
|
||||
if not results or not results.get("ids") or not results["ids"][0]:
|
||||
logger.debug("[缓存召回] 未找到相关缓存")
|
||||
return []
|
||||
|
||||
recalled_items = []
|
||||
ids = results["ids"][0] if isinstance(results["ids"][0], list) else [results["ids"][0]]
|
||||
distances = results.get("distances", [[]])[0] if results.get("distances") else []
|
||||
|
||||
for i, cache_key in enumerate(ids):
|
||||
distance = distances[i] if i < len(distances) else 1.0
|
||||
|
||||
# 过滤相似度不够的
|
||||
if distance > similarity_threshold:
|
||||
continue
|
||||
|
||||
# 从数据库获取缓存数据
|
||||
cache_obj = await db_query(
|
||||
model_class=CacheEntries,
|
||||
query_type="get",
|
||||
filters={"cache_key": cache_key},
|
||||
single_result=True,
|
||||
)
|
||||
|
||||
if not cache_obj:
|
||||
continue
|
||||
|
||||
# 检查是否过期
|
||||
expires_at = getattr(cache_obj, "expires_at", 0)
|
||||
if time.time() >= expires_at:
|
||||
continue
|
||||
|
||||
# 获取工具名称并过滤
|
||||
cached_tool_name = getattr(cache_obj, "tool_name", "")
|
||||
if tool_name and cached_tool_name != tool_name:
|
||||
continue
|
||||
|
||||
# 解析缓存内容
|
||||
try:
|
||||
cache_value = getattr(cache_obj, "cache_value", "{}")
|
||||
data = orjson.loads(cache_value)
|
||||
content = data.get("content", "") if isinstance(data, dict) else str(data)
|
||||
|
||||
# 从 cache_key 中提取原始查询(格式: tool_name::{"query": "xxx", ...}::file_hash)
|
||||
original_query = ""
|
||||
try:
|
||||
key_parts = cache_key.split("::")
|
||||
if len(key_parts) >= 2:
|
||||
args_json = key_parts[1]
|
||||
args = orjson.loads(args_json)
|
||||
original_query = args.get("query", "")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
recalled_items.append({
|
||||
"tool_name": cached_tool_name,
|
||||
"query": original_query,
|
||||
"content": content,
|
||||
"similarity": 1.0 - distance, # 转换为相似度分数
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"解析缓存内容失败: {e}")
|
||||
continue
|
||||
|
||||
if len(recalled_items) >= top_k:
|
||||
break
|
||||
|
||||
if recalled_items:
|
||||
logger.info(f"[缓存召回] 找到 {len(recalled_items)} 条相关缓存")
|
||||
|
||||
return recalled_items
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[缓存召回] 语义召回失败: {e}")
|
||||
return []
|
||||
|
||||
|
||||
# 全局实例
|
||||
tool_cache = CacheManager()
|
||||
|
||||
@@ -32,15 +32,35 @@ class ToolCallRecord:
|
||||
"""后处理:生成结果预览"""
|
||||
if self.result and not self.result_preview:
|
||||
content = self.result.get("content", "")
|
||||
# 联网搜索等重要工具不截断结果
|
||||
no_truncate_tools = {"web_search", "web_surfing", "knowledge_search"}
|
||||
should_truncate = self.tool_name not in no_truncate_tools
|
||||
max_length = 500 if should_truncate else 10000 # 联网搜索给更大的限制
|
||||
|
||||
if isinstance(content, str):
|
||||
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
|
||||
if len(content) > max_length:
|
||||
self.result_preview = content[:max_length] + "..."
|
||||
else:
|
||||
self.result_preview = content
|
||||
elif isinstance(content, list | dict):
|
||||
try:
|
||||
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..."
|
||||
json_str = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
|
||||
if len(json_str) > max_length:
|
||||
self.result_preview = json_str[:max_length] + "..."
|
||||
else:
|
||||
self.result_preview = json_str
|
||||
except Exception:
|
||||
self.result_preview = str(content)[:500] + "..."
|
||||
str_content = str(content)
|
||||
if len(str_content) > max_length:
|
||||
self.result_preview = str_content[:max_length] + "..."
|
||||
else:
|
||||
self.result_preview = str_content
|
||||
else:
|
||||
self.result_preview = str(content)[:500] + "..."
|
||||
str_content = str(content)
|
||||
if len(str_content) > max_length:
|
||||
self.result_preview = str_content[:max_length] + "..."
|
||||
else:
|
||||
self.result_preview = str_content
|
||||
|
||||
|
||||
class StreamToolHistoryManager:
|
||||
|
||||
@@ -334,12 +334,46 @@ class KFCContextBuilder:
|
||||
|
||||
tool_executor = ToolExecutor(chat_id=self.chat_id)
|
||||
|
||||
# 首先获取当前的历史记录(在执行新工具调用之前)
|
||||
info_parts = []
|
||||
|
||||
# ========== 1. 主动召回联网搜索缓存 ==========
|
||||
try:
|
||||
from src.common.cache_manager import tool_cache
|
||||
|
||||
# 使用聊天历史作为语义查询
|
||||
query_text = chat_history if chat_history else target_message
|
||||
recalled_caches = await tool_cache.recall_relevant_cache(
|
||||
query_text=query_text,
|
||||
tool_name="web_search", # 只召回联网搜索的缓存
|
||||
top_k=2,
|
||||
similarity_threshold=0.65, # 相似度阈值
|
||||
)
|
||||
|
||||
if recalled_caches:
|
||||
recall_parts = ["### 🔍 相关的历史搜索结果"]
|
||||
for item in recalled_caches:
|
||||
original_query = item.get("query", "")
|
||||
content = item.get("content", "")
|
||||
similarity = item.get("similarity", 0)
|
||||
if content:
|
||||
# 截断过长的内容
|
||||
if len(content) > 500:
|
||||
content = content[:500] + "..."
|
||||
recall_parts.append(f"**搜索「{original_query}」** (相关度:{similarity:.0%})\n{content}")
|
||||
|
||||
info_parts.append("\n\n".join(recall_parts))
|
||||
logger.info(f"[缓存召回] 召回了 {len(recalled_caches)} 条相关搜索缓存")
|
||||
except Exception as e:
|
||||
logger.debug(f"[缓存召回] 召回失败(非关键): {e}")
|
||||
|
||||
# ========== 2. 获取工具调用历史 ==========
|
||||
tool_history_str = tool_executor.history_manager.format_for_prompt(
|
||||
max_records=3, include_results=True
|
||||
)
|
||||
if tool_history_str:
|
||||
info_parts.append(tool_history_str)
|
||||
|
||||
# 然后执行工具调用
|
||||
# ========== 3. 执行工具调用 ==========
|
||||
tool_results, _, _ = await tool_executor.execute_from_chat_message(
|
||||
sender=sender_name,
|
||||
target_message=target_message,
|
||||
@@ -347,12 +381,6 @@ class KFCContextBuilder:
|
||||
return_details=False,
|
||||
)
|
||||
|
||||
info_parts = []
|
||||
|
||||
# 显示之前的工具调用历史(不包括当前这次调用)
|
||||
if tool_history_str:
|
||||
info_parts.append(tool_history_str)
|
||||
|
||||
# 显示当前工具调用的结果(简要信息)
|
||||
if tool_results:
|
||||
current_results_parts = ["### 🔧 刚获取的工具信息"]
|
||||
|
||||
@@ -39,7 +39,7 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
return self.api_manager.is_available()
|
||||
|
||||
async def search(self, args: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
"""执行优化的Exa搜索(使用新的search API)"""
|
||||
"""执行优化的Exa搜索(使用 search_and_contents API)"""
|
||||
if not self.is_available():
|
||||
return []
|
||||
|
||||
@@ -47,13 +47,12 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
num_results = min(args.get("num_results", 5), 5) # 默认5个结果,但限制最多5个
|
||||
time_range = args.get("time_range", "any")
|
||||
|
||||
# 使用新的搜索参数格式
|
||||
# 使用 search_and_contents 的参数格式
|
||||
exa_args = {
|
||||
"query": query,
|
||||
"num_results": num_results,
|
||||
"contents": {
|
||||
"text": True,
|
||||
"summary": True, # 启用自动摘要
|
||||
},
|
||||
"type": "auto",
|
||||
"highlights": True, # 获取高亮片段
|
||||
}
|
||||
|
||||
# 时间范围过滤
|
||||
@@ -70,20 +69,20 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
return []
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
# 使用新的search方法
|
||||
func = functools.partial(exa_client.search, query, **exa_args)
|
||||
# 使用 search_and_contents 方法
|
||||
func = functools.partial(exa_client.search_and_contents, **exa_args)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
# 优化结果处理 - 更注重答案质量
|
||||
results = []
|
||||
for res in search_response.results:
|
||||
# 获取最佳内容片段
|
||||
summary = getattr(res, "summary", "")
|
||||
# 获取高亮内容或文本
|
||||
highlights = getattr(res, "highlights", [])
|
||||
text = getattr(res, "text", "")
|
||||
|
||||
# 智能内容选择:摘要 > 文本开头
|
||||
if summary and len(summary) > 50:
|
||||
snippet = summary.strip()
|
||||
# 智能内容选择:高亮 > 文本开头
|
||||
if highlights and len(highlights) > 0:
|
||||
snippet = " ".join(highlights[:3]).strip()
|
||||
elif text:
|
||||
snippet = text[:300] + "..." if len(text) > 300 else text
|
||||
else:
|
||||
@@ -114,13 +113,12 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
query = args["query"]
|
||||
num_results = min(args.get("num_results", 3), 3) # answer模式默认3个结果,专注质量
|
||||
|
||||
# 精简的搜索参数 - 专注快速答案
|
||||
# 精简的搜索参数 - 使用 search_and_contents
|
||||
exa_args = {
|
||||
"query": query,
|
||||
"num_results": num_results,
|
||||
"contents": {
|
||||
"text": False, # 不需要全文
|
||||
"summary": True, # 优先摘要
|
||||
},
|
||||
"type": "auto",
|
||||
"highlights": True,
|
||||
}
|
||||
|
||||
try:
|
||||
@@ -129,16 +127,16 @@ class ExaSearchEngine(BaseSearchEngine):
|
||||
return []
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
func = functools.partial(exa_client.search, query, **exa_args)
|
||||
func = functools.partial(exa_client.search_and_contents, **exa_args)
|
||||
search_response = await loop.run_in_executor(None, func)
|
||||
|
||||
# 极简结果处理 - 只保留最核心信息
|
||||
results = []
|
||||
for res in search_response.results:
|
||||
summary = getattr(res, "summary", "")
|
||||
highlights = getattr(res, "highlights", [])
|
||||
|
||||
# 使用摘要作为答案
|
||||
answer_text = summary.strip() if summary else ""
|
||||
# 使用高亮作为答案
|
||||
answer_text = " ".join(highlights[:2]).strip() if highlights else ""
|
||||
|
||||
if answer_text and len(answer_text) > 20:
|
||||
results.append({
|
||||
|
||||
@@ -28,7 +28,12 @@ class WebSurfingTool(BaseTool):
|
||||
|
||||
name: str = "web_search"
|
||||
description: str = (
|
||||
"用于执行网络搜索。当用户明确要求搜索,或者需要获取关于公司、产品、事件的最新信息、新闻或动态时,必须使用此工具"
|
||||
"联网搜索工具。使用场景:\n"
|
||||
"1. 用户问的问题你不确定答案、需要验证\n"
|
||||
"2. 涉及最新信息(新闻、产品、事件、时效性内容)\n"
|
||||
"3. 需要查找具体数据、事实、定义\n"
|
||||
"4. 用户明确要求搜索\n"
|
||||
"不要担心调用频率,搜索结果会被缓存。"
|
||||
)
|
||||
available_for_llm: bool = True
|
||||
parameters: ClassVar[list] = [
|
||||
|
||||
Reference in New Issue
Block a user