some typing
This commit is contained in:
@@ -106,10 +106,10 @@ class EmbeddingStore:
|
|||||||
asyncio.get_running_loop()
|
asyncio.get_running_loop()
|
||||||
# 如果在事件循环中,使用线程池执行
|
# 如果在事件循环中,使用线程池执行
|
||||||
import concurrent.futures
|
import concurrent.futures
|
||||||
|
|
||||||
def run_in_thread():
|
def run_in_thread():
|
||||||
return asyncio.run(get_embedding(s))
|
return asyncio.run(get_embedding(s))
|
||||||
|
|
||||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
future = executor.submit(run_in_thread)
|
future = executor.submit(run_in_thread)
|
||||||
result = future.result()
|
result = future.result()
|
||||||
@@ -294,10 +294,10 @@ class EmbeddingStore:
|
|||||||
"""
|
"""
|
||||||
if self.faiss_index is None:
|
if self.faiss_index is None:
|
||||||
logger.debug("FaissIndex尚未构建,返回None")
|
logger.debug("FaissIndex尚未构建,返回None")
|
||||||
return None
|
return []
|
||||||
if self.idx2hash is None:
|
if self.idx2hash is None:
|
||||||
logger.warning("idx2hash尚未构建,返回None")
|
logger.warning("idx2hash尚未构建,返回None")
|
||||||
return None
|
return []
|
||||||
|
|
||||||
# L2归一化
|
# L2归一化
|
||||||
faiss.normalize_L2(np.array([query], dtype=np.float32))
|
faiss.normalize_L2(np.array([query], dtype=np.float32))
|
||||||
@@ -318,15 +318,15 @@ class EmbeddingStore:
|
|||||||
class EmbeddingManager:
|
class EmbeddingManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.paragraphs_embedding_store = EmbeddingStore(
|
self.paragraphs_embedding_store = EmbeddingStore(
|
||||||
local_storage['pg_namespace'],
|
local_storage["pg_namespace"], # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
)
|
)
|
||||||
self.entities_embedding_store = EmbeddingStore(
|
self.entities_embedding_store = EmbeddingStore(
|
||||||
local_storage['pg_namespace'],
|
local_storage["pg_namespace"], # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
)
|
)
|
||||||
self.relation_embedding_store = EmbeddingStore(
|
self.relation_embedding_store = EmbeddingStore(
|
||||||
local_storage['pg_namespace'],
|
local_storage["pg_namespace"], # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
)
|
)
|
||||||
self.stored_pg_hashes = set()
|
self.stored_pg_hashes = set()
|
||||||
|
|||||||
@@ -30,20 +30,20 @@ def _get_kg_dir():
|
|||||||
"""
|
"""
|
||||||
安全地获取KG数据目录路径
|
安全地获取KG数据目录路径
|
||||||
"""
|
"""
|
||||||
root_path = local_storage['root_path']
|
root_path: str = local_storage["root_path"]
|
||||||
if root_path is None:
|
if root_path is None:
|
||||||
# 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用
|
# 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
root_path = os.path.abspath(os.path.join(current_dir, "..", "..", ".."))
|
||||||
logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}")
|
logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}")
|
||||||
|
|
||||||
# 获取RAG数据目录
|
# 获取RAG数据目录
|
||||||
rag_data_dir = global_config["persistence"]["rag_data_dir"]
|
rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
|
||||||
if rag_data_dir is None:
|
if rag_data_dir is None:
|
||||||
kg_dir = os.path.join(root_path, "data/rag")
|
kg_dir = os.path.join(root_path, "data/rag")
|
||||||
else:
|
else:
|
||||||
kg_dir = os.path.join(root_path, rag_data_dir)
|
kg_dir = os.path.join(root_path, rag_data_dir)
|
||||||
|
|
||||||
return str(kg_dir).replace("\\", "/")
|
return str(kg_dir).replace("\\", "/")
|
||||||
|
|
||||||
|
|
||||||
@@ -65,9 +65,9 @@ class KGManager:
|
|||||||
|
|
||||||
# 持久化相关 - 使用延迟初始化的路径
|
# 持久化相关 - 使用延迟初始化的路径
|
||||||
self.dir_path = get_kg_dir_str()
|
self.dir_path = get_kg_dir_str()
|
||||||
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml"
|
self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml"
|
||||||
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet"
|
self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet"
|
||||||
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json"
|
self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json"
|
||||||
|
|
||||||
def save_to_file(self):
|
def save_to_file(self):
|
||||||
"""将KG数据保存到文件"""
|
"""将KG数据保存到文件"""
|
||||||
@@ -91,11 +91,11 @@ class KGManager:
|
|||||||
"""从文件加载KG数据"""
|
"""从文件加载KG数据"""
|
||||||
# 确保文件存在
|
# 确保文件存在
|
||||||
if not os.path.exists(self.pg_hash_file_path):
|
if not os.path.exists(self.pg_hash_file_path):
|
||||||
raise Exception(f"KG段落hash文件{self.pg_hash_file_path}不存在")
|
raise FileNotFoundError(f"KG段落hash文件{self.pg_hash_file_path}不存在")
|
||||||
if not os.path.exists(self.ent_cnt_data_path):
|
if not os.path.exists(self.ent_cnt_data_path):
|
||||||
raise Exception(f"KG实体计数文件{self.ent_cnt_data_path}不存在")
|
raise FileNotFoundError(f"KG实体计数文件{self.ent_cnt_data_path}不存在")
|
||||||
if not os.path.exists(self.graph_data_path):
|
if not os.path.exists(self.graph_data_path):
|
||||||
raise Exception(f"KG图文件{self.graph_data_path}不存在")
|
raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
|
||||||
|
|
||||||
# 加载段落hash
|
# 加载段落hash
|
||||||
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
|
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
|
||||||
@@ -122,8 +122,8 @@ class KGManager:
|
|||||||
# 避免自连接
|
# 避免自连接
|
||||||
continue
|
continue
|
||||||
# 一个triple就是一条边(同时构建双向联系)
|
# 一个triple就是一条边(同时构建双向联系)
|
||||||
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
|
||||||
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2])
|
hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2])
|
||||||
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
|
||||||
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
|
||||||
entity_set.add(hash_key1)
|
entity_set.add(hash_key1)
|
||||||
@@ -141,8 +141,8 @@ class KGManager:
|
|||||||
"""构建实体节点与文段节点之间的关系"""
|
"""构建实体节点与文段节点之间的关系"""
|
||||||
for idx in triple_list_data:
|
for idx in triple_list_data:
|
||||||
for triple in triple_list_data[idx]:
|
for triple in triple_list_data[idx]:
|
||||||
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0])
|
ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
|
||||||
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx)
|
pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx)
|
||||||
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -157,8 +157,8 @@ class KGManager:
|
|||||||
ent_hash_list = set()
|
ent_hash_list = set()
|
||||||
for triple_list in triple_list_data.values():
|
for triple_list in triple_list_data.values():
|
||||||
for triple in triple_list:
|
for triple in triple_list:
|
||||||
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0]))
|
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0]))
|
||||||
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2]))
|
ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2]))
|
||||||
ent_hash_list = list(ent_hash_list)
|
ent_hash_list = list(ent_hash_list)
|
||||||
|
|
||||||
synonym_hash_set = set()
|
synonym_hash_set = set()
|
||||||
@@ -263,7 +263,7 @@ class KGManager:
|
|||||||
for src_tgt in node_to_node.keys():
|
for src_tgt in node_to_node.keys():
|
||||||
for node_hash in src_tgt:
|
for node_hash in src_tgt:
|
||||||
if node_hash not in existed_nodes:
|
if node_hash not in existed_nodes:
|
||||||
if node_hash.startswith(local_storage['ent_namespace']):
|
if node_hash.startswith(local_storage["ent_namespace"]):
|
||||||
# 新增实体节点
|
# 新增实体节点
|
||||||
node = embedding_manager.entities_embedding_store.store.get(node_hash)
|
node = embedding_manager.entities_embedding_store.store.get(node_hash)
|
||||||
if node is None:
|
if node is None:
|
||||||
@@ -275,7 +275,7 @@ class KGManager:
|
|||||||
node_item["type"] = "ent"
|
node_item["type"] = "ent"
|
||||||
node_item["create_time"] = now_time
|
node_item["create_time"] = now_time
|
||||||
self.graph.update_node(node_item)
|
self.graph.update_node(node_item)
|
||||||
elif node_hash.startswith(local_storage['pg_namespace']):
|
elif node_hash.startswith(local_storage["pg_namespace"]):
|
||||||
# 新增文段节点
|
# 新增文段节点
|
||||||
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
|
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
|
||||||
if node is None:
|
if node is None:
|
||||||
@@ -359,7 +359,7 @@ class KGManager:
|
|||||||
# 关系三元组
|
# 关系三元组
|
||||||
triple = relation[2:-2].split("', '")
|
triple = relation[2:-2].split("', '")
|
||||||
for ent in [(triple[0]), (triple[2])]:
|
for ent in [(triple[0]), (triple[2])]:
|
||||||
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent)
|
ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent)
|
||||||
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
if ent_hash in existed_nodes: # 该实体需在KG中存在
|
||||||
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
if ent_hash not in ent_sim_scores: # 尚未记录的实体
|
||||||
ent_sim_scores[ent_hash] = []
|
ent_sim_scores[ent_hash] = []
|
||||||
@@ -437,7 +437,9 @@ class KGManager:
|
|||||||
# 获取最终结果
|
# 获取最终结果
|
||||||
# 从搜索结果中提取文段节点的结果
|
# 从搜索结果中提取文段节点的结果
|
||||||
passage_node_res = [
|
passage_node_res = [
|
||||||
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace'])
|
(node_key, score)
|
||||||
|
for node_key, score in ppr_res.items()
|
||||||
|
if node_key.startswith(local_storage["pg_namespace"])
|
||||||
]
|
]
|
||||||
del ppr_res
|
del ppr_res
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,3 @@
|
|||||||
from .llm_client import LLMMessage
|
|
||||||
|
|
||||||
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。
|
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。
|
||||||
|
|
||||||
输出格式示例:
|
输出格式示例:
|
||||||
@@ -63,10 +61,10 @@ qa_system_prompt = """
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
|
# def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
|
||||||
knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
|
# knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
|
||||||
messages = [
|
# messages = [
|
||||||
LLMMessage("system", qa_system_prompt).to_dict(),
|
# LLMMessage("system", qa_system_prompt).to_dict(),
|
||||||
LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(),
|
# LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(),
|
||||||
]
|
# ]
|
||||||
return messages
|
# return messages
|
||||||
|
|||||||
@@ -484,25 +484,25 @@ class MessageSending(MessageProcessBase):
|
|||||||
if self.message_segment:
|
if self.message_segment:
|
||||||
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
self.processed_plain_text = await self._process_message_segments(self.message_segment)
|
||||||
|
|
||||||
@classmethod
|
# @classmethod
|
||||||
def from_thinking(
|
# def from_thinking(
|
||||||
cls,
|
# cls,
|
||||||
thinking: MessageThinking,
|
# thinking: MessageThinking,
|
||||||
message_segment: Seg,
|
# message_segment: Seg,
|
||||||
is_head: bool = False,
|
# is_head: bool = False,
|
||||||
is_emoji: bool = False,
|
# is_emoji: bool = False,
|
||||||
) -> "MessageSending":
|
# ) -> "MessageSending":
|
||||||
"""从思考状态消息创建发送状态消息"""
|
# """从思考状态消息创建发送状态消息"""
|
||||||
return cls(
|
# return cls(
|
||||||
message_id=thinking.message_info.message_id, # type: ignore
|
# message_id=thinking.message_info.message_id, # type: ignore
|
||||||
chat_stream=thinking.chat_stream,
|
# chat_stream=thinking.chat_stream,
|
||||||
message_segment=message_segment,
|
# message_segment=message_segment,
|
||||||
bot_user_info=thinking.message_info.user_info, # type: ignore
|
# bot_user_info=thinking.message_info.user_info, # type: ignore
|
||||||
reply=thinking.reply,
|
# reply=thinking.reply,
|
||||||
is_head=is_head,
|
# is_head=is_head,
|
||||||
is_emoji=is_emoji,
|
# is_emoji=is_emoji,
|
||||||
sender_info=None,
|
# sender_info=None,
|
||||||
)
|
# )
|
||||||
|
|
||||||
def to_dict(self):
|
def to_dict(self):
|
||||||
ret = super().to_dict()
|
ret = super().to_dict()
|
||||||
|
|||||||
@@ -262,4 +262,4 @@ class ActionManager:
|
|||||||
"""
|
"""
|
||||||
from src.plugin_system.core.component_registry import component_registry
|
from src.plugin_system.core.component_registry import component_registry
|
||||||
|
|
||||||
return component_registry.get_component_class(action_name) # type: ignore
|
return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore
|
||||||
|
|||||||
Reference in New Issue
Block a user