Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
||||
import json
|
||||
import os
|
||||
import math
|
||||
import asyncio
|
||||
from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
@@ -99,7 +100,30 @@ class EmbeddingStore:
|
||||
self.idx2hash = None
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
return get_embedding(s)
|
||||
"""获取字符串的嵌入向量,处理异步调用"""
|
||||
try:
|
||||
# 尝试获取当前事件循环
|
||||
asyncio.get_running_loop()
|
||||
# 如果在事件循环中,使用线程池执行
|
||||
import concurrent.futures
|
||||
|
||||
def run_in_thread():
|
||||
return asyncio.run(get_embedding(s))
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(run_in_thread)
|
||||
result = future.result()
|
||||
if result is None:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
except RuntimeError:
|
||||
# 没有运行的事件循环,直接运行
|
||||
result = asyncio.run(get_embedding(s))
|
||||
if result is None:
|
||||
logger.error(f"获取嵌入失败: {s}")
|
||||
return []
|
||||
return result
|
||||
|
||||
def get_test_file_path(self):
|
||||
return EMBEDDING_TEST_FILE
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import List, Union
|
||||
@@ -7,8 +8,12 @@ from . import prompt_template
|
||||
from .knowledge_lib import INVALID_ENTITY
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from json_repair import repair_json
|
||||
def _extract_json_from_text(text: str) -> dict:
|
||||
def _extract_json_from_text(text: str):
|
||||
"""从文本中提取JSON数据的高容错方法"""
|
||||
if text is None:
|
||||
logger.error("输入文本为None")
|
||||
return []
|
||||
|
||||
try:
|
||||
fixed_json = repair_json(text)
|
||||
if isinstance(fixed_json, str):
|
||||
@@ -16,23 +21,66 @@ def _extract_json_from_text(text: str) -> dict:
|
||||
else:
|
||||
parsed_json = fixed_json
|
||||
|
||||
if isinstance(parsed_json, list) and parsed_json:
|
||||
parsed_json = parsed_json[0]
|
||||
|
||||
if isinstance(parsed_json, dict):
|
||||
# 如果是列表,直接返回
|
||||
if isinstance(parsed_json, list):
|
||||
return parsed_json
|
||||
|
||||
# 如果是字典且只有一个项目,可能包装了列表
|
||||
if isinstance(parsed_json, dict):
|
||||
# 如果字典只有一个键,并且值是列表,返回那个列表
|
||||
if len(parsed_json) == 1:
|
||||
value = list(parsed_json.values())[0]
|
||||
if isinstance(value, list):
|
||||
return value
|
||||
return parsed_json
|
||||
|
||||
# 其他情况,尝试转换为列表
|
||||
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
|
||||
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
|
||||
return []
|
||||
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
|
||||
|
||||
# 使用 asyncio.run 来运行异步方法
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
llm_req.generate_response_async(entity_extract_context), loop
|
||||
)
|
||||
response, (reasoning_content, model_name) = future.result()
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, (reasoning_content, model_name) = asyncio.run(
|
||||
llm_req.generate_response_async(entity_extract_context)
|
||||
)
|
||||
|
||||
# 添加调试日志
|
||||
logger.debug(f"LLM返回的原始响应: {response}")
|
||||
|
||||
entity_extract_result = _extract_json_from_text(response)
|
||||
# 尝试load JSON数据
|
||||
json.loads(entity_extract_result)
|
||||
|
||||
# 检查返回的是否为有效的实体列表
|
||||
if not isinstance(entity_extract_result, list):
|
||||
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
||||
if isinstance(entity_extract_result, dict):
|
||||
# 尝试常见的键名
|
||||
for key in ['entities', 'result', 'data', 'items']:
|
||||
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
|
||||
entity_extract_result = entity_extract_result[key]
|
||||
break
|
||||
else:
|
||||
# 如果找不到合适的列表,抛出异常
|
||||
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||
else:
|
||||
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||
|
||||
# 过滤无效实体
|
||||
entity_extract_result = [
|
||||
entity
|
||||
for entity in entity_extract_result
|
||||
@@ -50,16 +98,47 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||
)
|
||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
|
||||
|
||||
# 使用 asyncio.run 来运行异步方法
|
||||
try:
|
||||
# 如果当前已有事件循环在运行,使用它
|
||||
loop = asyncio.get_running_loop()
|
||||
future = asyncio.run_coroutine_threadsafe(
|
||||
llm_req.generate_response_async(rdf_extract_context), loop
|
||||
)
|
||||
response, (reasoning_content, model_name) = future.result()
|
||||
except RuntimeError:
|
||||
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||
response, (reasoning_content, model_name) = asyncio.run(
|
||||
llm_req.generate_response_async(rdf_extract_context)
|
||||
)
|
||||
|
||||
entity_extract_result = _extract_json_from_text(response)
|
||||
# 尝试load JSON数据
|
||||
json.loads(entity_extract_result)
|
||||
for triple in entity_extract_result:
|
||||
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||
# 添加调试日志
|
||||
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
||||
|
||||
rdf_triple_result = _extract_json_from_text(response)
|
||||
|
||||
# 检查返回的是否为有效的三元组列表
|
||||
if not isinstance(rdf_triple_result, list):
|
||||
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
||||
if isinstance(rdf_triple_result, dict):
|
||||
# 尝试常见的键名
|
||||
for key in ['triples', 'result', 'data', 'items']:
|
||||
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
|
||||
rdf_triple_result = rdf_triple_result[key]
|
||||
break
|
||||
else:
|
||||
# 如果找不到合适的列表,抛出异常
|
||||
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||
else:
|
||||
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||
|
||||
# 验证三元组格式
|
||||
for triple in rdf_triple_result:
|
||||
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||
raise Exception("RDF提取结果格式错误")
|
||||
|
||||
return entity_extract_result
|
||||
return rdf_triple_result
|
||||
|
||||
|
||||
def info_extract_from_str(
|
||||
|
||||
@@ -184,10 +184,10 @@ class KGManager:
|
||||
progress.update(task, advance=1)
|
||||
continue
|
||||
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
|
||||
assert isinstance(ent, EmbeddingStoreItem)
|
||||
if ent is None:
|
||||
progress.update(task, advance=1)
|
||||
continue
|
||||
assert isinstance(ent, EmbeddingStoreItem)
|
||||
# 查询相似实体
|
||||
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
||||
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
|
||||
@@ -265,7 +265,10 @@ class KGManager:
|
||||
if node_hash not in existed_nodes:
|
||||
if node_hash.startswith(local_storage['ent_namespace']):
|
||||
# 新增实体节点
|
||||
node = embedding_manager.entities_embedding_store.store[node_hash]
|
||||
node = embedding_manager.entities_embedding_store.store.get(node_hash)
|
||||
if node is None:
|
||||
logger.warning(f"实体节点 {node_hash} 在嵌入库中不存在,跳过")
|
||||
continue
|
||||
assert isinstance(node, EmbeddingStoreItem)
|
||||
node_item = self.graph[node_hash]
|
||||
node_item["content"] = node.str
|
||||
@@ -274,7 +277,10 @@ class KGManager:
|
||||
self.graph.update_node(node_item)
|
||||
elif node_hash.startswith(local_storage['pg_namespace']):
|
||||
# 新增文段节点
|
||||
node = embedding_manager.paragraphs_embedding_store.store[node_hash]
|
||||
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
|
||||
if node is None:
|
||||
logger.warning(f"段落节点 {node_hash} 在嵌入库中不存在,跳过")
|
||||
continue
|
||||
assert isinstance(node, EmbeddingStoreItem)
|
||||
content = node.str.replace("\n", " ")
|
||||
node_item = self.graph[node_hash]
|
||||
|
||||
@@ -11,12 +11,14 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统
|
||||
"""
|
||||
|
||||
|
||||
def build_entity_extract_context(paragraph: str) -> list[LLMMessage]:
|
||||
messages = [
|
||||
LLMMessage("system", entity_extract_system_prompt).to_dict(),
|
||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
|
||||
]
|
||||
return messages
|
||||
def build_entity_extract_context(paragraph: str) -> str:
|
||||
"""构建实体提取的完整提示文本"""
|
||||
return f"""{entity_extract_system_prompt}
|
||||
|
||||
段落:
|
||||
```
|
||||
{paragraph}
|
||||
```"""
|
||||
|
||||
|
||||
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
|
||||
@@ -36,12 +38,19 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描
|
||||
"""
|
||||
|
||||
|
||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]:
|
||||
messages = [
|
||||
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
|
||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
|
||||
]
|
||||
return messages
|
||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> str:
|
||||
"""构建RDF三元组提取的完整提示文本"""
|
||||
return f"""{rdf_triple_extract_system_prompt}
|
||||
|
||||
段落:
|
||||
```
|
||||
{paragraph}
|
||||
```
|
||||
|
||||
实体列表:
|
||||
```
|
||||
{entities}
|
||||
```"""
|
||||
|
||||
|
||||
qa_system_prompt = """
|
||||
|
||||
Reference in New Issue
Block a user