feat: 重构信息提取模块,移除LLMClient依赖,改为使用LLMRequest,优化数据加载和处理逻辑

This commit is contained in:
墨梓柒
2025-07-15 16:54:25 +08:00
parent 2b76dc2e21
commit f15e074cca
8 changed files with 119 additions and 177 deletions

View File

@@ -10,7 +10,7 @@ import pandas as pd
# import tqdm
import faiss
from .llm_client import LLMClient
# from .llm_client import LLMClient
from .lpmmconfig import global_config
from .utils.hash import get_sha256
from .global_logger import logger
@@ -295,7 +295,7 @@ class EmbeddingStore:
class EmbeddingManager:
def __init__(self, llm_client: LLMClient):
def __init__(self):
self.paragraphs_embedding_store = EmbeddingStore(
local_storage['pg_namespace'],
EMBEDDING_DATA_DIR_STR,

View File

@@ -7,25 +7,34 @@ from . import prompt_template
from .lpmmconfig import global_config, INVALID_ENTITY
from .llm_client import LLMClient
from src.chat.knowledge.utils.json_fix import new_fix_broken_generated_json
from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json
def _extract_json_from_text(text: str) -> dict:
"""从文本中提取JSON数据的高容错方法"""
try:
fixed_json = repair_json(text)
if isinstance(fixed_json, str):
parsed_json = json.loads(fixed_json)
else:
parsed_json = fixed_json
if isinstance(parsed_json, list) and parsed_json:
parsed_json = parsed_json[0]
def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
if isinstance(parsed_json, dict):
return parsed_json
except Exception as e:
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
_, request_result = llm_client.send_chat_request(
global_config["entity_extract"]["llm"]["model"], entity_extract_context
)
# 去除‘{’前的内容(结果中可能有多个‘{
if "[" in request_result:
request_result = request_result[request_result.index("[") :]
# 去除最后一个‘}’后的内容(结果中可能有多个‘}
if "]" in request_result:
request_result = request_result[: request_result.rindex("]") + 1]
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据
json.loads(entity_extract_result)
entity_extract_result = [
entity
for entity in entity_extract_result
@@ -38,23 +47,16 @@ def _entity_extract(llm_client: LLMClient, paragraph: str) -> List[str]:
return entity_extract_result
def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -> List[List[str]]:
def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> List[List[str]]:
"""对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_rdf_triple_extract_context(
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False)
)
_, request_result = llm_client.send_chat_request(global_config["rdf_build"]["llm"]["model"], entity_extract_context)
# 去除‘{’前的内容(结果中可能有多个‘{
if "[" in request_result:
request_result = request_result[request_result.index("[") :]
# 去除最后一个‘}’后的内容(结果中可能有多个‘}
if "]" in request_result:
request_result = request_result[: request_result.rindex("]") + 1]
entity_extract_result = json.loads(new_fix_broken_generated_json(request_result))
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据
json.loads(entity_extract_result)
for triple in entity_extract_result:
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
raise Exception("RDF提取结果格式错误")
@@ -63,7 +65,7 @@ def _rdf_triple_extract(llm_client: LLMClient, paragraph: str, entities: list) -
def info_extract_from_str(
llm_client_for_ner: LLMClient, llm_client_for_rdf: LLMClient, paragraph: str
llm_client_for_ner: LLMRequest, llm_client_for_rdf: LLMRequest, paragraph: str
) -> Union[tuple[None, None], tuple[list[str], list[list[str]]]]:
try_count = 0
while True:

View File

@@ -31,6 +31,7 @@ RAG_PG_HASH_NAMESPACE = "rag-pg-hash"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
DATA_PATH = os.path.join(ROOT_PATH, "data")
def _initialize_knowledge_local_storage():
"""
@@ -42,10 +43,6 @@ def _initialize_knowledge_local_storage():
# 路径配置
'root_path': ROOT_PATH,
'data_path': f"{ROOT_PATH}/data",
'lpmm_raw_data_path': f"{ROOT_PATH}/data/raw_data",
'lpmm_openie_data_path': f"{ROOT_PATH}/data/openie",
'lpmm_embedding_data_dir': f"{ROOT_PATH}/data/embedding",
'lpmm_rag_data_dir': f"{ROOT_PATH}/data/rag",
# 实体和命名空间配置
'lpmm_invalid_entity': INVALID_ENTITY,

View File

@@ -4,9 +4,8 @@ import glob
from typing import Any, Dict, List
from .lpmmconfig import INVALID_ENTITY, global_config
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
from .knowledge_lib import INVALID_ENTITY, ROOT_PATH, DATA_PATH
# from src.manager.local_store_manager import local_storage
def _filter_invalid_entities(entities: List[str]) -> List[str]:
@@ -107,7 +106,7 @@ class OpenIE:
@staticmethod
def load() -> "OpenIE":
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
openie_dir = os.path.join(DATA_PATH, "openie")
if not os.path.exists(openie_dir):
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
@@ -122,12 +121,6 @@ class OpenIE:
openie_data = OpenIE._from_dict(data_list)
return openie_data
@staticmethod
def save(openie_data: "OpenIE"):
"""保存OpenIE数据到文件"""
with open(global_config["persistence"]["openie_data_path"], "w", encoding="utf-8") as f:
f.write(json.dumps(openie_data._to_dict(), ensure_ascii=False, indent=4))
def extract_entity_dict(self):
"""提取实体列表"""
ner_output_dict = dict(

View File

@@ -5,11 +5,13 @@ from .global_logger import logger
# from . import prompt_template
from .embedding_store import EmbeddingManager
from .llm_client import LLMClient
# from .llm_client import LLMClient
from .kg_manager import KGManager
from .lpmmconfig import global_config
# from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding
from src.config.config import global_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@@ -19,26 +21,25 @@ class QAManager:
self,
embed_manager: EmbeddingManager,
kg_manager: KGManager,
llm_client_embedding: LLMClient,
llm_client_filter: LLMClient,
llm_client_qa: LLMClient,
):
self.embed_manager = embed_manager
self.kg_manager = kg_manager
self.llm_client_list = {
"embedding": llm_client_embedding,
"message_filter": llm_client_filter,
"qa": llm_client_qa,
}
# TODO: API-Adapter修改标记
self.qa_model = LLMRequest(
model=global_config.model.lpmm_qa,
request_type="lpmm.qa"
)
def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
"""处理查询"""
# 生成问题的Embedding
part_start_time = time.perf_counter()
question_embedding = self.llm_client_list["embedding"].send_embedding_request(
global_config["embedding"]["model"], question
)
question_embedding = get_embedding(question)
if question_embedding is None:
logger.error("生成问题Embedding失败")
return None
part_end_time = time.perf_counter()
logger.debug(f"Embedding用时{part_end_time - part_start_time:.5f}s")
@@ -46,14 +47,15 @@ class QAManager:
part_start_time = time.perf_counter()
relation_search_res = self.embed_manager.relation_embedding_store.search_top_k(
question_embedding,
global_config["qa"]["params"]["relation_search_top_k"],
global_config.lpmm_knowledge.qa_relation_search_top_k,
)
if relation_search_res is not None:
# 过滤阈值
# 考虑动态阈值:当存在显著数值差异的结果时,保留显著结果;否则,保留所有结果
relation_search_res = dyn_select_top_k(relation_search_res, 0.5, 1.0)
if relation_search_res[0][1] < global_config["qa"]["params"]["relation_threshold"]:
if relation_search_res[0][1] < global_config.lpmm_knowledge.qa_relation_threshold:
# 未找到相关关系
logger.debug("未找到相关关系,跳过关系检索")
relation_search_res = []
part_end_time = time.perf_counter()
@@ -71,7 +73,7 @@ class QAManager:
part_start_time = time.perf_counter()
paragraph_search_res = self.embed_manager.paragraphs_embedding_store.search_top_k(
question_embedding,
global_config["qa"]["params"]["paragraph_search_top_k"],
global_config.lpmm_knowledge.qa_paragraph_search_top_k,
)
part_end_time = time.perf_counter()
logger.debug(f"文段检索用时:{part_end_time - part_start_time:.5f}s")

View File

@@ -627,3 +627,12 @@ class ModelConfig(ConfigBase):
embedding: dict[str, Any] = field(default_factory=lambda: {})
"""嵌入模型配置"""
lpmm_entity_extract: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM实体提取模型配置"""
lpmm_rdf_build: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM RDF构建模型配置"""
lpmm_qa: dict[str, Any] = field(default_factory=lambda: {})
"""LPMM问答模型配置"""