This commit is contained in:
UnCLASPrommer
2025-07-15 17:05:34 +08:00
10 changed files with 265 additions and 252 deletions

View File

@@ -9,19 +9,17 @@ import os
from time import sleep
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.llm_client import LLMClient
from src.chat.knowledge.open_ie import OpenIE
from src.chat.knowledge.kg_manager import KGManager
from src.common.logger import get_logger
from src.chat.knowledge.utils.hash import get_sha256
from src.manager.local_store_manager import local_storage
# 添加项目根目录到 sys.path
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
OPENIE_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie")
OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
logger = get_logger("OpenIE导入")
@@ -63,7 +61,7 @@ def hash_deduplicate(
):
# 段落hash
paragraph_hash = get_sha256(raw_paragraph)
if f"{PG_NAMESPACE}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
if f"{local_storage['pg_namespace']}-{paragraph_hash}" in stored_pg_hashes and paragraph_hash in stored_paragraph_hashes:
continue
new_raw_paragraphs[paragraph_hash] = raw_paragraph
new_triple_list_data[paragraph_hash] = triple_list
@@ -193,15 +191,9 @@ def main(): # sourcery skip: dict-comprehension
logger.info("----开始导入openie数据----\n")
logger.info("创建LLM客户端")
llm_client_list = {}
for key in global_config["llm_providers"]:
llm_client_list[key] = LLMClient(
global_config["llm_providers"][key]["base_url"],
global_config["llm_providers"][key]["api_key"],
)
# 初始化Embedding库
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
embed_manager = EmbeddingManager()
logger.info("正在从文件加载Embedding库")
try:
embed_manager.load_from_file()
@@ -230,7 +222,7 @@ def main(): # sourcery skip: dict-comprehension
# 数据比对Embedding库与KG的段落hash集合
for pg_hash in kg_manager.stored_paragraph_hashes:
key = f"{PG_NAMESPACE}-{pg_hash}"
key = f"{local_storage['pg_namespace']}-{pg_hash}"
if key not in embed_manager.stored_pg_hashes:
logger.warning(f"KG中存在Embedding库中不存在的段落{key}")

View File

@@ -4,7 +4,6 @@ import signal
from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock, Event
import sys
import glob
import datetime
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@@ -13,11 +12,9 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_logger
from src.chat.knowledge.lpmmconfig import global_config
# from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.ie_process import info_extract_from_str
from src.chat.knowledge.llm_client import LLMClient
from src.chat.knowledge.open_ie import OpenIE
from src.chat.knowledge.raw_processing import load_raw_data
from rich.progress import (
BarColumn,
TimeElapsedColumn,
@@ -27,24 +24,17 @@ from rich.progress import (
SpinnerColumn,
TextColumn,
)
from raw_data_preprocessor import RAW_DATA_PATH, process_multi_files, load_raw_data
from src.config.config import global_config
from src.llm_models.utils_model import LLMRequest
logger = get_logger("LPMM知识库-信息提取")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join(
ROOT_PATH, "data", "imported_lpmm_data"
)
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie")
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
open_ie_doc_lock = Lock()
# 创建一个事件标志,用于控制程序终止
shutdown_event = Event()
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
def ensure_dirs():
"""确保临时目录和输出目录存在"""
@@ -54,12 +44,26 @@ def ensure_dirs():
if not os.path.exists(OPENIE_OUTPUT_DIR):
os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
if not os.path.exists(IMPORTED_DATA_PATH):
os.makedirs(IMPORTED_DATA_PATH)
logger.info(f"已创建导入数据目录: {IMPORTED_DATA_PATH}")
if not os.path.exists(RAW_DATA_PATH):
os.makedirs(RAW_DATA_PATH)
logger.info(f"已创建原始数据目录: {RAW_DATA_PATH}")
# 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock()
open_ie_doc_lock = Lock()
def process_single_text(pg_hash, raw_data, llm_client_list):
# 创建一个事件标志,用于控制程序终止
shutdown_event = Event()
lpmm_entity_extract_llm = LLMRequest(
model=global_config.model.lpmm_entity_extract,
request_type="lpmm.entity_extract"
)
lpmm_rdf_build_llm = LLMRequest(
model=global_config.model.lpmm_rdf_build,
request_type="lpmm.rdf_build"
)
def process_single_text(pg_hash, raw_data):
"""处理单个文本的函数,用于线程池"""
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
@@ -77,8 +81,8 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
os.remove(temp_file_path)
entity_list, rdf_triple_list = info_extract_from_str(
llm_client_list[global_config["entity_extract"]["llm"]["provider"]],
llm_client_list[global_config["rdf_build"]["llm"]["provider"]],
lpmm_entity_extract_llm,
lpmm_rdf_build_llm,
raw_data,
)
if entity_list is None or rdf_triple_list is None:
@@ -113,7 +117,7 @@ def signal_handler(_signum, _frame):
def main(): # sourcery skip: comprehension-to-generator, extract-method
# 设置信号处理器
signal.signal(signal.SIGINT, signal_handler)
ensure_dirs() # 确保目录存在
# 新增用户确认提示
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
print("实体提取操作将会花费较多api余额和时间建议在空闲时段执行。")
@@ -130,50 +134,17 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
ensure_dirs() # 确保目录存在
logger.info("--------进行信息提取--------\n")
logger.info("创建LLM客户端")
llm_client_list = {
key: LLMClient(
global_config["llm_providers"][key]["base_url"],
global_config["llm_providers"][key]["api_key"],
)
for key in global_config["llm_providers"]
}
# 检查 openie 输出目录
if not os.path.exists(OPENIE_OUTPUT_DIR):
os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
# 确保 TEMP_DIR 目录存在
if not os.path.exists(TEMP_DIR):
os.makedirs(TEMP_DIR)
logger.info(f"已创建缓存目录: {TEMP_DIR}")
# 遍历IMPORTED_DATA_PATH下所有json文件
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
if not imported_files:
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
sys.exit(1)
all_sha256_list = []
all_raw_datas = []
for imported_file in imported_files:
logger.info(f"正在处理文件: {imported_file}")
try:
sha256_list, raw_datas = load_raw_data(imported_file)
except Exception as e:
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
continue
all_sha256_list.extend(sha256_list)
all_raw_datas.extend(raw_datas)
# 加载原始数据
logger.info("正在加载原始数据")
all_sha256_list, all_raw_datas = load_raw_data()
failed_sha256 = []
open_ie_doc = []
workers = global_config["info_extraction"]["workers"]
workers = global_config.lpmm_knowledge.info_extraction_workers
with ThreadPoolExecutor(max_workers=workers) as executor:
future_to_hash = {
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
}

View File

@@ -1,40 +1,16 @@
import json
import os
from pathlib import Path
import sys # 新增系统模块导入
import datetime # 新增导入
from src.chat.knowledge.utils.hash import get_sha256
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_logger
from src.chat.knowledge.lpmmconfig import global_config
logger = get_logger("lpmm")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
# 新增:确保 RAW_DATA_PATH 存在
if not os.path.exists(RAW_DATA_PATH):
os.makedirs(RAW_DATA_PATH, exist_ok=True)
logger.info(f"已创建目录: {RAW_DATA_PATH}")
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
if global_config.get("persistence", {}).get("raw_data_path") is not None:
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"])
else:
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
# 添加项目根目录到 sys.path
def check_and_create_dirs():
"""检查并创建必要的目录"""
required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
for dir_path in required_dirs:
if not os.path.exists(dir_path):
os.makedirs(dir_path)
logger.info(f"已创建目录: {dir_path}")
def process_text_file(file_path):
def _process_text_file(file_path):
"""处理单个文本文件,返回段落列表"""
with open(file_path, "r", encoding="utf-8") as f:
raw = f.read()
@@ -55,54 +31,45 @@ def process_text_file(file_path):
return paragraphs
def main():
# 新增用户确认提示
print("=== 数据预处理脚本 ===")
print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。")
print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。")
print("请确保原始数据已放置在正确的目录中。")
confirm = input("确认继续执行?(y/n): ").strip().lower()
if confirm != "y":
logger.info("操作已取消")
sys.exit(1)
print("\n" + "=" * 40 + "\n")
# 检查并创建必要的目录
check_and_create_dirs()
# # 检查输出文件是否存在
# if os.path.exists(RAW_DATA_PATH):
# logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
# sys.exit(1)
# if os.path.exists(RAW_DATA_PATH):
# logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
# sys.exit(1)
# 获取所有原始文本文件
def _process_multi_files() -> list:
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
if not raw_files:
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
sys.exit(1)
# 处理所有文件
all_paragraphs = []
for file in raw_files:
logger.info(f"正在处理文件: {file.name}")
paragraphs = process_text_file(file)
paragraphs = _process_text_file(file)
all_paragraphs.extend(paragraphs)
return all_paragraphs
# 保存合并后的结果到 IMPORTED_DATA_PATH文件名格式为 MM-DD-HH-ss-imported-data.json
now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
def load_raw_data() -> tuple[list[str], list[str]]:
"""加载原始数据文件
logger.info(f"处理完成,结果已保存到: {output_path}")
读取原始数据文件,将原始数据加载到内存中
Args:
path: 可选指定要读取的json文件绝对路径
if __name__ == "__main__":
logger.info(f"原始数据路径: {RAW_DATA_PATH}")
logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}")
main()
Returns:
- raw_data: 原始数据列表
- sha256_list: 原始数据的SHA256集合
"""
raw_data = _process_multi_files()
sha256_list = []
sha256_set = set()
for item in raw_data:
if not isinstance(item, str):
logger.warning(f"数据类型错误:{item}")
continue
pg_hash = get_sha256(item)
if pg_hash in sha256_set:
logger.warning(f"重复数据:{item}")
continue
sha256_set.add(pg_hash)
sha256_list.append(pg_hash)
raw_data.append(item)
logger.info(f"共读取到{len(raw_data)}条数据")
return sha256_list, raw_data

View File

@@ -10,8 +10,8 @@ import pandas as pd
# import tqdm
import faiss
from .llm_client import LLMClient
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
# from .llm_client import LLMClient
from .lpmmconfig import global_config
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
@@ -25,6 +25,9 @@ from rich.progress import (
SpinnerColumn,
TextColumn,
)
from src.manager.local_store_manager import local_storage
from src.chat.utils.utils import get_embedding
install(extra_lines=3)
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
@@ -86,21 +89,20 @@ class EmbeddingStoreItem:
class EmbeddingStore:
def __init__(self, llm_client: LLMClient, namespace: str, dir_path: str):
def __init__(self, namespace: str, dir_path: str):
self.namespace = namespace
self.llm_client = llm_client
self.dir = dir_path
self.embedding_file_path = dir_path + "/" + namespace + ".parquet"
self.index_file_path = dir_path + "/" + namespace + ".index"
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
self.index_file_path = f"{dir_path}/{namespace}.index"
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
self.store = dict()
self.store = {}
self.faiss_index = None
self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
return get_embedding(s)
def get_test_file_path(self):
return EMBEDDING_TEST_FILE
@@ -293,20 +295,17 @@ class EmbeddingStore:
class EmbeddingManager:
def __init__(self, llm_client: LLMClient):
def __init__(self):
self.paragraphs_embedding_store = EmbeddingStore(
llm_client,
PG_NAMESPACE,
local_storage['pg_namespace'],
EMBEDDING_DATA_DIR_STR,
)
self.entities_embedding_store = EmbeddingStore(
llm_client,
ENT_NAMESPACE,
local_storage['pg_namespace'],
EMBEDDING_DATA_DIR_STR,
)
self.relation_embedding_store = EmbeddingStore(
llm_client,
REL_NAMESPACE,
local_storage['pg_namespace'],
EMBEDDING_DATA_DIR_STR,
)
self.stored_pg_hashes = set()

View File

@@ -4,28 +4,35 @@ from typing import List, Union
from .global_logger import logger
from . import prompt_template
from .lpmmconfig import global_config, INVALID_ENTITY
from .llm_client import LLMClient
from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json
from .knowledge_lib import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json
def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法"""
try:
fixed_json = repair_json(text)
if isinstance(fixed_json, str):
parsed_json = json.loads(fixed_json)
else:
parsed_json = fixed_json
if isinstance(parsed_json, list) and parsed_json:
parsed_json = parsed_json[0]
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
if isinstance(parsed_json, dict):
return parsed_json
except Exception as e:
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
_, request_result = llm_client.send_chat_request(
global_config["entity_extract"]["llm"]["model"], entity_extract_context
)
# 去除‘{’前的内容(结果中可能有多个‘{
if "[" in request_result:
request_result = request_result[request_result.index("[") :]
# 去除最后一个‘}’后的内容(结果中可能有多个‘}
if "]" in request_result:
request_result = request_result[: request_result.rindex("]") + 1]
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据
json.loads(entity_extract_result)
entity_extract_result = [
entity
for entity in entity_extract_result
@@ -38,23 +45,16 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
return entity_extract_result
def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_rdf_triple_extract_context(
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False)
)
_, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
# 去除‘{’前的内容(结果中可能有多个‘{
if "[" in request_result:
request_result = request_result[request_result.index("[") :]
# 去除最后一个‘}’后的内容(结果中可能有多个‘}
if "]" in request_result:
request_result = request_result[: request_result.rindex("]") + 1]
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据
json.loads(entity_extract_result)
for triple in entity_extract_result:
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
raise Exception("RDF提取结果格式错误")
@@ -63,7 +63,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -
def info_extract_from_str(
llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
try_count = 0
while True:

View File

@@ -20,24 +20,37 @@ from quick_algo import di_graph, pagerank
from .utils.hash import get_sha256
from .embedding_store import EmbeddingManager, EmbeddingStoreItem
from .lpmmconfig import (
ENT_NAMESPACE,
PG_NAMESPACE,
RAG_ENT_CNT_NAMESPACE,
RAG_GRAPH_NAMESPACE,
RAG_PG_HASH_NAMESPACE,
global_config,
)
from .lpmmconfig import global_config
from src.manager.local_store_manager import local_storage
from .global_logger import logger
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
KG_DIR = (
os.path.join(ROOT_PATH, "data/rag")
if global_config["persistence"]["rag_data_dir"] is None
else os.path.join(ROOT_PATH, global_config["persistence"]["rag_data_dir"])
)
KG_DIR_STR = str(KG_DIR).replace("\\", "/")
def _get_kg_dir():
"""
安全地获取KG数据目录路径
"""
root_path = local_storage['root_path']
if root_path is None:
# 如果 local_storage 中没有 root_path使用当前文件的相对路径作为备用
current_dir = os.path.dirname(os.path.abspath(__file__))
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
logger.warning(f"local_storage 中未找到 root_path使用备用路径: {root_path}")
# 获取RAG数据目录
rag_data_dir = global_config["persistence"]["rag_data_dir"]
if rag_data_dir is None:
kg_dir = os.path.join(root_path, "data/rag")
else:
kg_dir = os.path.join(root_path, rag_data_dir)
return str(kg_dir).replace("\\", "/")
# 延迟初始化,避免在模块加载时就访问可能未初始化的 local_storage
def get_kg_dir_str():
"""获取KG目录字符串"""
return _get_kg_dir()
class KGManager:
@@ -46,15 +59,15 @@ class KGManager:
# 存储段落的hash值用于去重
self.stored_paragraph_hashes = set()
# 实体出现次数
self.ent_appear_cnt = dict()
self.ent_appear_cnt = {}
# KG
self.graph = di_graph.DiGraph()
# 持久化相关
self.dir_path = KG_DIR_STR
self.graph_data_path = self.dir_path + "/" + RAG_GRAPH_NAMESPACE + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + RAG_ENT_CNT_NAMESPACE + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + RAG_PG_HASH_NAMESPACE + ".json"
# 持久化相关 - 使用延迟初始化的路径
self.dir_path = get_kg_dir_str()
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json"
def save_to_file(self):
"""将KG数据保存到文件"""
@@ -109,8 +122,8 @@ class KGManager:
# 避免自连接
continue
# 一个triple就是一条边同时构建双向联系
hash_key1 = ENT_NAMESPACE + "-" + get_sha256(triple[0])
hash_key2 = ENT_NAMESPACE + "-" + get_sha256(triple[2])
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1)
@@ -128,8 +141,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系"""
for idx in triple_list_data:
for triple in triple_list_data[idx]:
ent_hash_key = ENT_NAMESPACE + "-" + get_sha256(triple[0])
pg_hash_key = PG_NAMESPACE + "-" + str(idx)
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod
@@ -144,8 +157,8 @@ class KGManager:
ent_hash_list = set()
for triple_list in triple_list_data.values():
for triple in triple_list:
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[0]))
ent_hash_list.add(ENT_NAMESPACE + "-" + get_sha256(triple[2]))
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0]))
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list)
synonym_hash_set = set()
@@ -250,7 +263,7 @@ class KGManager:
for src_tgt in node_to_node.keys():
for node_hash in src_tgt:
if node_hash not in existed_nodes:
if node_hash.startswith(ENT_NAMESPACE):
if node_hash.startswith(local_storage['ent_namespace']):
# 新增实体节点
node = embedding_manager.entities_embedding_store.store[node_hash]
assert isinstance(node, EmbeddingStoreItem)
@@ -259,7 +272,7 @@ class KGManager:
node_item["type"] = "ent"
node_item["create_time"] = now_time
self.graph.update_node(node_item)
elif node_hash.startswith(PG_NAMESPACE):
elif node_hash.startswith(local_storage['pg_namespace']):
# 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
assert isinstance(node, EmbeddingStoreItem)
@@ -340,7 +353,7 @@ class KGManager:
# 关系三元组
triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]:
ent_hash = ENT_NAMESPACE + "-" + get_sha256(ent)
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = []
@@ -418,7 +431,7 @@ class KGManager:
# 获取最终结果
# 从搜索结果中提取文段节点的结果
passage_node_res = [
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(PG_NAMESPACE)
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace'])
]
del ppr_res

View File

@@ -1,4 +1,4 @@
from src.chat.knowledge.lpmmconfig import PG_NAMESPACE, global_config
from src.chat.knowledge.lpmmconfig import global_config
from src.chat.knowledge.embedding_store import EmbeddingManager
from src.chat.knowledge.llm_client import LLMClient
from src.chat.knowledge.mem_active_manager import MemoryActiveManager
@@ -6,10 +6,80 @@ from src.chat.knowledge.qa_manager import QAManager
from src.chat.knowledge.kg_manager import KGManager
from src.chat.knowledge.global_logger import logger
from src.config.config import global_config as bot_global_config
# try:
# import quick_algo
# except ImportError:
# print("quick_algo not found, please install it first")
from src.manager.local_store_manager import local_storage
import os
INVALID_ENTITY = [
"",
"",
"",
"",
"",
"我们",
"你们",
"他们",
"她们",
"它们",
]
PG_NAMESPACE = "paragraph"
ENT_NAMESPACE = "entity"
REL_NAMESPACE = "relation"
RAG_GRAPH_NAMESPACE = "rag-graph"
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
DATA_PATH = os.path.join(ROOT_PATH, "data")
def _initialize_knowledge_local_storage():
"""
初始化知识库相关的本地存储配置
使用字典批量设置避免重复的if判断
"""
# 定义所有需要初始化的配置项
default_configs = {
# 路径配置
'root_path': ROOT_PATH,
'data_path': f"{ROOT_PATH}/data",
# 实体和命名空间配置
'lpmm_invalid_entity': INVALID_ENTITY,
'pg_namespace': PG_NAMESPACE,
'ent_namespace': ENT_NAMESPACE,
'rel_namespace': REL_NAMESPACE,
# RAG相关命名空间配置
'rag_graph_namespace': RAG_GRAPH_NAMESPACE,
'rag_ent_cnt_namespace': RAG_ENT_CNT_NAMESPACE,
'rag_pg_hash_namespace': RAG_PG_HASH_NAMESPACE
}
# 日志级别映射重要配置用info其他用debug
important_configs = {'root_path', 'data_path'}
# 批量设置配置项
initialized_count = 0
for key, default_value in default_configs.items():
if local_storage[key] is None:
local_storage[key] = default_value
# 根据重要性选择日志级别
if key in important_configs:
logger.info(f"设置{key}: {default_value}")
else:
logger.debug(f"设置{key}: {default_value}")
initialized_count += 1
if initialized_count > 0:
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
else:
logger.debug("知识库本地存储配置已存在,跳过初始化")
# 初始化本地存储路径
_initialize_knowledge_local_storage()
# 检查LPMM知识库是否启用
if bot_global_config.lpmm_knowledge.enable:
@@ -23,7 +93,7 @@ if bot_global_config.lpmm_knowledge.enable:
)
# 初始化Embedding库
embed_manager = EmbeddingManager(llm_client_list[global_config["embedding"]["provider"]])
embed_manager = EmbeddingManager()
logger.info("正在从文件加载Embedding库")
try:
embed_manager.load_from_file()
@@ -54,9 +124,6 @@ if bot_global_config.lpmm_knowledge.enable:
qa_manager = QAManager(
embed_manager,
kg_manager,
llm_client_list[global_config["embedding"]["provider"]],
llm_client_list[global_config["qa"]["llm"]["provider"]],
llm_client_list[global_config["qa"]["llm"]["provider"]],
)
# 记忆激活(用于记忆库)

View File

@@ -4,9 +4,8 @@ import glob
from typing import Any, Dict, List
from .lpmmconfig import INVALID_ENTITY, global_config
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
# from src.manager.local_store_manager import local_storage
def _filter_invalid_entities(entities: List[str]) -> List[str]:
@@ -107,7 +106,7 @@ class OpenIE:
@staticmethod
def load() -> "OpenIE":
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
openie_dir = os.path.join(DATA_PATH, "openie")
if not os.path.exists(openie_dir):
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
@@ -122,12 +121,6 @@ class OpenIE:
openie_data = OpenIE._from_dict(data_list)
return openie_data
@staticmethod
def save(openie_data: "OpenIE"):
"""保存OpenIE数据到文件"""
with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
def extract_entity_dict(self):
"""提取实体列表"""
ner_output_dict = dict(

View File

@@ -5,11 +5,13 @@ from .global_logger import logger
# from . import prompt_template
from .embedding_store import EmbeddingManager
from .llm_client import LLMClient
# from .llm_client import LLMClient
from .kg_manager import KGManager
from .lpmmconfig import global_config
# from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding
from src.config.config import global_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@@ -19,26 +21,25 @@ class QAManager:
self,
embed_manager: EmbeddingManager,
kg_manager: KGManager,
llm_client_embedding: LLMClient,
llm_client_filter: LLMClient,
llm_client_qa: LLMClient,
):
self.embed_manager = embed_manager
self.kg_manager = kg_manager
self.llm_client_list = {
"embedding": llm_client_embedding,
"message_filter": llm_client_filter,
"qa": llm_client_qa,
}
# TODO: API-Adapter修改标记
self.qa_model = LLMRequest(
model=global_config.model.lpmm_qa,
request_type="lpmm.qa"
)
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
"""处理查询"""
# 生成问题的Embedding
part_start_time = time.perf_counter()
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
global_config["embedding"]["model"], question
)
question_embedding = get_embedding(question)
if question_embedding is None:
logger.error("生成问题Embedding失败")
return None
part_end_time = time.perf_counter()
logger.debug(f"Embedding用时{part_end_time - part_start_time:.5f}s")
@@ -46,14 +47,15 @@ class QAManager:
part_start_time = time.perf_counter()
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
question_embedding,
global_config["qa"]["params"]["relation_search_top_k"],
global_config.lpmm_knowledge.qa_relation_search_top_k,
)
if relation_search_res is not None:
# 过滤阈值
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
# 未找到相关关系
logger.debug("未找到相关关系,跳过关系检索")
relation_search_res = []
part_end_time = time.perf_counter()
@@ -71,7 +73,7 @@ class QAManager:
part_start_time = time.perf_counter()
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
question_embedding,
global_config["qa"]["params"]["paragraph_search_top_k"],
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
)
part_end_time = time.perf_counter()
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")

View File

@@ -625,3 +625,12 @@ class ModelConfig(ConfigBase):
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM实体提取模型配置"""
lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM RDF构建模型配置"""
lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM问答模型配置"""