refactor: 优化知识库搜索逻辑,移除冗余代码
This commit is contained in:
@@ -38,125 +38,22 @@ class SearchKnowledgeFromLPMMTool(BaseTool):
|
|||||||
# threshold = function_args.get("threshold", 0.4)
|
# threshold = function_args.get("threshold", 0.4)
|
||||||
|
|
||||||
# 调用知识库搜索
|
# 调用知识库搜索
|
||||||
embedding = await get_embedding(query, request_type="info_retrieval")
|
|
||||||
if embedding:
|
knowledge_info = qa_manager.get_knowledge(query)
|
||||||
knowledge_info = qa_manager.get_knowledge(query)
|
|
||||||
logger.debug(f"知识库查询结果: {knowledge_info}")
|
logger.debug(f"知识库查询结果: {knowledge_info}")
|
||||||
if knowledge_info:
|
|
||||||
content = f"你知道这些知识: {knowledge_info}"
|
if knowledge_info:
|
||||||
else:
|
content = f"你知道这些知识: {knowledge_info}"
|
||||||
content = f"你不太了解有关{query}的知识"
|
else:
|
||||||
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
content = f"你不太了解有关{query}的知识"
|
||||||
# 如果获取嵌入失败
|
return {"type": "lpmm_knowledge", "id": query, "content": content}
|
||||||
return {"type": "info", "id": query, "content": f"无法获取关于'{query}'的嵌入向量,你lpmm知识库炸了"}
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
# 捕获异常并记录错误
|
||||||
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
logger.error(f"知识库搜索工具执行失败: {str(e)}")
|
||||||
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
|
# 在其他异常情况下,确保 id 仍然是 query (如果它被定义了)
|
||||||
query_id = query if "query" in locals() else "unknown_query"
|
query_id = query if "query" in locals() else "unknown_query"
|
||||||
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"}
|
return {"type": "info", "id": query_id, "content": f"lpmm知识库搜索失败,炸了: {str(e)}"}
|
||||||
|
|
||||||
# def get_info_from_db(
|
|
||||||
# self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
|
|
||||||
# ) -> Union[str, list]:
|
|
||||||
# """从数据库中获取相关信息
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# query_embedding: 查询的嵌入向量
|
|
||||||
# limit: 最大返回结果数
|
|
||||||
# threshold: 相似度阈值
|
|
||||||
# return_raw: 是否返回原始结果
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# Union[str, list]: 格式化的信息字符串或原始结果列表
|
|
||||||
# """
|
|
||||||
# if not query_embedding:
|
|
||||||
# return "" if not return_raw else []
|
|
||||||
|
|
||||||
# # 使用余弦相似度计算
|
|
||||||
# pipeline = [
|
|
||||||
# {
|
|
||||||
# "$addFields": {
|
|
||||||
# "dotProduct": {
|
|
||||||
# "$reduce": {
|
|
||||||
# "input": {"$range": [0, {"$size": "$embedding"}]},
|
|
||||||
# "initialValue": 0,
|
|
||||||
# "in": {
|
|
||||||
# "$add": [
|
|
||||||
# "$$value",
|
|
||||||
# {
|
|
||||||
# "$multiply": [
|
|
||||||
# {"$arrayElemAt": ["$embedding", "$$this"]},
|
|
||||||
# {"$arrayElemAt": [query_embedding, "$$this"]},
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# ]
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# "magnitude1": {
|
|
||||||
# "$sqrt": {
|
|
||||||
# "$reduce": {
|
|
||||||
# "input": "$embedding",
|
|
||||||
# "initialValue": 0,
|
|
||||||
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# "magnitude2": {
|
|
||||||
# "$sqrt": {
|
|
||||||
# "$reduce": {
|
|
||||||
# "input": query_embedding,
|
|
||||||
# "initialValue": 0,
|
|
||||||
# "in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# {"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
|
|
||||||
# {
|
|
||||||
# "$match": {
|
|
||||||
# "similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# {"$sort": {"similarity": -1}},
|
|
||||||
# {"$limit": limit},
|
|
||||||
# {"$project": {"content": 1, "similarity": 1}},
|
|
||||||
# ]
|
|
||||||
|
|
||||||
# results = list(db.knowledges.aggregate(pipeline))
|
|
||||||
# logger.debug(f"知识库查询结果数量: {len(results)}")
|
|
||||||
|
|
||||||
# if not results:
|
|
||||||
# return "" if not return_raw else []
|
|
||||||
|
|
||||||
# if return_raw:
|
|
||||||
# return results
|
|
||||||
# else:
|
|
||||||
# # 返回所有找到的内容,用换行分隔
|
|
||||||
# return "\n".join(str(result["content"]) for result in results)
|
|
||||||
|
|
||||||
def _format_results(self, results: list) -> str:
|
|
||||||
"""格式化结果"""
|
|
||||||
if not results:
|
|
||||||
return "未找到相关知识。"
|
|
||||||
|
|
||||||
formatted_string = "我找到了一些相关知识:\n"
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
# chunk_id = result.get("chunk_id")
|
|
||||||
text = result.get("text", "")
|
|
||||||
source = result.get("source", "未知来源")
|
|
||||||
source_type = result.get("source_type", "未知类型")
|
|
||||||
similarity = result.get("similarity", 0.0)
|
|
||||||
|
|
||||||
formatted_string += (
|
|
||||||
f"{i + 1}. (相似度: {similarity:.2f}) 类型: {source_type}, 来源: {source} \n内容片段: {text}\n\n"
|
|
||||||
)
|
|
||||||
# 暂时去掉chunk_id
|
|
||||||
# formatted_string += f"{i + 1}. (相似度: {similarity:.2f}) 类型: {source_type}, 来源: {source}, Chunk ID: {chunk_id} \n内容片段: {text}\n\n"
|
|
||||||
|
|
||||||
return formatted_string
|
|
||||||
|
|
||||||
|
|
||||||
# 注册工具
|
|
||||||
# register_tool(SearchKnowledgeTool)
|
|
||||||
|
|||||||
Reference in New Issue
Block a user