diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 808b8013b..d732683ae 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -106,10 +106,10 @@ class EmbeddingStore: asyncio.get_running_loop() # 如果在事件循环中,使用线程池执行 import concurrent.futures - + def run_in_thread(): return asyncio.run(get_embedding(s)) - + with concurrent.futures.ThreadPoolExecutor() as executor: future = executor.submit(run_in_thread) result = future.result() @@ -294,10 +294,10 @@ class EmbeddingStore: """ if self.faiss_index is None: logger.debug("FaissIndex尚未构建,返回None") - return None + return [] if self.idx2hash is None: logger.warning("idx2hash尚未构建,返回None") - return None + return [] # L2归一化 faiss.normalize_L2(np.array([query], dtype=np.float32)) @@ -318,15 +318,15 @@ class EmbeddingStore: class EmbeddingManager: def __init__(self): self.paragraphs_embedding_store = EmbeddingStore( - local_storage['pg_namespace'], + local_storage["pg_namespace"], # type: ignore EMBEDDING_DATA_DIR_STR, ) self.entities_embedding_store = EmbeddingStore( - local_storage['pg_namespace'], + local_storage["pg_namespace"], # type: ignore EMBEDDING_DATA_DIR_STR, ) self.relation_embedding_store = EmbeddingStore( - local_storage['pg_namespace'], + local_storage["pg_namespace"], # type: ignore EMBEDDING_DATA_DIR_STR, ) self.stored_pg_hashes = set() diff --git a/src/chat/knowledge/kg_manager.py b/src/chat/knowledge/kg_manager.py index e18a7da80..083a741d6 100644 --- a/src/chat/knowledge/kg_manager.py +++ b/src/chat/knowledge/kg_manager.py @@ -30,20 +30,20 @@ def _get_kg_dir(): """ 安全地获取KG数据目录路径 """ - root_path = local_storage['root_path'] + root_path: str = local_storage["root_path"] if root_path is None: # 如果 local_storage 中没有 root_path,使用当前文件的相对路径作为备用 current_dir = os.path.dirname(os.path.abspath(__file__)) root_path = os.path.abspath(os.path.join(current_dir, "..", "..", "..")) logger.warning(f"local_storage 中未找到 root_path,使用备用路径: {root_path}") - + # 获取RAG数据目录 - rag_data_dir = global_config["persistence"]["rag_data_dir"] + rag_data_dir: str = global_config["persistence"]["rag_data_dir"] if rag_data_dir is None: kg_dir = os.path.join(root_path, "data/rag") else: kg_dir = os.path.join(root_path, rag_data_dir) - + return str(kg_dir).replace("\\", "/") @@ -65,9 +65,9 @@ class KGManager: # 持久化相关 - 使用延迟初始化的路径 self.dir_path = get_kg_dir_str() - self.graph_data_path = self.dir_path + "/" + local_storage['rag_graph_namespace'] + ".graphml" - self.ent_cnt_data_path = self.dir_path + "/" + local_storage['rag_ent_cnt_namespace'] + ".parquet" - self.pg_hash_file_path = self.dir_path + "/" + local_storage['rag_pg_hash_namespace'] + ".json" + self.graph_data_path = self.dir_path + "/" + local_storage["rag_graph_namespace"] + ".graphml" + self.ent_cnt_data_path = self.dir_path + "/" + local_storage["rag_ent_cnt_namespace"] + ".parquet" + self.pg_hash_file_path = self.dir_path + "/" + local_storage["rag_pg_hash_namespace"] + ".json" def save_to_file(self): """将KG数据保存到文件""" @@ -91,11 +91,11 @@ class KGManager: """从文件加载KG数据""" # 确保文件存在 if not os.path.exists(self.pg_hash_file_path): - raise Exception(f"KG段落hash文件{self.pg_hash_file_path}不存在") + raise FileNotFoundError(f"KG段落hash文件{self.pg_hash_file_path}不存在") if not os.path.exists(self.ent_cnt_data_path): - raise Exception(f"KG实体计数文件{self.ent_cnt_data_path}不存在") + raise FileNotFoundError(f"KG实体计数文件{self.ent_cnt_data_path}不存在") if not os.path.exists(self.graph_data_path): - raise Exception(f"KG图文件{self.graph_data_path}不存在") + raise FileNotFoundError(f"KG图文件{self.graph_data_path}不存在") # 加载段落hash with open(self.pg_hash_file_path, "r", encoding="utf-8") as f: @@ -122,8 +122,8 @@ class KGManager: # 避免自连接 continue # 一个triple就是一条边(同时构建双向联系) - hash_key1 = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) - hash_key2 = local_storage['ent_namespace'] + "-" + get_sha256(triple[2]) + hash_key1 = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) + hash_key2 = local_storage["ent_namespace"] + "-" + get_sha256(triple[2]) node_to_node[(hash_key1, hash_key2)] = node_to_node.get((hash_key1, hash_key2), 0) + 1.0 node_to_node[(hash_key2, hash_key1)] = node_to_node.get((hash_key2, hash_key1), 0) + 1.0 entity_set.add(hash_key1) @@ -141,8 +141,8 @@ class KGManager: """构建实体节点与文段节点之间的关系""" for idx in triple_list_data: for triple in triple_list_data[idx]: - ent_hash_key = local_storage['ent_namespace'] + "-" + get_sha256(triple[0]) - pg_hash_key = local_storage['pg_namespace'] + "-" + str(idx) + ent_hash_key = local_storage["ent_namespace"] + "-" + get_sha256(triple[0]) + pg_hash_key = local_storage["pg_namespace"] + "-" + str(idx) node_to_node[(ent_hash_key, pg_hash_key)] = node_to_node.get((ent_hash_key, pg_hash_key), 0) + 1.0 @staticmethod @@ -157,8 +157,8 @@ class KGManager: ent_hash_list = set() for triple_list in triple_list_data.values(): for triple in triple_list: - ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[0])) - ent_hash_list.add(local_storage['ent_namespace'] + "-" + get_sha256(triple[2])) + ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[0])) + ent_hash_list.add(local_storage["ent_namespace"] + "-" + get_sha256(triple[2])) ent_hash_list = list(ent_hash_list) synonym_hash_set = set() @@ -263,7 +263,7 @@ class KGManager: for src_tgt in node_to_node.keys(): for node_hash in src_tgt: if node_hash not in existed_nodes: - if node_hash.startswith(local_storage['ent_namespace']): + if node_hash.startswith(local_storage["ent_namespace"]): # 新增实体节点 node = embedding_manager.entities_embedding_store.store.get(node_hash) if node is None: @@ -275,7 +275,7 @@ class KGManager: node_item["type"] = "ent" node_item["create_time"] = now_time self.graph.update_node(node_item) - elif node_hash.startswith(local_storage['pg_namespace']): + elif node_hash.startswith(local_storage["pg_namespace"]): # 新增文段节点 node = embedding_manager.paragraphs_embedding_store.store.get(node_hash) if node is None: @@ -359,7 +359,7 @@ class KGManager: # 关系三元组 triple = relation[2:-2].split("', '") for ent in [(triple[0]), (triple[2])]: - ent_hash = local_storage['ent_namespace'] + "-" + get_sha256(ent) + ent_hash = local_storage["ent_namespace"] + "-" + get_sha256(ent) if ent_hash in existed_nodes: # 该实体需在KG中存在 if ent_hash not in ent_sim_scores: # 尚未记录的实体 ent_sim_scores[ent_hash] = [] @@ -437,7 +437,9 @@ class KGManager: # 获取最终结果 # 从搜索结果中提取文段节点的结果 passage_node_res = [ - (node_key, score) for node_key, score in ppr_res.items() if node_key.startswith(local_storage['pg_namespace']) + (node_key, score) + for node_key, score in ppr_res.items() + if node_key.startswith(local_storage["pg_namespace"]) ] del ppr_res diff --git a/src/chat/knowledge/prompt_template.py b/src/chat/knowledge/prompt_template.py index fe5a293c0..485103aad 100644 --- a/src/chat/knowledge/prompt_template.py +++ b/src/chat/knowledge/prompt_template.py @@ -1,5 +1,3 @@ -from .llm_client import LLMMessage - entity_extract_system_prompt = """你是一个性能优异的实体提取系统。请从段落中提取出所有实体,并以JSON列表的形式输出。 输出格式示例: @@ -63,10 +61,10 @@ qa_system_prompt = """ """ -def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]: - knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) - messages = [ - LLMMessage("system", qa_system_prompt).to_dict(), - LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(), - ] - return messages +# def build_qa_context(question: str, knowledge: list[tuple[str, str, str]]) -> list[LLMMessage]: +# knowledge = "\n".join([f"{i + 1}. 相关性:{k[0]}\n{k[1]}" for i, k in enumerate(knowledge)]) +# messages = [ +# LLMMessage("system", qa_system_prompt).to_dict(), +# LLMMessage("user", f"问题:\n{question}\n\n可能有帮助的信息:\n{knowledge}").to_dict(), +# ] +# return messages diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index 1346e73c5..36737eb77 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -484,25 +484,25 @@ class MessageSending(MessageProcessBase): if self.message_segment: self.processed_plain_text = await self._process_message_segments(self.message_segment) - @classmethod - def from_thinking( - cls, - thinking: MessageThinking, - message_segment: Seg, - is_head: bool = False, - is_emoji: bool = False, - ) -> "MessageSending": - """从思考状态消息创建发送状态消息""" - return cls( - message_id=thinking.message_info.message_id, # type: ignore - chat_stream=thinking.chat_stream, - message_segment=message_segment, - bot_user_info=thinking.message_info.user_info, # type: ignore - reply=thinking.reply, - is_head=is_head, - is_emoji=is_emoji, - sender_info=None, - ) + # @classmethod + # def from_thinking( + # cls, + # thinking: MessageThinking, + # message_segment: Seg, + # is_head: bool = False, + # is_emoji: bool = False, + # ) -> "MessageSending": + # """从思考状态消息创建发送状态消息""" + # return cls( + # message_id=thinking.message_info.message_id, # type: ignore + # chat_stream=thinking.chat_stream, + # message_segment=message_segment, + # bot_user_info=thinking.message_info.user_info, # type: ignore + # reply=thinking.reply, + # is_head=is_head, + # is_emoji=is_emoji, + # sender_info=None, + # ) def to_dict(self): ret = super().to_dict() diff --git a/src/chat/planner_actions/action_manager.py b/src/chat/planner_actions/action_manager.py index a4876a46d..a2f4c37bd 100644 --- a/src/chat/planner_actions/action_manager.py +++ b/src/chat/planner_actions/action_manager.py @@ -262,4 +262,4 @@ class ActionManager: """ from src.plugin_system.core.component_registry import component_registry - return component_registry.get_component_class(action_name) # type: ignore + return component_registry.get_component_class(action_name, ComponentType.ACTION) # type: ignore