diff --git a/bot.py b/bot.py index 1a5e6694b..5548c1725 100644 --- a/bot.py +++ b/bot.py @@ -8,6 +8,7 @@ if os.path.exists(".env"): print("成功加载环境变量配置") else: print("未找到.env文件,请确保程序所需的环境变量被正确设置") + raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") import sys import time import platform @@ -140,81 +141,85 @@ async def graceful_shutdown(): logger.error(f"麦麦关闭失败: {e}", exc_info=True) +def _calculate_file_hash(file_path: Path, file_type: str) -> str: + """计算文件的MD5哈希值""" + if not file_path.exists(): + logger.error(f"{file_type} 文件不存在") + raise FileNotFoundError(f"{file_type} 文件不存在") + + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + return hashlib.md5(content.encode("utf-8")).hexdigest() + + +def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]: + """检查协议确认状态 + + Returns: + tuple[bool, bool]: (已确认, 未更新) + """ + # 检查环境变量确认 + if file_hash == os.getenv(env_var): + return True, False + + # 检查确认文件 + if confirm_file.exists(): + with open(confirm_file, "r", encoding="utf-8") as f: + confirmed_content = f.read() + if file_hash == confirmed_content: + return True, False + + return False, True + + +def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None: + """提示用户确认协议""" + confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议") + confirm_logger.critical( + f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_hash}"和"PRIVACY_AGREE={privacy_hash}"继续运行' + ) + + while True: + user_input = input().strip().lower() + if user_input in ["同意", "confirmed"]: + return + confirm_logger.critical('请输入"同意"或"confirmed"以继续运行') + + +def _save_confirmations(eula_updated: bool, privacy_updated: bool, + eula_hash: str, privacy_hash: str) -> None: + """保存用户确认结果""" + if eula_updated: + logger.info(f"更新EULA确认文件{eula_hash}") + Path("eula.confirmed").write_text(eula_hash, encoding="utf-8") + + if privacy_updated: + logger.info(f"更新隐私条款确认文件{privacy_hash}") + Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8") + + def check_eula(): - eula_confirm_file = Path("eula.confirmed") - privacy_confirm_file = Path("privacy.confirmed") - eula_file = Path("EULA.md") - privacy_file = Path("PRIVACY.md") - - eula_updated = True - privacy_updated = True - - eula_confirmed = False - privacy_confirmed = False - - # 首先计算当前EULA文件的哈希值 - if eula_file.exists(): - with open(eula_file, "r", encoding="utf-8") as f: - eula_content = f.read() - eula_new_hash = hashlib.md5(eula_content.encode("utf-8")).hexdigest() - else: - logger.error("EULA.md 文件不存在") - raise FileNotFoundError("EULA.md 文件不存在") - - # 首先计算当前隐私条款文件的哈希值 - if privacy_file.exists(): - with open(privacy_file, "r", encoding="utf-8") as f: - privacy_content = f.read() - privacy_new_hash = hashlib.md5(privacy_content.encode("utf-8")).hexdigest() - else: - logger.error("PRIVACY.md 文件不存在") - raise FileNotFoundError("PRIVACY.md 文件不存在") - - # 检查EULA确认文件是否存在 - if eula_confirm_file.exists(): - with open(eula_confirm_file, "r", encoding="utf-8") as f: - confirmed_content = f.read() - if eula_new_hash == confirmed_content: - eula_confirmed = True - eula_updated = False - if eula_new_hash == os.getenv("EULA_AGREE"): - eula_confirmed = True - eula_updated = False - - # 检查隐私条款确认文件是否存在 - if privacy_confirm_file.exists(): - with open(privacy_confirm_file, "r", encoding="utf-8") as f: - confirmed_content = f.read() - if privacy_new_hash == confirmed_content: - privacy_confirmed = True - privacy_updated = False - if privacy_new_hash == os.getenv("PRIVACY_AGREE"): - privacy_confirmed = True - privacy_updated = False - - # 如果EULA或隐私条款有更新,提示用户重新确认 + """检查EULA和隐私条款确认状态""" + # 计算文件哈希值 + eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md") + privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md") + + # 检查确认状态 + eula_confirmed, eula_updated = _check_agreement_status( + eula_hash, Path("eula.confirmed"), "EULA_AGREE" + ) + privacy_confirmed, privacy_updated = _check_agreement_status( + privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE" + ) + + # 早期返回:如果都已确认且未更新 + if eula_confirmed and privacy_confirmed: + return + + # 如果有更新,需要重新确认 if eula_updated or privacy_updated: - confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议") - confirm_logger.critical( - f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_new_hash}"和"PRIVACY_AGREE={privacy_new_hash}"继续运行' - ) - while True: - user_input = input().strip().lower() - if user_input in ["同意", "confirmed"]: - # print("确认成功,继续运行") - # print(f"确认成功,继续运行{eula_updated} {privacy_updated}") - if eula_updated: - logger.info(f"更新EULA确认文件{eula_new_hash}") - eula_confirm_file.write_text(eula_new_hash, encoding="utf-8") - if privacy_updated: - logger.info(f"更新隐私条款确认文件{privacy_new_hash}") - privacy_confirm_file.write_text(privacy_new_hash, encoding="utf-8") - break - else: - confirm_logger.critical('请输入"同意"或"confirmed"以继续运行') - return - elif eula_confirmed and privacy_confirmed: - return + _prompt_user_confirmation(eula_hash, privacy_hash) + _save_confirmations(eula_updated, privacy_updated, eula_hash, privacy_hash) def raw_main(): diff --git a/scripts/import_openie.py b/scripts/import_openie.py index 791c64672..63a4d9852 100644 --- a/scripts/import_openie.py +++ b/scripts/import_openie.py @@ -15,6 +15,7 @@ from src.chat.knowledge.kg_manager import KGManager from src.common.logger import get_logger from src.chat.knowledge.utils.hash import get_sha256 from src.manager.local_store_manager import local_storage +from dotenv import load_dotenv # 添加项目根目录到 sys.path @@ -23,6 +24,45 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie") logger = get_logger("OpenIE导入") +ENV_FILE = os.path.join(ROOT_PATH, ".env") + +if os.path.exists(".env"): + load_dotenv(".env", override=True) + print("成功加载环境变量配置") +else: + print("未找到.env文件,请确保程序所需的环境变量被正确设置") + raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") + +env_mask = {key: os.getenv(key) for key in os.environ} +def scan_provider(env_config: dict): + provider = {} + + # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 + # 避免 GPG_KEY 这样的变量干扰检查 + env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) + + # 遍历 env_config 的所有键 + for key in env_config: + # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 + if key.endswith("_BASE_URL") or key.endswith("_KEY"): + # 提取 provider 名称 + provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 + + # 初始化 provider 的字典(如果尚未初始化) + if provider_name not in provider: + provider[provider_name] = {"url": None, "key": None} + + # 根据键的类型填充 url 或 key + if key.endswith("_BASE_URL"): + provider[provider_name]["url"] = env_config[key] + elif key.endswith("_KEY"): + provider[provider_name]["key"] = env_config[key] + + # 检查每个 provider 是否同时存在 url 和 key + for provider_name, config in provider.items(): + if config["url"] is None or config["key"] is None: + logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") + raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") def ensure_openie_dir(): """确保OpenIE数据目录存在""" @@ -174,6 +214,8 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k def main(): # sourcery skip: dict-comprehension # 新增确认提示 + env_config = {key: os.getenv(key) for key in os.environ} + scan_provider(env_config) print("=== 重要操作确认 ===") print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型") print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快") diff --git a/scripts/info_extraction.py b/scripts/info_extraction.py index 90f0c80ea..c36a77892 100644 --- a/scripts/info_extraction.py +++ b/scripts/info_extraction.py @@ -27,6 +27,7 @@ from rich.progress import ( from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data from src.config.config import global_config from src.llm_models.utils_model import LLMRequest +from dotenv import load_dotenv logger = get_logger("LPMM知识库-信息提取") @@ -35,6 +36,45 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) TEMP_DIR = os.path.join(ROOT_PATH, "temp") # IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data") OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie") +ENV_FILE = os.path.join(ROOT_PATH, ".env") + +if os.path.exists(".env"): + load_dotenv(".env", override=True) + print("成功加载环境变量配置") +else: + print("未找到.env文件,请确保程序所需的环境变量被正确设置") + raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量") + +env_mask = {key: os.getenv(key) for key in os.environ} +def scan_provider(env_config: dict): + provider = {} + + # 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重 + # 避免 GPG_KEY 这样的变量干扰检查 + env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items())) + + # 遍历 env_config 的所有键 + for key in env_config: + # 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式 + if key.endswith("_BASE_URL") or key.endswith("_KEY"): + # 提取 provider 名称 + provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分 + + # 初始化 provider 的字典(如果尚未初始化) + if provider_name not in provider: + provider[provider_name] = {"url": None, "key": None} + + # 根据键的类型填充 url 或 key + if key.endswith("_BASE_URL"): + provider[provider_name]["url"] = env_config[key] + elif key.endswith("_KEY"): + provider[provider_name]["key"] = env_config[key] + + # 检查每个 provider 是否同时存在 url 和 key + for provider_name, config in provider.items(): + if config["url"] is None or config["key"] is None: + logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}") + raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量") def ensure_dirs(): """确保临时目录和输出目录存在""" @@ -118,6 +158,8 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method # 设置信号处理器 signal.signal(signal.SIGINT, signal_handler) ensure_dirs() # 确保目录存在 + env_config = {key: os.getenv(key) for key in os.environ} + scan_provider(env_config) # 新增用户确认提示 print("=== 重要操作确认,请认真阅读以下内容哦 ===") print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。") diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 3eb466d21..808b8013b 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import json import os import math +import asyncio from typing import Dict, List, Tuple import numpy as np @@ -99,7 +100,30 @@ class EmbeddingStore: self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - return get_embedding(s) + """获取字符串的嵌入向量,处理异步调用""" + try: + # 尝试获取当前事件循环 + asyncio.get_running_loop() + # 如果在事件循环中,使用线程池执行 + import concurrent.futures + + def run_in_thread(): + return asyncio.run(get_embedding(s)) + + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(run_in_thread) + result = future.result() + if result is None: + logger.error(f"获取嵌入失败: {s}") + return [] + return result + except RuntimeError: + # 没有运行的事件循环,直接运行 + result = asyncio.run(get_embedding(s)) + if result is None: + logger.error(f"获取嵌入失败: {s}") + return [] + return result def get_test_file_path(self): return EMBEDDING_TEST_FILE diff --git a/src/chat/knowledge/ie_process.py b/src/chat/knowledge/ie_process.py index bd0e17684..16d4e0804 100644 --- a/src/chat/knowledge/ie_process.py +++ b/src/chat/knowledge/ie_process.py @@ -1,3 +1,4 @@ +import asyncio import json import time from typing import List, Union @@ -7,8 +8,12 @@ from . import prompt_template from .knowledge_lib import INVALID_ENTITY from src.llm_models.utils_model import LLMRequest from json_repair import repair_json -def _extract_json_from_text(text: str) -> dict: +def _extract_json_from_text(text: str): """从文本中提取JSON数据的高容错方法""" + if text is None: + logger.error("输入文本为None") + return [] + try: fixed_json = repair_json(text) if isinstance(fixed_json, str): @@ -16,23 +21,66 @@ def _extract_json_from_text(text: str) -> dict: else: parsed_json = fixed_json - if isinstance(parsed_json, list) and parsed_json: - parsed_json = parsed_json[0] - - if isinstance(parsed_json, dict): + # 如果是列表,直接返回 + if isinstance(parsed_json, list): return parsed_json + + # 如果是字典且只有一个项目,可能包装了列表 + if isinstance(parsed_json, dict): + # 如果字典只有一个键,并且值是列表,返回那个列表 + if len(parsed_json) == 1: + value = list(parsed_json.values())[0] + if isinstance(value, list): + return value + return parsed_json + + # 其他情况,尝试转换为列表 + logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}") + return [] except Exception as e: - logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...") + logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...") + return [] def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: """对段落进行实体提取,返回提取出的实体列表(JSON格式)""" entity_extract_context = prompt_template.build_entity_extract_context(paragraph) - response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context) + + # 使用 asyncio.run 来运行异步方法 + try: + # 如果当前已有事件循环在运行,使用它 + loop = asyncio.get_running_loop() + future = asyncio.run_coroutine_threadsafe( + llm_req.generate_response_async(entity_extract_context), loop + ) + response, (reasoning_content, model_name) = future.result() + except RuntimeError: + # 如果没有运行中的事件循环,直接使用 asyncio.run + response, (reasoning_content, model_name) = asyncio.run( + llm_req.generate_response_async(entity_extract_context) + ) + # 添加调试日志 + logger.debug(f"LLM返回的原始响应: {response}") + entity_extract_result = _extract_json_from_text(response) - # 尝试load JSON数据 - json.loads(entity_extract_result) + + # 检查返回的是否为有效的实体列表 + if not isinstance(entity_extract_result, list): + # 如果不是列表,可能是字典格式,尝试从中提取列表 + if isinstance(entity_extract_result, dict): + # 尝试常见的键名 + for key in ['entities', 'result', 'data', 'items']: + if key in entity_extract_result and isinstance(entity_extract_result[key], list): + entity_extract_result = entity_extract_result[key] + break + else: + # 如果找不到合适的列表,抛出异常 + raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") + else: + raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}") + + # 过滤无效实体 entity_extract_result = [ entity for entity in entity_extract_result @@ -50,16 +98,47 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) -> rdf_extract_context = prompt_template.build_rdf_triple_extract_context( paragraph, entities=json.dumps(entities, ensure_ascii=False) ) - response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context) + + # 使用 asyncio.run 来运行异步方法 + try: + # 如果当前已有事件循环在运行,使用它 + loop = asyncio.get_running_loop() + future = asyncio.run_coroutine_threadsafe( + llm_req.generate_response_async(rdf_extract_context), loop + ) + response, (reasoning_content, model_name) = future.result() + except RuntimeError: + # 如果没有运行中的事件循环,直接使用 asyncio.run + response, (reasoning_content, model_name) = asyncio.run( + llm_req.generate_response_async(rdf_extract_context) + ) - entity_extract_result = _extract_json_from_text(response) - # 尝试load JSON数据 - json.loads(entity_extract_result) - for triple in entity_extract_result: - if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: + # 添加调试日志 + logger.debug(f"RDF LLM返回的原始响应: {response}") + + rdf_triple_result = _extract_json_from_text(response) + + # 检查返回的是否为有效的三元组列表 + if not isinstance(rdf_triple_result, list): + # 如果不是列表,可能是字典格式,尝试从中提取列表 + if isinstance(rdf_triple_result, dict): + # 尝试常见的键名 + for key in ['triples', 'result', 'data', 'items']: + if key in rdf_triple_result and isinstance(rdf_triple_result[key], list): + rdf_triple_result = rdf_triple_result[key] + break + else: + # 如果找不到合适的列表,抛出异常 + raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}") + else: + raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}") + + # 验证三元组格式 + for triple in rdf_triple_result: + if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple: raise Exception("RDF提取结果格式错误") - return entity_extract_result + return rdf_triple_result def info_extract_from_str( diff --git a/src/chat/knowledge/prompt_template.py b/src/chat/knowledge/prompt_template.py index 14a360083..fe5a293c0 100644 --- a/src/chat/knowledge/prompt_template.py +++ b/src/chat/knowledge/prompt_template.py @@ -11,12 +11,14 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统 """ -def build_entity_extract_context(paragraph: str) -> list[LLMMessage]: - messages = [ - LLMMessage("system", entity_extract_system_prompt).to_dict(), - LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(), - ] - return messages +def build_entity_extract_context(paragraph: str) -> str: + """构建实体提取的完整提示文本""" + return f"""{entity_extract_system_prompt} + +段落: +``` +{paragraph} +```""" rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。 @@ -36,12 +38,19 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描 """ -def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]: - messages = [ - LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(), - LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(), - ] - return messages +def build_rdf_triple_extract_context(paragraph: str, entities: str) -> str: + """构建RDF三元组提取的完整提示文本""" + return f"""{rdf_triple_extract_system_prompt} + +段落: +``` +{paragraph} +``` + +实体列表: +``` +{entities} +```""" qa_system_prompt = """ diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 1077cfa09..b9a419c33 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -255,12 +255,11 @@ class LLMRequest: if self.temp != 0.7: payload["temperature"] = self.temp - # 添加enable_thinking参数(如果不是默认值False) - if not self.enable_thinking: - payload["enable_thinking"] = False - - if self.thinking_budget != 4096: - payload["thinking_budget"] = self.thinking_budget + # 添加enable_thinking参数(仅在启用时添加) + if self.enable_thinking: + payload["enable_thinking"] = True + if self.thinking_budget != 4096: + payload["thinking_budget"] = self.thinking_budget if self.max_tokens: payload["max_tokens"] = self.max_tokens @@ -670,12 +669,11 @@ class LLMRequest: if self.temp != 0.7: payload["temperature"] = self.temp - # 添加enable_thinking参数(如果不是默认值False) - if not self.enable_thinking: - payload["enable_thinking"] = False - - if self.thinking_budget != 4096: - payload["thinking_budget"] = self.thinking_budget + # 添加enable_thinking参数(仅在启用时添加) + if self.enable_thinking: + payload["enable_thinking"] = True + if self.thinking_budget != 4096: + payload["thinking_budget"] = self.thinking_budget if self.max_tokens: payload["max_tokens"] = self.max_tokens