knowledge系统对应修改
This commit is contained in:
@@ -19,14 +19,10 @@ class CompareNumbersTool(BaseTool):
|
|||||||
|
|
||||||
name = "compare_numbers"
|
name = "compare_numbers"
|
||||||
description = "使用工具 比较两个数的大小,返回较大的数"
|
description = "使用工具 比较两个数的大小,返回较大的数"
|
||||||
parameters = {
|
parameters = [
|
||||||
"type": "object",
|
("num1", "number", "第一个数字", True),
|
||||||
"properties": {
|
("num2", "number", "第二个数字", True),
|
||||||
"num1": {"type": "number", "description": "第一个数字"},
|
]
|
||||||
"num2": {"type": "number", "description": "第二个数字"},
|
|
||||||
},
|
|
||||||
"required": ["num1", "num2"],
|
|
||||||
}
|
|
||||||
|
|
||||||
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""执行比较两个数的大小
|
"""执行比较两个数的大小
|
||||||
|
|||||||
@@ -8,7 +8,10 @@ from . import prompt_template
|
|||||||
from .knowledge_lib import INVALID_ENTITY
|
from .knowledge_lib import INVALID_ENTITY
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
|
||||||
def _extract_json_from_text(text: str):
|
def _extract_json_from_text(text: str):
|
||||||
|
# sourcery skip: assign-if-exp, extract-method
|
||||||
"""从文本中提取JSON数据的高容错方法"""
|
"""从文本中提取JSON数据的高容错方法"""
|
||||||
if text is None:
|
if text is None:
|
||||||
logger.error("输入文本为None")
|
logger.error("输入文本为None")
|
||||||
@@ -42,7 +45,9 @@ def _extract_json_from_text(text: str):
|
|||||||
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
|
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||||
|
# sourcery skip: reintroduce-else, swap-if-else-branches, use-named-expression
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||||
|
|
||||||
@@ -50,15 +55,11 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
|||||||
try:
|
try:
|
||||||
# 如果当前已有事件循环在运行,使用它
|
# 如果当前已有事件循环在运行,使用它
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(entity_extract_context), loop)
|
||||||
llm_req.generate_response_async(entity_extract_context), loop
|
response, _ = future.result()
|
||||||
)
|
|
||||||
response, (reasoning_content, model_name) = future.result()
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||||
response, (reasoning_content, model_name) = asyncio.run(
|
response, _ = asyncio.run(llm_req.generate_response_async(entity_extract_context))
|
||||||
llm_req.generate_response_async(entity_extract_context)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加调试日志
|
# 添加调试日志
|
||||||
logger.debug(f"LLM返回的原始响应: {response}")
|
logger.debug(f"LLM返回的原始响应: {response}")
|
||||||
@@ -67,19 +68,17 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
|||||||
|
|
||||||
# 检查返回的是否为有效的实体列表
|
# 检查返回的是否为有效的实体列表
|
||||||
if not isinstance(entity_extract_result, list):
|
if not isinstance(entity_extract_result, list):
|
||||||
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
if not isinstance(entity_extract_result, dict):
|
||||||
if isinstance(entity_extract_result, dict):
|
raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||||
|
|
||||||
# 尝试常见的键名
|
# 尝试常见的键名
|
||||||
for key in ['entities', 'result', 'data', 'items']:
|
for key in ["entities", "result", "data", "items"]:
|
||||||
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
|
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
|
||||||
entity_extract_result = entity_extract_result[key]
|
entity_extract_result = entity_extract_result[key]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# 如果找不到合适的列表,抛出异常
|
# 如果找不到合适的列表,抛出异常
|
||||||
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
raise ValueError(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||||
else:
|
|
||||||
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
|
||||||
|
|
||||||
# 过滤无效实体
|
# 过滤无效实体
|
||||||
entity_extract_result = [
|
entity_extract_result = [
|
||||||
entity
|
entity
|
||||||
@@ -87,8 +86,8 @@ def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
|||||||
if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY)
|
if (entity is not None) and (entity != "") and (entity not in INVALID_ENTITY)
|
||||||
]
|
]
|
||||||
|
|
||||||
if len(entity_extract_result) == 0:
|
if not entity_extract_result:
|
||||||
raise Exception("实体提取结果为空")
|
raise ValueError("实体提取结果为空")
|
||||||
|
|
||||||
return entity_extract_result
|
return entity_extract_result
|
||||||
|
|
||||||
@@ -103,15 +102,11 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
|||||||
try:
|
try:
|
||||||
# 如果当前已有事件循环在运行,使用它
|
# 如果当前已有事件循环在运行,使用它
|
||||||
loop = asyncio.get_running_loop()
|
loop = asyncio.get_running_loop()
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
future = asyncio.run_coroutine_threadsafe(llm_req.generate_response_async(rdf_extract_context), loop)
|
||||||
llm_req.generate_response_async(rdf_extract_context), loop
|
response, _ = future.result()
|
||||||
)
|
|
||||||
response, (reasoning_content, model_name) = future.result()
|
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||||
response, (reasoning_content, model_name) = asyncio.run(
|
response, _ = asyncio.run(llm_req.generate_response_async(rdf_extract_context))
|
||||||
llm_req.generate_response_async(rdf_extract_context)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加调试日志
|
# 添加调试日志
|
||||||
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
||||||
@@ -120,23 +115,26 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
|||||||
|
|
||||||
# 检查返回的是否为有效的三元组列表
|
# 检查返回的是否为有效的三元组列表
|
||||||
if not isinstance(rdf_triple_result, list):
|
if not isinstance(rdf_triple_result, list):
|
||||||
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
if not isinstance(rdf_triple_result, dict):
|
||||||
if isinstance(rdf_triple_result, dict):
|
raise ValueError(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||||
|
|
||||||
# 尝试常见的键名
|
# 尝试常见的键名
|
||||||
for key in ['triples', 'result', 'data', 'items']:
|
for key in ["triples", "result", "data", "items"]:
|
||||||
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
|
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
|
||||||
rdf_triple_result = rdf_triple_result[key]
|
rdf_triple_result = rdf_triple_result[key]
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
# 如果找不到合适的列表,抛出异常
|
# 如果找不到合适的列表,抛出异常
|
||||||
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
raise ValueError(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||||
else:
|
|
||||||
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
|
||||||
|
|
||||||
# 验证三元组格式
|
# 验证三元组格式
|
||||||
for triple in rdf_triple_result:
|
for triple in rdf_triple_result:
|
||||||
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
if (
|
||||||
raise Exception("RDF提取结果格式错误")
|
not isinstance(triple, list)
|
||||||
|
or len(triple) != 3
|
||||||
|
or (triple[0] is None or triple[1] is None or triple[2] is None)
|
||||||
|
or "" in triple
|
||||||
|
):
|
||||||
|
raise ValueError("RDF提取结果格式错误")
|
||||||
|
|
||||||
return rdf_triple_result
|
return rdf_triple_result
|
||||||
|
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ class KGManager:
|
|||||||
ent_hash_list = list(ent_hash_list)
|
ent_hash_list = list(ent_hash_list)
|
||||||
|
|
||||||
synonym_hash_set = set()
|
synonym_hash_set = set()
|
||||||
synonym_result = dict()
|
synonym_result = {}
|
||||||
|
|
||||||
# rich 进度条
|
# rich 进度条
|
||||||
total = len(ent_hash_list)
|
total = len(ent_hash_list)
|
||||||
|
|||||||
@@ -5,13 +5,15 @@ from .global_logger import logger
|
|||||||
|
|
||||||
# from . import prompt_template
|
# from . import prompt_template
|
||||||
from .embedding_store import EmbeddingManager
|
from .embedding_store import EmbeddingManager
|
||||||
|
|
||||||
# from .llm_client import LLMClient
|
# from .llm_client import LLMClient
|
||||||
from .kg_manager import KGManager
|
from .kg_manager import KGManager
|
||||||
|
|
||||||
# from .lpmmconfig import global_config
|
# from .lpmmconfig import global_config
|
||||||
from .utils.dyn_topk import dyn_select_top_k
|
from .utils.dyn_topk import dyn_select_top_k
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
|
|
||||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||||
|
|
||||||
@@ -21,15 +23,10 @@ class QAManager:
|
|||||||
self,
|
self,
|
||||||
embed_manager: EmbeddingManager,
|
embed_manager: EmbeddingManager,
|
||||||
kg_manager: KGManager,
|
kg_manager: KGManager,
|
||||||
|
|
||||||
):
|
):
|
||||||
self.embed_manager = embed_manager
|
self.embed_manager = embed_manager
|
||||||
self.kg_manager = kg_manager
|
self.kg_manager = kg_manager
|
||||||
# TODO: API-Adapter修改标记
|
self.qa_model = LLMRequest(model_set=model_config.model_task_config.lpmm_qa, request_type="lpmm.qa")
|
||||||
self.qa_model = LLMRequest(
|
|
||||||
model=global_config.model.lpmm_qa,
|
|
||||||
request_type="lpmm.qa"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
async def process_query(self, question: str) -> Tuple[List[Tuple[str, float, float]], Optional[Dict[str, float]]]:
|
||||||
"""处理查询"""
|
"""处理查询"""
|
||||||
|
|||||||
Reference in New Issue
Block a user