feat(context): 为大语言模型提供过去网页搜索的上下文记忆

此更改使聊天机器人能够记住并引用过去网页搜索的相关信息,从而显著提高响应质量和连贯性。

系统不再将每个查询视为孤立事件,而是在生成新响应之前,对之前的 `web_search` 结果缓存进行向量相似度搜索。如果发现过去的相关信息,会自动作为“相关历史搜索结果”注入到大语言模型的提示中。

这使模型能够立即访问相关背景信息,避免对已经讨论过的主题重复搜索。

为了支持这一新功能:
- 对 `web_search` 工具的提示进行了改写,以通过确保结果被高效缓存和调用,鼓励大语言模型更频繁地使用它。
- 重要工具结果(如网页搜索)的预览长度已增加
This commit is contained in:
tt-P607
2025-12-04 04:12:36 +08:00
parent f519f87884
commit 22767ce234
4 changed files with 186 additions and 13 deletions

View File

@@ -574,6 +574,126 @@ class CacheManager:
}
}
async def recall_relevant_cache(
self,
query_text: str,
tool_name: str | None = None,
top_k: int = 3,
similarity_threshold: float = 0.70,
) -> list[dict[str, Any]]:
"""
根据语义相似度主动召回相关的缓存条目
用于在回复前扫描缓存,找到与当前对话相关的历史搜索结果
Args:
query_text: 用于语义匹配的查询文本(通常是最近几条聊天内容)
tool_name: 可选,限制只召回特定工具的缓存(如 "web_search"
top_k: 返回的最大结果数
similarity_threshold: 相似度阈值L2距离越小越相似
Returns:
相关缓存条目列表,每个条目包含 {tool_name, query, content, similarity}
"""
if not query_text or not self.embedding_model:
return []
try:
# 生成查询向量
embedding_result = await self.embedding_model.get_embedding(query_text)
if not embedding_result:
return []
embedding_vector = embedding_result[0] if isinstance(embedding_result, tuple) else embedding_result
validated_embedding = self._validate_embedding(embedding_vector)
if validated_embedding is None:
return []
query_embedding = np.array([validated_embedding], dtype="float32")
# 从 L2 向量数据库查询
results = vector_db_service.query(
collection_name=self.semantic_cache_collection_name,
query_embeddings=query_embedding.tolist(),
n_results=top_k * 2, # 多取一些,后面会过滤
)
if not results or not results.get("ids") or not results["ids"][0]:
logger.debug("[缓存召回] 未找到相关缓存")
return []
recalled_items = []
ids = results["ids"][0] if isinstance(results["ids"][0], list) else [results["ids"][0]]
distances = results.get("distances", [[]])[0] if results.get("distances") else []
for i, cache_key in enumerate(ids):
distance = distances[i] if i < len(distances) else 1.0
# 过滤相似度不够的
if distance > similarity_threshold:
continue
# 从数据库获取缓存数据
cache_obj = await db_query(
model_class=CacheEntries,
query_type="get",
filters={"cache_key": cache_key},
single_result=True,
)
if not cache_obj:
continue
# 检查是否过期
expires_at = getattr(cache_obj, "expires_at", 0)
if time.time() >= expires_at:
continue
# 获取工具名称并过滤
cached_tool_name = getattr(cache_obj, "tool_name", "")
if tool_name and cached_tool_name != tool_name:
continue
# 解析缓存内容
try:
cache_value = getattr(cache_obj, "cache_value", "{}")
data = orjson.loads(cache_value)
content = data.get("content", "") if isinstance(data, dict) else str(data)
# 从 cache_key 中提取原始查询(格式: tool_name::{"query": "xxx", ...}::file_hash
original_query = ""
try:
key_parts = cache_key.split("::")
if len(key_parts) >= 2:
args_json = key_parts[1]
args = orjson.loads(args_json)
original_query = args.get("query", "")
except Exception:
pass
recalled_items.append({
"tool_name": cached_tool_name,
"query": original_query,
"content": content,
"similarity": 1.0 - distance, # 转换为相似度分数
})
except Exception as e:
logger.warning(f"解析缓存内容失败: {e}")
continue
if len(recalled_items) >= top_k:
break
if recalled_items:
logger.info(f"[缓存召回] 找到 {len(recalled_items)} 条相关缓存")
return recalled_items
except Exception as e:
logger.error(f"[缓存召回] 语义召回失败: {e}")
return []
# 全局实例
tool_cache = CacheManager()