Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -10,8 +10,8 @@ import pandas as pd
|
||||
# import tqdm
|
||||
import faiss
|
||||
|
||||
from .llm_client import LLMClient
|
||||
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
|
||||
# from .llm_client import LLMClient
|
||||
from .lpmmconfig import global_config
|
||||
from .utils.hash import get_sha256
|
||||
from .global_logger import logger
|
||||
from rich.traceback import install
|
||||
@@ -25,6 +25,9 @@ from rich.progress import (
|
||||
SpinnerColumn,
|
||||
TextColumn,
|
||||
)
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.chat.utils.utils import get_embedding
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
@@ -86,21 +89,20 @@ class EmbeddingStoreItem:
|
||||
|
||||
|
||||
class EmbeddingStore:
|
||||
def __init__(self, llm_client: LLMClient, namespace: str, dir_path: str):
|
||||
def __init__(self, namespace: str, dir_path: str):
|
||||
self.namespace = namespace
|
||||
self.llm_client = llm_client
|
||||
self.dir = dir_path
|
||||
self.embedding_file_path = dir_path + "/" + namespace + ".parquet"
|
||||
self.index_file_path = dir_path + "/" + namespace + ".index"
|
||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
||||
|
||||
self.store = dict()
|
||||
self.store = {}
|
||||
|
||||
self.faiss_index = None
|
||||
self.idx2hash = None
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
||||
return get_embedding(s)
|
||||
|
||||
def get_test_file_path(self):
|
||||
return EMBEDDING_TEST_FILE
|
||||
@@ -293,20 +295,17 @@ class EmbeddingStore:
|
||||
|
||||
|
||||
class EmbeddingManager:
|
||||
def __init__(self, llm_client: LLMClient):
|
||||
def __init__(self):
|
||||
self.paragraphs_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
PG_NAMESPACE,
|
||||
local_storage['pg_namespace'],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.entities_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
ENT_NAMESPACE,
|
||||
local_storage['pg_namespace'],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.relation_embedding_store = EmbeddingStore(
|
||||
llm_client,
|
||||
REL_NAMESPACE,
|
||||
local_storage['pg_namespace'],
|
||||
EMBEDDING_DATA_DIR_STR,
|
||||
)
|
||||
self.stored_pg_hashes = set()
|
||||
|
||||
@@ -4,28 +4,35 @@ from typing import List, Union
|
||||
|
||||
from .global_logger import logger
|
||||
from . import prompt_template
|
||||
from .lpmmconfig import global_config, INVALID_ENTITY
|
||||
from .llm_client import LLMClient
|
||||
from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json
|
||||
from .knowledge_lib import INVALID_ENTITY
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from json_repair import repair_json
|
||||
def _extract_json_from_text(text: str) -> dict:
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
try:
|
||||
fixed_json = repair_json(text)
|
||||
if isinstance(fixed_json, str):
|
||||
parsed_json = json.loads(fixed_json)
|
||||
else:
|
||||
parsed_json = fixed_json
|
||||
|
||||
if isinstance(parsed_json, list) and parsed_json:
|
||||
parsed_json = parsed_json[0]
|
||||
|
||||
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
||||
if isinstance(parsed_json, dict):
|
||||
return parsed_json
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
|
||||
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
_, request_result = llm_client.send_chat_request(
|
||||
global_config["entity_extract"]["llm"]["model"], entity_extract_context
|
||||
)
|
||||
|
||||
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
||||
if "[" in request_result:
|
||||
request_result = request_result[request_result.index("[") :]
|
||||
|
||||
# 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
|
||||
if "]" in request_result:
|
||||
request_result = request_result[: request_result.rindex("]") + 1]
|
||||
|
||||
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
|
||||
|
||||
entity_extract_result = _extract_json_from_text(response)
|
||||
# 尝试load JSON数据
|
||||
json.loads(entity_extract_result)
|
||||
entity_extract_result = [
|
||||
entity
|
||||
for entity in entity_extract_result
|
||||
@@ -38,23 +45,16 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
||||
return entity_extract_result
|
||||
|
||||
|
||||
def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
|
||||
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
entity_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||
)
|
||||
_, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
|
||||
|
||||
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
||||
if "[" in request_result:
|
||||
request_result = request_result[request_result.index("[") :]
|
||||
|
||||
# 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
|
||||
if "]" in request_result:
|
||||
request_result = request_result[: request_result.rindex("]") + 1]
|
||||
|
||||
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
|
||||
|
||||
entity_extract_result = _extract_json_from_text(response)
|
||||
# 尝试load JSON数据
|
||||
json.loads(entity_extract_result)
|
||||
for triple in entity_extract_result:
|
||||
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||
raise Exception("RDF提取结果格式错误")
|
||||
@@ -63,7 +63,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -
|
||||
|
||||
|
||||
def info_extract_from_str(
|
||||
llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str
|
||||
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||
try_count = 0
|
||||
while True:
|
||||
|
||||
@@ -20,24 +20,37 @@ from quick_algo import di_graph, pagerank
|
||||
|
||||
from .utils.hash import get_sha256
|
||||
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
|
||||
from .lpmmconfig import (
|
||||
ENT_NAMESPACE,
|
||||
PG_NAMESPACE,
|
||||
RAG_ENT_CNT_NAMESPACE,
|
||||
RAG_GRAPH_NAMESPACE,
|
||||
RAG_PG_HASH_NAMESPACE,
|
||||
global_config,
|
||||
)
|
||||
from .lpmmconfig import global_config
|
||||
from src.manager.local_store_manager import local_storage
|
||||
|
||||
from .global_logger import logger
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
KG_DIR = (
|
||||
os.path.join(ROOT_PATH, "data/rag")
|
||||
if global_config["persistence"]["rag_data_dir"] is None
|
||||
else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"])
|
||||
)
|
||||
KG_DIR_STR = str(KG_DIR).replace("\\", "/")
|
||||
|
||||
def _get_kg_dir():
|
||||
"""
|
||||
安全地获取KG数据目录路径
|
||||
"""
|
||||
root_path = local_storage['root_path']
|
||||
if root_path is None:
|
||||
# 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
||||
logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}")
|
||||
|
||||
# 获取RAG数据目录
|
||||
rag_data_dir = global_config["persistence"]["rag_data_dir"]
|
||||
if rag_data_dir is None:
|
||||
kg_dir = os.path.join(root_path, "data/rag")
|
||||
else:
|
||||
kg_dir = os.path.join(root_path, rag_data_dir)
|
||||
|
||||
return str(kg_dir).replace("\\", "/")
|
||||
|
||||
|
||||
# 延迟初始化,避免在模块加载时就访问可能未初始化的 local_storage
|
||||
def get_kg_dir_str():
|
||||
"""获取KG目录字符串"""
|
||||
return _get_kg_dir()
|
||||
|
||||
|
||||
class KGManager:
|
||||
@@ -46,15 +59,15 @@ class KGManager:
|
||||
# 存储段落的hash值,用于去重
|
||||
self.stored_paragraph_hashes = set()
|
||||
# 实体出现次数
|
||||
self.ent_appear_cnt = dict()
|
||||
self.ent_appear_cnt = {}
|
||||
# KG
|
||||
self.graph = di_graph.DiGraph()
|
||||
|
||||
# 持久化相关
|
||||
self.dir_path = KG_DIR_STR
|
||||
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
|
||||
# 持久化相关 - 使用延迟初始化的路径
|
||||
self.dir_path = get_kg_dir_str()
|
||||
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml"
|
||||
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet"
|
||||
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json"
|
||||
|
||||
def save_to_file(self):
|
||||
"""将KG数据保存到文件"""
|
||||
@@ -109,8 +122,8 @@ class KGManager:
|
||||
# 避免自连接
|
||||
continue
|
||||
# 一个triple就是一条边(同时构建双向联系)
|
||||
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
||||
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
|
||||
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
||||
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2])
|
||||
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
||||
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
||||
entity_set.add(hash_key1)
|
||||
@@ -128,8 +141,8 @@ class KGManager:
|
||||
"""构建实体节点与文段节点之间的关系"""
|
||||
for idx in triple_list_data:
|
||||
for triple in triple_list_data[idx]:
|
||||
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
|
||||
pg_hash_key = PG_NAMESPACE + "-" + str(idx)
|
||||
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
||||
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx)
|
||||
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
||||
|
||||
@staticmethod
|
||||
@@ -144,8 +157,8 @@ class KGManager:
|
||||
ent_hash_list = set()
|
||||
for triple_list in triple_list_data.values():
|
||||
for triple in triple_list:
|
||||
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0]))
|
||||
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2]))
|
||||
ent_hash_list = list(ent_hash_list)
|
||||
|
||||
synonym_hash_set = set()
|
||||
@@ -250,7 +263,7 @@ class KGManager:
|
||||
for src_tgt in node_to_node.keys():
|
||||
for node_hash in src_tgt:
|
||||
if node_hash not in existed_nodes:
|
||||
if node_hash.startswith(ENT_NAMESPACE):
|
||||
if node_hash.startswith(local_storage['ent_namespace']):
|
||||
# 新增实体节点
|
||||
node = embedding_manager.entities_embedding_store.store[node_hash]
|
||||
assert isinstance(node, EmbeddingStoreItem)
|
||||
@@ -259,7 +272,7 @@ class KGManager:
|
||||
node_item["type"] = "ent"
|
||||
node_item["create_time"] = now_time
|
||||
self.graph.update_node(node_item)
|
||||
elif node_hash.startswith(PG_NAMESPACE):
|
||||
elif node_hash.startswith(local_storage['pg_namespace']):
|
||||
# 新增文段节点
|
||||
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
|
||||
assert isinstance(node, EmbeddingStoreItem)
|
||||
@@ -340,7 +353,7 @@ class KGManager:
|
||||
# 关系三元组
|
||||
triple = relation[2:-2].split("', '")
|
||||
for ent in [(triple[0]), (triple[2])]:
|
||||
ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent)
|
||||
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent)
|
||||
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
||||
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
||||
ent_sim_scores[ent_hash] = []
|
||||
@@ -418,7 +431,7 @@ class KGManager:
|
||||
# 获取最终结果
|
||||
# 从搜索结果中提取文段节点的结果
|
||||
passage_node_res = [
|
||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE)
|
||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace'])
|
||||
]
|
||||
del ppr_res
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
|
||||
from src.chat.knowledge.lpmmconfig import global_config
|
||||
from src.chat.knowledge.embedding_store import EmbeddingManager
|
||||
from src.chat.knowledge.llm_client import LLMClient
|
||||
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
|
||||
@@ -6,10 +6,80 @@ from src.chat.knowledge.qa_manager import QAManager
|
||||
from src.chat.knowledge.kg_manager import KGManager
|
||||
from src.chat.knowledge.global_logger import logger
|
||||
from src.config.config import global_config as bot_global_config
|
||||
# try:
|
||||
# import quick_algo
|
||||
# except ImportError:
|
||||
# print("quick_algo not found, please install it first")
|
||||
from src.manager.local_store_manager import local_storage
|
||||
import os
|
||||
|
||||
INVALID_ENTITY = [
|
||||
"",
|
||||
"你",
|
||||
"他",
|
||||
"她",
|
||||
"它",
|
||||
"我们",
|
||||
"你们",
|
||||
"他们",
|
||||
"她们",
|
||||
"它们",
|
||||
]
|
||||
PG_NAMESPACE = "paragraph"
|
||||
ENT_NAMESPACE = "entity"
|
||||
REL_NAMESPACE = "relation"
|
||||
|
||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
||||
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||
|
||||
def _initialize_knowledge_local_storage():
|
||||
"""
|
||||
初始化知识库相关的本地存储配置
|
||||
使用字典批量设置,避免重复的if判断
|
||||
"""
|
||||
# 定义所有需要初始化的配置项
|
||||
default_configs = {
|
||||
# 路径配置
|
||||
'root_path': ROOT_PATH,
|
||||
'data_path': f"{ROOT_PATH}/data",
|
||||
|
||||
# 实体和命名空间配置
|
||||
'lpmm_invalid_entity': INVALID_ENTITY,
|
||||
'pg_namespace': PG_NAMESPACE,
|
||||
'ent_namespace': ENT_NAMESPACE,
|
||||
'rel_namespace': REL_NAMESPACE,
|
||||
|
||||
# RAG相关命名空间配置
|
||||
'rag_graph_namespace': RAG_GRAPH_NAMESPACE,
|
||||
'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE,
|
||||
'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE
|
||||
}
|
||||
|
||||
# 日志级别映射:重要配置用info,其他用debug
|
||||
important_configs = {'root_path', 'data_path'}
|
||||
|
||||
# 批量设置配置项
|
||||
initialized_count = 0
|
||||
for key, default_value in default_configs.items():
|
||||
if local_storage[key] is None:
|
||||
local_storage[key] = default_value
|
||||
|
||||
# 根据重要性选择日志级别
|
||||
if key in important_configs:
|
||||
logger.info(f"设置{key}: {default_value}")
|
||||
else:
|
||||
logger.debug(f"设置{key}: {default_value}")
|
||||
|
||||
initialized_count += 1
|
||||
|
||||
if initialized_count > 0:
|
||||
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
|
||||
else:
|
||||
logger.debug("知识库本地存储配置已存在,跳过初始化")
|
||||
|
||||
# 初始化本地存储路径
|
||||
_initialize_knowledge_local_storage()
|
||||
|
||||
# 检查LPMM知识库是否启用
|
||||
if bot_global_config.lpmm_knowledge.enable:
|
||||
@@ -23,7 +93,7 @@ if bot_global_config.lpmm_knowledge.enable:
|
||||
)
|
||||
|
||||
# 初始化Embedding库
|
||||
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
|
||||
embed_manager = EmbeddingManager()
|
||||
logger.info("正在从文件加载Embedding库")
|
||||
try:
|
||||
embed_manager.load_from_file()
|
||||
@@ -54,9 +124,6 @@ if bot_global_config.lpmm_knowledge.enable:
|
||||
qa_manager = QAManager(
|
||||
embed_manager,
|
||||
kg_manager,
|
||||
llm_client_list[global_config["embedding"]["provider"]],
|
||||
llm_client_list[global_config["qa"]["llm"]["provider"]],
|
||||
llm_client_list[global_config["qa"]["llm"]["provider"]],
|
||||
)
|
||||
|
||||
# 记忆激活(用于记忆库)
|
||||
|
||||
@@ -4,9 +4,8 @@ import glob
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
from .lpmmconfig import INVALID_ENTITY, global_config
|
||||
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
|
||||
# from src.manager.local_store_manager import local_storage
|
||||
|
||||
|
||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||
@@ -107,7 +106,7 @@ class OpenIE:
|
||||
@staticmethod
|
||||
def load() -> "OpenIE":
|
||||
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
|
||||
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
|
||||
openie_dir = os.path.join(DATA_PATH, "openie")
|
||||
if not os.path.exists(openie_dir):
|
||||
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
|
||||
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
||||
@@ -122,12 +121,6 @@ class OpenIE:
|
||||
openie_data = OpenIE._from_dict(data_list)
|
||||
return openie_data
|
||||
|
||||
@staticmethod
|
||||
def save(openie_data: "OpenIE"):
|
||||
"""保存OpenIE数据到文件"""
|
||||
with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
|
||||
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
|
||||
|
||||
def extract_entity_dict(self):
|
||||
"""提取实体列表"""
|
||||
ner_output_dict = dict(
|
||||
|
||||
@@ -5,11 +5,13 @@ from .global_logger import logger
|
||||
|
||||
# from . import prompt_template
|
||||
from .embedding_store import EmbeddingManager
|
||||
from .llm_client import LLMClient
|
||||
# from .llm_client import LLMClient
|
||||
from .kg_manager import KGManager
|
||||
from .lpmmconfig import global_config
|
||||
# from .lpmmconfig import global_config
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.config.config import global_config
|
||||
|
||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||
|
||||
@@ -19,26 +21,25 @@ class QAManager:
|
||||
self,
|
||||
embed_manager: EmbeddingManager,
|
||||
kg_manager: KGManager,
|
||||
llm_client_embedding: LLMClient,
|
||||
llm_client_filter: LLMClient,
|
||||
llm_client_qa: LLMClient,
|
||||
|
||||
):
|
||||
self.embed_manager = embed_manager
|
||||
self.kg_manager = kg_manager
|
||||
self.llm_client_list = {
|
||||
"embedding": llm_client_embedding,
|
||||
"message_filter": llm_client_filter,
|
||||
"qa": llm_client_qa,
|
||||
}
|
||||
# TODO: API-Adapter修改标记
|
||||
self.qa_model = LLMRequest(
|
||||
model=global_config.model.lpmm_qa,
|
||||
request_type="lpmm.qa"
|
||||
)
|
||||
|
||||
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
||||
"""处理查询"""
|
||||
|
||||
# 生成问题的Embedding
|
||||
part_start_time = time.perf_counter()
|
||||
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
|
||||
global_config["embedding"]["model"], question
|
||||
)
|
||||
question_embedding = get_embedding(question)
|
||||
if question_embedding is None:
|
||||
logger.error("生成问题Embedding失败")
|
||||
return None
|
||||
part_end_time = time.perf_counter()
|
||||
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
||||
|
||||
@@ -46,14 +47,15 @@ class QAManager:
|
||||
part_start_time = time.perf_counter()
|
||||
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
||||
question_embedding,
|
||||
global_config["qa"]["params"]["relation_search_top_k"],
|
||||
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
||||
)
|
||||
if relation_search_res is not None:
|
||||
# 过滤阈值
|
||||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||||
if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
|
||||
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
||||
# 未找到相关关系
|
||||
logger.debug("未找到相关关系,跳过关系检索")
|
||||
relation_search_res = []
|
||||
|
||||
part_end_time = time.perf_counter()
|
||||
@@ -71,7 +73,7 @@ class QAManager:
|
||||
part_start_time = time.perf_counter()
|
||||
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||||
question_embedding,
|
||||
global_config["qa"]["params"]["paragraph_search_top_k"],
|
||||
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
||||
)
|
||||
part_end_time = time.perf_counter()
|
||||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||||
|
||||
Reference in New Issue
Block a user