some typing

This commit is contained in:
UnCLAS-Prommer
2025-07-19 19:14:52 +08:00
parent 57536e60fa
commit 32cb4dc726
5 changed files with 56 additions and 56 deletions

View File

@@ -294,10 +294,10 @@ class EmbeddingStore:
""" """
if self.faiss_index is None: if self.faiss_index is None:
logger.debug("FaissIndex尚未构建,返回None") logger.debug("FaissIndex尚未构建,返回None")
return None return []
if self.idx2hash is None: if self.idx2hash is None:
logger.warning("idx2hash尚未构建,返回None") logger.warning("idx2hash尚未构建,返回None")
return None return []
# L2归一化 # L2归一化
faiss.normalize_L2(np.array([query], dtype=np.float32)) faiss.normalize_L2(np.array([query], dtype=np.float32))
@@ -318,15 +318,15 @@ class EmbeddingStore:
class EmbeddingManager: class EmbeddingManager:
def __init__(self): def __init__(self):
self.paragraphs_embedding_store = EmbeddingStore( self.paragraphs_embedding_store = EmbeddingStore(
local_storage['pg_namespace'], local_storage["pg_namespace"], # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.entities_embedding_store = EmbeddingStore( self.entities_embedding_store = EmbeddingStore(
local_storage['pg_namespace'], local_storage["pg_namespace"], # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.relation_embedding_store = EmbeddingStore( self.relation_embedding_store = EmbeddingStore(
local_storage['pg_namespace'], local_storage["pg_namespace"], # type: ignore
EMBEDDING_DATA_DIR_STR, EMBEDDING_DATA_DIR_STR,
) )
self.stored_pg_hashes = set() self.stored_pg_hashes = set()

View File

@@ -30,7 +30,7 @@ def _get_kg_dir():
""" """
安全地获取KG数据目录路径 安全地获取KG数据目录路径
""" """
root_path = local_storage['root_path'] root_path: str = local_storage["root_path"]
if root_path is None: if root_path is None:
# 如果 local_storage 中没有 root_path使用当前文件的相对路径作为备用 # 如果 local_storage 中没有 root_path使用当前文件的相对路径作为备用
current_dir = os.path.dirname(os.path.abspath(__file__)) current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -38,7 +38,7 @@ def _get_kg_dir():
logger.warning(f"local_storage 中未找到 root_path使用备用路径: {root_path}") logger.warning(f"local_storage 中未找到 root_path使用备用路径: {root_path}")
# 获取RAG数据目录 # 获取RAG数据目录
rag_data_dir = global_config["persistence"]["rag_data_dir"] rag_data_dir: str = global_config["persistence"]["rag_data_dir"]
if rag_data_dir is None: if rag_data_dir is None:
kg_dir = os.path.join(root_path, "data/rag") kg_dir = os.path.join(root_path, "data/rag")
else: else:
@@ -65,9 +65,9 @@ class KGManager:
# 持久化相关 - 使用延迟初始化的路径 # 持久化相关 - 使用延迟初始化的路径
self.dir_path = get_kg_dir_str() self.dir_path = get_kg_dir_str()
self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml" self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml"
self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet" self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet"
self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json" self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json"
def save_to_file(self): def save_to_file(self):
"""将KG数据保存到文件""" """将KG数据保存到文件"""
@@ -91,11 +91,11 @@ class KGManager:
"""从文件加载KG数据""" """从文件加载KG数据"""
# 确保文件存在 # 确保文件存在
if not os.path.exists(self.pg_hash_file_path): if not os.path.exists(self.pg_hash_file_path):
raise Exception(f"KG段落hash文件{self.pg_hash_file_path}不存在") raise FileNotFoundError(f"KG段落hash文件{self.pg_hash_file_path}不存在")
if not os.path.exists(self.ent_cnt_data_path): if not os.path.exists(self.ent_cnt_data_path):
raise Exception(f"KG实体计数文件{self.ent_cnt_data_path}不存在") raise FileNotFoundError(f"KG实体计数文件{self.ent_cnt_data_path}不存在")
if not os.path.exists(self.graph_data_path): if not os.path.exists(self.graph_data_path):
raise Exception(f"KG图文件{self.graph_data_path}不存在") raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在")
# 加载段落hash # 加载段落hash
with open(self.pg_hash_file_path, "r", encoding="utf-8") as f: with open(self.pg_hash_file_path, "r", encoding="utf-8") as f:
@@ -122,8 +122,8 @@ class KGManager:
# 避免自连接 # 避免自连接
continue continue
# 一个triple就是一条边同时构建双向联系 # 一个triple就是一条边同时构建双向联系
hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2]) hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2])
node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0
node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0
entity_set.add(hash_key1) entity_set.add(hash_key1)
@@ -141,8 +141,8 @@ class KGManager:
"""构建实体节点与文段节点之间的关系""" """构建实体节点与文段节点之间的关系"""
for idx in triple_list_data: for idx in triple_list_data:
for triple in triple_list_data[idx]: for triple in triple_list_data[idx]:
ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0])
pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx) pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx)
node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0
@staticmethod @staticmethod
@@ -157,8 +157,8 @@ class KGManager:
ent_hash_list = set() ent_hash_list = set()
for triple_list in triple_list_data.values(): for triple_list in triple_list_data.values():
for triple in triple_list: for triple in triple_list:
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0])) ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0]))
ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2])) ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2]))
ent_hash_list = list(ent_hash_list) ent_hash_list = list(ent_hash_list)
synonym_hash_set = set() synonym_hash_set = set()
@@ -263,7 +263,7 @@ class KGManager:
for src_tgt in node_to_node.keys(): for src_tgt in node_to_node.keys():
for node_hash in src_tgt: for node_hash in src_tgt:
if node_hash not in existed_nodes: if node_hash not in existed_nodes:
if node_hash.startswith(local_storage['ent_namespace']): if node_hash.startswith(local_storage["ent_namespace"]):
# 新增实体节点 # 新增实体节点
node = embedding_manager.entities_embedding_store.store.get(node_hash) node = embedding_manager.entities_embedding_store.store.get(node_hash)
if node is None: if node is None:
@@ -275,7 +275,7 @@ class KGManager:
node_item["type"] = "ent" node_item["type"] = "ent"
node_item["create_time"] = now_time node_item["create_time"] = now_time
self.graph.update_node(node_item) self.graph.update_node(node_item)
elif node_hash.startswith(local_storage['pg_namespace']): elif node_hash.startswith(local_storage["pg_namespace"]):
# 新增文段节点 # 新增文段节点
node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) node = embedding_manager.paragraphs_embedding_store.store.get(node_hash)
if node is None: if node is None:
@@ -359,7 +359,7 @@ class KGManager:
# 关系三元组 # 关系三元组
triple = relation[2:-2].split("', '") triple = relation[2:-2].split("', '")
for ent in [(triple[0]), (triple[2])]: for ent in [(triple[0]), (triple[2])]:
ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent) ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent)
if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash in existed_nodes: # 该实体需在KG中存在
if ent_hash not in ent_sim_scores: # 尚未记录的实体 if ent_hash not in ent_sim_scores: # 尚未记录的实体
ent_sim_scores[ent_hash] = [] ent_sim_scores[ent_hash] = []
@@ -437,7 +437,9 @@ class KGManager:
# 获取最终结果 # 获取最终结果
# 从搜索结果中提取文段节点的结果 # 从搜索结果中提取文段节点的结果
passage_node_res = [ passage_node_res = [
(node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace']) (node_key, score)
for node_key, score in ppr_res.items()
if node_key.startswith(local_storage["pg_namespace"])
] ]
del ppr_res del ppr_res

View File

@@ -1,5 +1,3 @@
from .llm_client import LLMMessage
entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体并以JSON列表的形式输出。 entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体并以JSON列表的形式输出。
输出格式示例: 输出格式示例:
@@ -63,10 +61,10 @@ qa_system_prompt = """
""" """
def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]: # def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]:
knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) # knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)])
messages = [ # messages = [
LLMMessage("system", qa_system_prompt).to_dict(), # LLMMessage("system", qa_system_prompt).to_dict(),
LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(), # LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息\n{knowledge}").to_dict(),
] # ]
return messages # return messages

View File

@@ -484,25 +484,25 @@ class MessageSending(MessageProcessBase):
if self.message_segment: if self.message_segment:
self.processed_plain_text = await self._process_message_segments(self.message_segment) self.processed_plain_text = await self._process_message_segments(self.message_segment)
@classmethod # @classmethod
def from_thinking( # def from_thinking(
cls, # cls,
thinking: MessageThinking, # thinking: MessageThinking,
message_segment: Seg, # message_segment: Seg,
is_head: bool = False, # is_head: bool = False,
is_emoji: bool = False, # is_emoji: bool = False,
) -> "MessageSending": # ) -> "MessageSending":
"""从思考状态消息创建发送状态消息""" # """从思考状态消息创建发送状态消息"""
return cls( # return cls(
message_id=thinking.message_info.message_id, # type: ignore # message_id=thinking.message_info.message_id, # type: ignore
chat_stream=thinking.chat_stream, # chat_stream=thinking.chat_stream,
message_segment=message_segment, # message_segment=message_segment,
bot_user_info=thinking.message_info.user_info, # type: ignore # bot_user_info=thinking.message_info.user_info, # type: ignore
reply=thinking.reply, # reply=thinking.reply,
is_head=is_head, # is_head=is_head,
is_emoji=is_emoji, # is_emoji=is_emoji,
sender_info=None, # sender_info=None,
) # )
def to_dict(self): def to_dict(self):
ret = super().to_dict() ret = super().to_dict()

View File

@@ -262,4 +262,4 @@ class ActionManager:
""" """
from src.plugin_system.core.component_registry import component_registry from src.plugin_system.core.component_registry import component_registry
return component_registry.get_component_class(action_name) # type: ignore return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore