feat: 重构信息提取模块,移除LLMClient依赖,改为使用LLMRequest,优化数据加载和处理逻辑
This commit is contained in:
@@ -13,11 +13,10 @@ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|||||||
from rich.progress import Progress # 替换为 rich 进度条
|
from rich.progress import Progress # 替换为 rich 进度条
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.knowledge.lpmmconfig import global_config
|
# from src.chat.knowledge.lpmmconfig import global_config
|
||||||
from src.chat.knowledge.ie_process import info_extract_from_str
|
from src.chat.knowledge.ie_process import _entity_extract, info_extract_from_str
|
||||||
from src.chat.knowledge.llm_client import LLMClient
|
from src.chat.knowledge.llm_client import LLMClient
|
||||||
from src.chat.knowledge.open_ie import OpenIE
|
from src.chat.knowledge.open_ie import OpenIE
|
||||||
from src.chat.knowledge.raw_processing import load_raw_data
|
|
||||||
from rich.progress import (
|
from rich.progress import (
|
||||||
BarColumn,
|
BarColumn,
|
||||||
TimeElapsedColumn,
|
TimeElapsedColumn,
|
||||||
@@ -27,16 +26,17 @@ from rich.progress import (
|
|||||||
SpinnerColumn,
|
SpinnerColumn,
|
||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
|
from raw_data_preprocessor import process_multi_files, load_raw_data
|
||||||
|
from src.config.config import global_config
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
logger = get_logger("LPMM知识库-信息提取")
|
logger = get_logger("LPMM知识库-信息提取")
|
||||||
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||||
IMPORTED_DATA_PATH = global_config["persistence"]["imported_data_path"] or os.path.join(
|
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||||
ROOT_PATH, "data", "imported_lpmm_data"
|
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||||
)
|
|
||||||
OPENIE_OUTPUT_DIR = global_config["persistence"]["openie_data_path"] or os.path.join(ROOT_PATH, "data", "openie")
|
|
||||||
|
|
||||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||||
file_lock = Lock()
|
file_lock = Lock()
|
||||||
@@ -45,6 +45,14 @@ open_ie_doc_lock = Lock()
|
|||||||
# 创建一个事件标志,用于控制程序终止
|
# 创建一个事件标志,用于控制程序终止
|
||||||
shutdown_event = Event()
|
shutdown_event = Event()
|
||||||
|
|
||||||
|
lpmm_entity_extract_llm = LLMRequest(
|
||||||
|
model=global_config.model.lpmm_entity_extract,
|
||||||
|
request_type="lpmm.entity_extract"
|
||||||
|
)
|
||||||
|
lpmm_rdf_build_llm = LLMRequest(
|
||||||
|
model=global_config.model.lpmm_rdf_build,
|
||||||
|
request_type="lpmm.rdf_build"
|
||||||
|
)
|
||||||
|
|
||||||
def ensure_dirs():
|
def ensure_dirs():
|
||||||
"""确保临时目录和输出目录存在"""
|
"""确保临时目录和输出目录存在"""
|
||||||
@@ -54,12 +62,9 @@ def ensure_dirs():
|
|||||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||||
if not os.path.exists(IMPORTED_DATA_PATH):
|
|
||||||
os.makedirs(IMPORTED_DATA_PATH)
|
|
||||||
logger.info(f"已创建导入数据目录: {IMPORTED_DATA_PATH}")
|
|
||||||
|
|
||||||
|
|
||||||
def process_single_text(pg_hash, raw_data, llm_client_list):
|
def process_single_text(pg_hash, raw_data):
|
||||||
"""处理单个文本的函数,用于线程池"""
|
"""处理单个文本的函数,用于线程池"""
|
||||||
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
temp_file_path = f"{TEMP_DIR}/{pg_hash}.json"
|
||||||
|
|
||||||
@@ -77,8 +82,8 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
|||||||
os.remove(temp_file_path)
|
os.remove(temp_file_path)
|
||||||
|
|
||||||
entity_list, rdf_triple_list = info_extract_from_str(
|
entity_list, rdf_triple_list = info_extract_from_str(
|
||||||
llm_client_list[global_config["entity_extract"]["llm"]["provider"]],
|
lpmm_entity_extract_llm,
|
||||||
llm_client_list[global_config["rdf_build"]["llm"]["provider"]],
|
lpmm_rdf_build_llm,
|
||||||
raw_data,
|
raw_data,
|
||||||
)
|
)
|
||||||
if entity_list is None or rdf_triple_list is None:
|
if entity_list is None or rdf_triple_list is None:
|
||||||
@@ -130,50 +135,17 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
|||||||
ensure_dirs() # 确保目录存在
|
ensure_dirs() # 确保目录存在
|
||||||
logger.info("--------进行信息提取--------\n")
|
logger.info("--------进行信息提取--------\n")
|
||||||
|
|
||||||
logger.info("创建LLM客户端")
|
# 加载原始数据
|
||||||
llm_client_list = {
|
logger.info("正在加载原始数据")
|
||||||
key: LLMClient(
|
all_sha256_list, all_raw_datas = load_raw_data()
|
||||||
global_config["llm_providers"][key]["base_url"],
|
|
||||||
global_config["llm_providers"][key]["api_key"],
|
|
||||||
)
|
|
||||||
for key in global_config["llm_providers"]
|
|
||||||
}
|
|
||||||
# 检查 openie 输出目录
|
|
||||||
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
|
||||||
os.makedirs(OPENIE_OUTPUT_DIR)
|
|
||||||
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
|
||||||
|
|
||||||
# 确保 TEMP_DIR 目录存在
|
|
||||||
if not os.path.exists(TEMP_DIR):
|
|
||||||
os.makedirs(TEMP_DIR)
|
|
||||||
logger.info(f"已创建缓存目录: {TEMP_DIR}")
|
|
||||||
|
|
||||||
# 遍历IMPORTED_DATA_PATH下所有json文件
|
|
||||||
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
|
|
||||||
if not imported_files:
|
|
||||||
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
all_sha256_list = []
|
|
||||||
all_raw_datas = []
|
|
||||||
|
|
||||||
for imported_file in imported_files:
|
|
||||||
logger.info(f"正在处理文件: {imported_file}")
|
|
||||||
try:
|
|
||||||
sha256_list, raw_datas = load_raw_data(imported_file)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
|
|
||||||
continue
|
|
||||||
all_sha256_list.extend(sha256_list)
|
|
||||||
all_raw_datas.extend(raw_datas)
|
|
||||||
|
|
||||||
failed_sha256 = []
|
failed_sha256 = []
|
||||||
open_ie_doc = []
|
open_ie_doc = []
|
||||||
|
|
||||||
workers = global_config["info_extraction"]["workers"]
|
workers = global_config.lpmm_knowledge.info_extraction_workers
|
||||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||||
future_to_hash = {
|
future_to_hash = {
|
||||||
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
|
executor.submit(process_single_text, pg_hash, raw_data): pg_hash
|
||||||
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas, strict=False)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +1,16 @@
|
|||||||
import json
|
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys # 新增系统模块导入
|
import sys # 新增系统模块导入
|
||||||
import datetime # 新增导入
|
from src.chat.knowledge.utils.hash import get_sha256
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.knowledge.lpmmconfig import global_config
|
|
||||||
|
|
||||||
logger = get_logger("lpmm")
|
logger = get_logger("lpmm")
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||||
# 新增:确保 RAW_DATA_PATH 存在
|
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||||
if not os.path.exists(RAW_DATA_PATH):
|
|
||||||
os.makedirs(RAW_DATA_PATH, exist_ok=True)
|
|
||||||
logger.info(f"已创建目录: {RAW_DATA_PATH}")
|
|
||||||
|
|
||||||
if global_config.get("persistence", {}).get("raw_data_path") is not None:
|
def _process_text_file(file_path):
|
||||||
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, global_config["persistence"]["raw_data_path"])
|
|
||||||
else:
|
|
||||||
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
|
||||||
|
|
||||||
|
|
||||||
def check_and_create_dirs():
|
|
||||||
"""检查并创建必要的目录"""
|
|
||||||
required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
|
|
||||||
|
|
||||||
for dir_path in required_dirs:
|
|
||||||
if not os.path.exists(dir_path):
|
|
||||||
os.makedirs(dir_path)
|
|
||||||
logger.info(f"已创建目录: {dir_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def process_text_file(file_path):
|
|
||||||
"""处理单个文本文件,返回段落列表"""
|
"""处理单个文本文件,返回段落列表"""
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
raw = f.read()
|
raw = f.read()
|
||||||
@@ -55,54 +31,45 @@ def process_text_file(file_path):
|
|||||||
return paragraphs
|
return paragraphs
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def _process_multi_files() -> list:
|
||||||
# 新增用户确认提示
|
|
||||||
print("=== 数据预处理脚本 ===")
|
|
||||||
print(f"本脚本将处理 '{RAW_DATA_PATH}' 目录下的所有 .txt 文件。")
|
|
||||||
print(f"处理后的段落数据将合并,并以 MM-DD-HH-SS-imported-data.json 的格式保存在 '{IMPORTED_DATA_PATH}' 目录中。")
|
|
||||||
print("请确保原始数据已放置在正确的目录中。")
|
|
||||||
confirm = input("确认继续执行?(y/n): ").strip().lower()
|
|
||||||
if confirm != "y":
|
|
||||||
logger.info("操作已取消")
|
|
||||||
sys.exit(1)
|
|
||||||
print("\n" + "=" * 40 + "\n")
|
|
||||||
|
|
||||||
# 检查并创建必要的目录
|
|
||||||
check_and_create_dirs()
|
|
||||||
|
|
||||||
# # 检查输出文件是否存在
|
|
||||||
# if os.path.exists(RAW_DATA_PATH):
|
|
||||||
# logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
|
|
||||||
# sys.exit(1)
|
|
||||||
|
|
||||||
# if os.path.exists(RAW_DATA_PATH):
|
|
||||||
# logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
|
|
||||||
# sys.exit(1)
|
|
||||||
|
|
||||||
# 获取所有原始文本文件
|
|
||||||
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||||
if not raw_files:
|
if not raw_files:
|
||||||
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
# 处理所有文件
|
# 处理所有文件
|
||||||
all_paragraphs = []
|
all_paragraphs = []
|
||||||
for file in raw_files:
|
for file in raw_files:
|
||||||
logger.info(f"正在处理文件: {file.name}")
|
logger.info(f"正在处理文件: {file.name}")
|
||||||
paragraphs = process_text_file(file)
|
paragraphs = _process_text_file(file)
|
||||||
all_paragraphs.extend(paragraphs)
|
all_paragraphs.extend(paragraphs)
|
||||||
|
return all_paragraphs
|
||||||
|
|
||||||
# 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json
|
def load_raw_data() -> tuple[list[str], list[str]]:
|
||||||
now = datetime.datetime.now()
|
"""加载原始数据文件
|
||||||
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
|
|
||||||
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
|
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
|
||||||
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
|
|
||||||
|
|
||||||
logger.info(f"处理完成,结果已保存到: {output_path}")
|
读取原始数据文件,将原始数据加载到内存中
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 可选,指定要读取的json文件绝对路径
|
||||||
|
|
||||||
if __name__ == "__main__":
|
Returns:
|
||||||
logger.info(f"原始数据路径: {RAW_DATA_PATH}")
|
- raw_data: 原始数据列表
|
||||||
logger.info(f"处理后的数据路径: {IMPORTED_DATA_PATH}")
|
- sha256_list: 原始数据的SHA256集合
|
||||||
main()
|
"""
|
||||||
|
raw_data = _process_multi_files()
|
||||||
|
sha256_list = []
|
||||||
|
sha256_set = set()
|
||||||
|
for item in raw_data:
|
||||||
|
if not isinstance(item, str):
|
||||||
|
logger.warning(f"数据类型错误:{item}")
|
||||||
|
continue
|
||||||
|
pg_hash = get_sha256(item)
|
||||||
|
if pg_hash in sha256_set:
|
||||||
|
logger.warning(f"重复数据:{item}")
|
||||||
|
continue
|
||||||
|
sha256_set.add(pg_hash)
|
||||||
|
sha256_list.append(pg_hash)
|
||||||
|
raw_data.append(item)
|
||||||
|
logger.info(f"共读取到{len(raw_data)}条数据")
|
||||||
|
|
||||||
|
return sha256_list, raw_data
|
||||||
@@ -10,7 +10,7 @@ import pandas as pd
|
|||||||
# import tqdm
|
# import tqdm
|
||||||
import faiss
|
import faiss
|
||||||
|
|
||||||
from .llm_client import LLMClient
|
# from .llm_client import LLMClient
|
||||||
from .lpmmconfig import global_config
|
from .lpmmconfig import global_config
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
@@ -295,7 +295,7 @@ class EmbeddingStore:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingManager:
|
class EmbeddingManager:
|
||||||
def __init__(self, llm_client: LLMClient):
|
def __init__(self):
|
||||||
self.paragraphs_embedding_store = EmbeddingStore(
|
self.paragraphs_embedding_store = EmbeddingStore(
|
||||||
local_storage['pg_namespace'],
|
local_storage['pg_namespace'],
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
|
|||||||
@@ -7,25 +7,34 @@ from . import prompt_template
|
|||||||
from .lpmmconfig import global_config, INVALID_ENTITY
|
from .lpmmconfig import global_config, INVALID_ENTITY
|
||||||
from .llm_client import LLMClient
|
from .llm_client import LLMClient
|
||||||
from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json
|
from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from json_repair import repair_json
|
||||||
|
def _extract_json_from_text(text: str) -> dict:
|
||||||
|
"""从文本中提取JSON数据的高容错方法"""
|
||||||
|
try:
|
||||||
|
fixed_json = repair_json(text)
|
||||||
|
if isinstance(fixed_json, str):
|
||||||
|
parsed_json = json.loads(fixed_json)
|
||||||
|
else:
|
||||||
|
parsed_json = fixed_json
|
||||||
|
|
||||||
|
if isinstance(parsed_json, list) and parsed_json:
|
||||||
|
parsed_json = parsed_json[0]
|
||||||
|
|
||||||
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
if isinstance(parsed_json, dict):
|
||||||
|
return parsed_json
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
|
||||||
|
|
||||||
|
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||||
_, request_result = llm_client.send_chat_request(
|
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
|
||||||
global_config["entity_extract"]["llm"]["model"], entity_extract_context
|
|
||||||
)
|
|
||||||
|
|
||||||
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
|
||||||
if "[" in request_result:
|
|
||||||
request_result = request_result[request_result.index("[") :]
|
|
||||||
|
|
||||||
# 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
|
|
||||||
if "]" in request_result:
|
|
||||||
request_result = request_result[: request_result.rindex("]") + 1]
|
|
||||||
|
|
||||||
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
|
||||||
|
|
||||||
|
entity_extract_result = _extract_json_from_text(response)
|
||||||
|
# 尝试load JSON数据
|
||||||
|
json.loads(entity_extract_result)
|
||||||
entity_extract_result = [
|
entity_extract_result = [
|
||||||
entity
|
entity
|
||||||
for entity in entity_extract_result
|
for entity in entity_extract_result
|
||||||
@@ -38,23 +47,16 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
|
|||||||
return entity_extract_result
|
return entity_extract_result
|
||||||
|
|
||||||
|
|
||||||
def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
|
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_rdf_triple_extract_context(
|
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||||
)
|
)
|
||||||
_, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
|
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
|
||||||
|
|
||||||
# 去除‘{’前的内容(结果中可能有多个‘{’)
|
|
||||||
if "[" in request_result:
|
|
||||||
request_result = request_result[request_result.index("[") :]
|
|
||||||
|
|
||||||
# 去除最后一个‘}’后的内容(结果中可能有多个‘}’)
|
|
||||||
if "]" in request_result:
|
|
||||||
request_result = request_result[: request_result.rindex("]") + 1]
|
|
||||||
|
|
||||||
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
|
|
||||||
|
|
||||||
|
entity_extract_result = _extract_json_from_text(response)
|
||||||
|
# 尝试load JSON数据
|
||||||
|
json.loads(entity_extract_result)
|
||||||
for triple in entity_extract_result:
|
for triple in entity_extract_result:
|
||||||
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||||
raise Exception("RDF提取结果格式错误")
|
raise Exception("RDF提取结果格式错误")
|
||||||
@@ -63,7 +65,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -
|
|||||||
|
|
||||||
|
|
||||||
def info_extract_from_str(
|
def info_extract_from_str(
|
||||||
llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str
|
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
|
||||||
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
|
||||||
try_count = 0
|
try_count = 0
|
||||||
while True:
|
while True:
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
|
|||||||
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
|
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||||
|
|
||||||
def _initialize_knowledge_local_storage():
|
def _initialize_knowledge_local_storage():
|
||||||
"""
|
"""
|
||||||
@@ -42,10 +43,6 @@ def _initialize_knowledge_local_storage():
|
|||||||
# 路径配置
|
# 路径配置
|
||||||
'root_path': ROOT_PATH,
|
'root_path': ROOT_PATH,
|
||||||
'data_path': f"{ROOT_PATH}/data",
|
'data_path': f"{ROOT_PATH}/data",
|
||||||
'lpmm_raw_data_path': f"{ROOT_PATH}/data/raw_data",
|
|
||||||
'lpmm_openie_data_path': f"{ROOT_PATH}/data/openie",
|
|
||||||
'lpmm_embedding_data_dir': f"{ROOT_PATH}/data/embedding",
|
|
||||||
'lpmm_rag_data_dir': f"{ROOT_PATH}/data/rag",
|
|
||||||
|
|
||||||
# 实体和命名空间配置
|
# 实体和命名空间配置
|
||||||
'lpmm_invalid_entity': INVALID_ENTITY,
|
'lpmm_invalid_entity': INVALID_ENTITY,
|
||||||
|
|||||||
@@ -4,9 +4,8 @@ import glob
|
|||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
from .lpmmconfig import INVALID_ENTITY, global_config
|
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
|
||||||
|
# from src.manager.local_store_manager import local_storage
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
||||||
|
|
||||||
|
|
||||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||||
@@ -107,7 +106,7 @@ class OpenIE:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def load() -> "OpenIE":
|
def load() -> "OpenIE":
|
||||||
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
|
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
|
||||||
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
|
openie_dir = os.path.join(DATA_PATH, "openie")
|
||||||
if not os.path.exists(openie_dir):
|
if not os.path.exists(openie_dir):
|
||||||
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
|
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
|
||||||
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
||||||
@@ -122,12 +121,6 @@ class OpenIE:
|
|||||||
openie_data = OpenIE._from_dict(data_list)
|
openie_data = OpenIE._from_dict(data_list)
|
||||||
return openie_data
|
return openie_data
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def save(openie_data: "OpenIE"):
|
|
||||||
"""保存OpenIE数据到文件"""
|
|
||||||
with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
|
|
||||||
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
|
|
||||||
|
|
||||||
def extract_entity_dict(self):
|
def extract_entity_dict(self):
|
||||||
"""提取实体列表"""
|
"""提取实体列表"""
|
||||||
ner_output_dict = dict(
|
ner_output_dict = dict(
|
||||||
|
|||||||
@@ -5,11 +5,13 @@ from .global_logger import logger
|
|||||||
|
|
||||||
# from . import prompt_template
|
# from . import prompt_template
|
||||||
from .embedding_store import EmbeddingManager
|
from .embedding_store import EmbeddingManager
|
||||||
from .llm_client import LLMClient
|
# from .llm_client import LLMClient
|
||||||
from .kg_manager import KGManager
|
from .kg_manager import KGManager
|
||||||
from .lpmmconfig import global_config
|
# from .lpmmconfig import global_config
|
||||||
from .utils.dyn_topk import dyn_select_top_k
|
from .utils.dyn_topk import dyn_select_top_k
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.chat.utils.utils import get_embedding
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||||
|
|
||||||
@@ -19,26 +21,25 @@ class QAManager:
|
|||||||
self,
|
self,
|
||||||
embed_manager: EmbeddingManager,
|
embed_manager: EmbeddingManager,
|
||||||
kg_manager: KGManager,
|
kg_manager: KGManager,
|
||||||
llm_client_embedding: LLMClient,
|
|
||||||
llm_client_filter: LLMClient,
|
|
||||||
llm_client_qa: LLMClient,
|
|
||||||
):
|
):
|
||||||
self.embed_manager = embed_manager
|
self.embed_manager = embed_manager
|
||||||
self.kg_manager = kg_manager
|
self.kg_manager = kg_manager
|
||||||
self.llm_client_list = {
|
# TODO: API-Adapter修改标记
|
||||||
"embedding": llm_client_embedding,
|
self.qa_model = LLMRequest(
|
||||||
"message_filter": llm_client_filter,
|
model=global_config.model.lpmm_qa,
|
||||||
"qa": llm_client_qa,
|
request_type="lpmm.qa"
|
||||||
}
|
)
|
||||||
|
|
||||||
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
||||||
"""处理查询"""
|
"""处理查询"""
|
||||||
|
|
||||||
# 生成问题的Embedding
|
# 生成问题的Embedding
|
||||||
part_start_time = time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
|
question_embedding = get_embedding(question)
|
||||||
global_config["embedding"]["model"], question
|
if question_embedding is None:
|
||||||
)
|
logger.error("生成问题Embedding失败")
|
||||||
|
return None
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
logger.debug(f"Embedding用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|
||||||
@@ -46,14 +47,15 @@ class QAManager:
|
|||||||
part_start_time = time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
|
||||||
question_embedding,
|
question_embedding,
|
||||||
global_config["qa"]["params"]["relation_search_top_k"],
|
global_config.lpmm_knowledge.qa_relation_search_top_k,
|
||||||
)
|
)
|
||||||
if relation_search_res is not None:
|
if relation_search_res is not None:
|
||||||
# 过滤阈值
|
# 过滤阈值
|
||||||
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
|
||||||
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
|
||||||
if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
|
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
|
||||||
# 未找到相关关系
|
# 未找到相关关系
|
||||||
|
logger.debug("未找到相关关系,跳过关系检索")
|
||||||
relation_search_res = []
|
relation_search_res = []
|
||||||
|
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
@@ -71,7 +73,7 @@ class QAManager:
|
|||||||
part_start_time = time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
|
||||||
question_embedding,
|
question_embedding,
|
||||||
global_config["qa"]["params"]["paragraph_search_top_k"],
|
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
|
||||||
)
|
)
|
||||||
part_end_time = time.perf_counter()
|
part_end_time = time.perf_counter()
|
||||||
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")
|
||||||
|
|||||||
@@ -627,3 +627,12 @@ class ModelConfig(ConfigBase):
|
|||||||
|
|
||||||
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
embedding: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
"""嵌入模型配置"""
|
"""嵌入模型配置"""
|
||||||
|
|
||||||
|
lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""LPMM实体提取模型配置"""
|
||||||
|
|
||||||
|
lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""LPMM RDF构建模型配置"""
|
||||||
|
|
||||||
|
lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
|
||||||
|
"""LPMM问答模型配置"""
|
||||||
|
|||||||
Reference in New Issue
Block a user