knowledge系统对应修改

This commit is contained in:
UnCLAS-Prommer
2025-07-31 13:38:56 +08:00
parent 37e52a1566
commit 52acfe5958
4 changed files with 58 additions and 67 deletions

View File

@@ -19,14 +19,10 @@ class CompareNumbersTool(BaseTool):
name = "compare_numbers" name = "compare_numbers"
description = "使用工具 比较两个数的大小,返回较大的数" description = "使用工具 比较两个数的大小,返回较大的数"
parameters = { parameters = [
"type": "object", ("num1", "number", "第一个数字", True),
"properties": { ("num2", "number", "第二个数字", True),
"num1": {"type": "number", "description": "第一个数字"}, ]
"num2": {"type": "number", "description": "第二个数字"},
},
"required": ["num1", "num2"],
}
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""执行比较两个数的大小 """执行比较两个数的大小

View File

@@ -8,7 +8,10 @@ from . import prompt_template
from .knowledge_lib import INVALID_ENTITY from .knowledge_lib import INVALID_ENTITY
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from json_repair import repair_json from json_repair import repair_json
def _extract_json_from_text(text: str): def _extract_json_from_text(text: str):
# sourcery skip: assign-if-exp, extract-method
"""从文本中提取JSON数据的高容错方法""" """从文本中提取JSON数据的高容错方法"""
if text is None: if text is None:
logger.error("输入文本为None") logger.error("输入文本为None")
@@ -42,7 +45,9 @@ def _extract_json_from_text(text: str):
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...") logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
return [] return []
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
"""对段落进行实体提取返回提取出的实体列表JSON格式""" """对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph) entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
@@ -50,15 +55,11 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
try: try:
# 如果当前已有事件循环在运行,使用它 # 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
llm_req.generate_response_async(entity_extract_context), loop response, _ = future.result()
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError: except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run # 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run( response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
llm_req.generate_response_async(entity_extract_context)
)
# 添加调试日志 # 添加调试日志
logger.debug(f"LLM返回的原始响应: {response}") logger.debug(f"LLM返回的原始响应: {response}")
@@ -67,19 +68,17 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
# 检查返回的是否为有效的实体列表 # 检查返回的是否为有效的实体列表
if not isinstance(entity_extract_result, list): if not isinstance(entity_extract_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表 if not isinstance(entity_extract_result, dict):
if isinstance(entity_extract_result, dict): raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
# 尝试常见的键名 # 尝试常见的键名
for key in ['entities', 'result', 'data', 'items']: for key in ["entities", "result", "data", "items"]:
if key in entity_extract_result and isinstance(entity_extract_result[key], list): if key in entity_extract_result and isinstance(entity_extract_result[key], list):
entity_extract_result = entity_extract_result[key] entity_extract_result = entity_extract_result[key]
break break
else: else:
# 如果找不到合适的列表,抛出异常 # 如果找不到合适的列表,抛出异常
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
else:
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
# 过滤无效实体 # 过滤无效实体
entity_extract_result = [ entity_extract_result = [
entity entity
@@ -87,8 +86,8 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY) if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY)
] ]
if len(entity_extract_result) == 0: if not entity_extract_result:
raise Exception("实体提取结果为空") raise ValueError("实体提取结果为空")
return entity_extract_result return entity_extract_result
@@ -103,15 +102,11 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
try: try:
# 如果当前已有事件循环在运行,使用它 # 如果当前已有事件循环在运行,使用它
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
future = asyncio.run_coroutine_threadsafe( future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
llm_req.generate_response_async(rdf_extract_context), loop response, _ = future.result()
)
response, (reasoning_content, model_name) = future.result()
except RuntimeError: except RuntimeError:
# 如果没有运行中的事件循环,直接使用 asyncio.run # 如果没有运行中的事件循环,直接使用 asyncio.run
response, (reasoning_content, model_name) = asyncio.run( response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
llm_req.generate_response_async(rdf_extract_context)
)
# 添加调试日志 # 添加调试日志
logger.debug(f"RDF LLM返回的原始响应: {response}") logger.debug(f"RDF LLM返回的原始响应: {response}")
@@ -120,23 +115,26 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
# 检查返回的是否为有效的三元组列表 # 检查返回的是否为有效的三元组列表
if not isinstance(rdf_triple_result, list): if not isinstance(rdf_triple_result, list):
# 如果不是列表,可能是字典格式,尝试从中提取列表 if not isinstance(rdf_triple_result, dict):
if isinstance(rdf_triple_result, dict): raise ValueError(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
# 尝试常见的键名 # 尝试常见的键名
for key in ['triples', 'result', 'data', 'items']: for key in ["triples", "result", "data", "items"]:
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list): if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
rdf_triple_result = rdf_triple_result[key] rdf_triple_result = rdf_triple_result[key]
break break
else: else:
# 如果找不到合适的列表,抛出异常 # 如果找不到合适的列表,抛出异常
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}") raise ValueError(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
else:
raise Exception(f"RDF三元组提取结果格式错误期望列表但得到: {type(rdf_triple_result)}")
# 验证三元组格式 # 验证三元组格式
for triple in rdf_triple_result: for triple in rdf_triple_result:
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: if (
raise Exception("RDF提取结果格式错误") not isinstance(triple, list)
or len(triple) != 3
or (triple[0] is None or triple[1] is None or triple[2] is None)
or "" in triple
):
raise ValueError("RDF提取结果格式错误")
return rdf_triple_result return rdf_triple_result

View File

@@ -162,7 +162,7 @@ class KGManager:
ent_hash_list = list(ent_hash_list) ent_hash_list = list(ent_hash_list)
synonym_hash_set = set() synonym_hash_set = set()
synonym_result = dict() synonym_result = {}
# rich 进度条 # rich 进度条
total = len(ent_hash_list) total = len(ent_hash_list)

View File

@@ -5,13 +5,15 @@ from .global_logger import logger
# from . import prompt_template # from . import prompt_template
from .embedding_store import EmbeddingManager from .embedding_store import EmbeddingManager
# from .llm_client import LLMClient # from .llm_client import LLMClient
from .kg_manager import KGManager from .kg_manager import KGManager
# from .lpmmconfig import global_config # from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding
from src.config.config import global_config from src.config.config import global_config, model_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@@ -21,15 +23,10 @@ class QAManager:
self, self,
embed_manager: EmbeddingManager, embed_manager: EmbeddingManager,
kg_manager: KGManager, kg_manager: KGManager,
): ):
self.embed_manager = embed_manager self.embed_manager = embed_manager
self.kg_manager = kg_manager self.kg_manager = kg_manager
# TODO: API-Adapter修改标记 self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
self.qa_model = LLMRequest(
model=global_config.model.lpmm_qa,
request_type="lpmm.qa"
)
async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]: async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
"""处理查询""" """处理查询"""