Merge pull request #956 from MaiM-with-u/plugin

Plugin插件和工作记忆
This commit is contained in:
SengokuCola
2025-05-16 23:47:18 +08:00
committed by GitHub
62 changed files with 4048 additions and 1596 deletions

View File

@@ -1,8 +1,10 @@
from src.tools.tool_can_use.base_tool import BaseTool
from src.chat.utils.utils import get_embedding
from src.common.database import db
from src.common.database.database_model import Knowledges # Updated import
from src.common.logger_manager import get_logger
from typing import Any, Union
from typing import Any, Union, List # Added List
import json # Added for parsing embedding
import math # Added for cosine similarity
logger = get_logger("get_knowledge_tool")
@@ -30,6 +32,7 @@ class SearchKnowledgeTool(BaseTool):
Returns:
dict: 工具执行结果
"""
query = "" # Initialize query to ensure it's defined in except block
try:
query = function_args.get("query")
threshold = function_args.get("threshold", 0.4)
@@ -48,9 +51,19 @@ class SearchKnowledgeTool(BaseTool):
logger.error(f"知识库搜索工具执行失败: {str(e)}")
return {"type": "info", "id": query, "content": f"知识库搜索失败,炸了: {str(e)}"}
@staticmethod
def _cosine_similarity(vec1: List[float], vec2: List[float]) -> float:
"""计算两个向量之间的余弦相似度"""
dot_product = sum(p * q for p, q in zip(vec1, vec2))
magnitude1 = math.sqrt(sum(p * p for p in vec1))
magnitude2 = math.sqrt(sum(q * q for q in vec2))
if magnitude1 == 0 or magnitude2 == 0:
return 0.0
return dot_product / (magnitude1 * magnitude2)
@staticmethod
def get_info_from_db(
query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
query_embedding: list[float], limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
"""从数据库中获取相关信息
@@ -66,66 +79,51 @@ class SearchKnowledgeTool(BaseTool):
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
similar_items = []
try:
all_knowledges = Knowledges.select()
for item in all_knowledges:
try:
item_embedding_str = item.embedding
if not item_embedding_str:
logger.warning(f"Knowledge item ID {item.id} has empty embedding string.")
continue
item_embedding = json.loads(item_embedding_str)
if not isinstance(item_embedding, list) or not all(
isinstance(x, (int, float)) for x in item_embedding
):
logger.warning(f"Knowledge item ID {item.id} has invalid embedding format after JSON parsing.")
continue
except json.JSONDecodeError:
logger.warning(f"Failed to parse embedding for knowledge item ID {item.id}")
continue
except AttributeError:
logger.warning(f"Knowledge item ID {item.id} missing 'embedding' attribute or it's not a string.")
continue
results = list(db.knowledges.aggregate(pipeline))
logger.debug(f"知识库查询结果数量: {len(results)}")
similarity = SearchKnowledgeTool._cosine_similarity(query_embedding, item_embedding)
if similarity >= threshold:
similar_items.append({"content": item.content, "similarity": similarity, "raw_item": item})
# 按相似度降序排序
similar_items.sort(key=lambda x: x["similarity"], reverse=True)
# 应用限制
results = similar_items[:limit]
logger.debug(f"知识库查询后,符合条件的结果数量: {len(results)}")
except Exception as e:
logger.error(f"从 Peewee 数据库获取知识信息失败: {str(e)}")
return "" if not return_raw else []
if not results:
return "" if not return_raw else []
if return_raw:
return results
# Peewee 模型实例不能直接序列化为 JSON如果需要原始模型调用者需要处理
# 这里返回包含内容和相似度的字典列表
return [{"content": r["content"], "similarity": r["similarity"]} for r in results]
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)