fix:修复LPMM学习问题
This commit is contained in:
145
bot.py
145
bot.py
@@ -8,6 +8,7 @@ if os.path.exists(".env"):
|
|||||||
print("成功加载环境变量配置")
|
print("成功加载环境变量配置")
|
||||||
else:
|
else:
|
||||||
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||||
|
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import platform
|
import platform
|
||||||
@@ -140,81 +141,85 @@ async def graceful_shutdown():
|
|||||||
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
|
logger.error(f"麦麦关闭失败: {e}", exc_info=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_file_hash(file_path: Path, file_type: str) -> str:
|
||||||
|
"""计算文件的MD5哈希值"""
|
||||||
|
if not file_path.exists():
|
||||||
|
logger.error(f"{file_type} 文件不存在")
|
||||||
|
raise FileNotFoundError(f"{file_type} 文件不存在")
|
||||||
|
|
||||||
|
with open(file_path, "r", encoding="utf-8") as f:
|
||||||
|
content = f.read()
|
||||||
|
return hashlib.md5(content.encode("utf-8")).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _check_agreement_status(file_hash: str, confirm_file: Path, env_var: str) -> tuple[bool, bool]:
|
||||||
|
"""检查协议确认状态
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple[bool, bool]: (已确认, 未更新)
|
||||||
|
"""
|
||||||
|
# 检查环境变量确认
|
||||||
|
if file_hash == os.getenv(env_var):
|
||||||
|
return True, False
|
||||||
|
|
||||||
|
# 检查确认文件
|
||||||
|
if confirm_file.exists():
|
||||||
|
with open(confirm_file, "r", encoding="utf-8") as f:
|
||||||
|
confirmed_content = f.read()
|
||||||
|
if file_hash == confirmed_content:
|
||||||
|
return True, False
|
||||||
|
|
||||||
|
return False, True
|
||||||
|
|
||||||
|
|
||||||
|
def _prompt_user_confirmation(eula_hash: str, privacy_hash: str) -> None:
|
||||||
|
"""提示用户确认协议"""
|
||||||
|
confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议")
|
||||||
|
confirm_logger.critical(
|
||||||
|
f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_hash}"和"PRIVACY_AGREE={privacy_hash}"继续运行'
|
||||||
|
)
|
||||||
|
|
||||||
|
while True:
|
||||||
|
user_input = input().strip().lower()
|
||||||
|
if user_input in ["同意", "confirmed"]:
|
||||||
|
return
|
||||||
|
confirm_logger.critical('请输入"同意"或"confirmed"以继续运行')
|
||||||
|
|
||||||
|
|
||||||
|
def _save_confirmations(eula_updated: bool, privacy_updated: bool,
|
||||||
|
eula_hash: str, privacy_hash: str) -> None:
|
||||||
|
"""保存用户确认结果"""
|
||||||
|
if eula_updated:
|
||||||
|
logger.info(f"更新EULA确认文件{eula_hash}")
|
||||||
|
Path("eula.confirmed").write_text(eula_hash, encoding="utf-8")
|
||||||
|
|
||||||
|
if privacy_updated:
|
||||||
|
logger.info(f"更新隐私条款确认文件{privacy_hash}")
|
||||||
|
Path("privacy.confirmed").write_text(privacy_hash, encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
def check_eula():
|
def check_eula():
|
||||||
eula_confirm_file = Path("eula.confirmed")
|
"""检查EULA和隐私条款确认状态"""
|
||||||
privacy_confirm_file = Path("privacy.confirmed")
|
# 计算文件哈希值
|
||||||
eula_file = Path("EULA.md")
|
eula_hash = _calculate_file_hash(Path("EULA.md"), "EULA.md")
|
||||||
privacy_file = Path("PRIVACY.md")
|
privacy_hash = _calculate_file_hash(Path("PRIVACY.md"), "PRIVACY.md")
|
||||||
|
|
||||||
eula_updated = True
|
# 检查确认状态
|
||||||
privacy_updated = True
|
eula_confirmed, eula_updated = _check_agreement_status(
|
||||||
|
eula_hash, Path("eula.confirmed"), "EULA_AGREE"
|
||||||
|
)
|
||||||
|
privacy_confirmed, privacy_updated = _check_agreement_status(
|
||||||
|
privacy_hash, Path("privacy.confirmed"), "PRIVACY_AGREE"
|
||||||
|
)
|
||||||
|
|
||||||
eula_confirmed = False
|
# 早期返回:如果都已确认且未更新
|
||||||
privacy_confirmed = False
|
if eula_confirmed and privacy_confirmed:
|
||||||
|
return
|
||||||
|
|
||||||
# 首先计算当前EULA文件的哈希值
|
# 如果有更新,需要重新确认
|
||||||
if eula_file.exists():
|
|
||||||
with open(eula_file, "r", encoding="utf-8") as f:
|
|
||||||
eula_content = f.read()
|
|
||||||
eula_new_hash = hashlib.md5(eula_content.encode("utf-8")).hexdigest()
|
|
||||||
else:
|
|
||||||
logger.error("EULA.md 文件不存在")
|
|
||||||
raise FileNotFoundError("EULA.md 文件不存在")
|
|
||||||
|
|
||||||
# 首先计算当前隐私条款文件的哈希值
|
|
||||||
if privacy_file.exists():
|
|
||||||
with open(privacy_file, "r", encoding="utf-8") as f:
|
|
||||||
privacy_content = f.read()
|
|
||||||
privacy_new_hash = hashlib.md5(privacy_content.encode("utf-8")).hexdigest()
|
|
||||||
else:
|
|
||||||
logger.error("PRIVACY.md 文件不存在")
|
|
||||||
raise FileNotFoundError("PRIVACY.md 文件不存在")
|
|
||||||
|
|
||||||
# 检查EULA确认文件是否存在
|
|
||||||
if eula_confirm_file.exists():
|
|
||||||
with open(eula_confirm_file, "r", encoding="utf-8") as f:
|
|
||||||
confirmed_content = f.read()
|
|
||||||
if eula_new_hash == confirmed_content:
|
|
||||||
eula_confirmed = True
|
|
||||||
eula_updated = False
|
|
||||||
if eula_new_hash == os.getenv("EULA_AGREE"):
|
|
||||||
eula_confirmed = True
|
|
||||||
eula_updated = False
|
|
||||||
|
|
||||||
# 检查隐私条款确认文件是否存在
|
|
||||||
if privacy_confirm_file.exists():
|
|
||||||
with open(privacy_confirm_file, "r", encoding="utf-8") as f:
|
|
||||||
confirmed_content = f.read()
|
|
||||||
if privacy_new_hash == confirmed_content:
|
|
||||||
privacy_confirmed = True
|
|
||||||
privacy_updated = False
|
|
||||||
if privacy_new_hash == os.getenv("PRIVACY_AGREE"):
|
|
||||||
privacy_confirmed = True
|
|
||||||
privacy_updated = False
|
|
||||||
|
|
||||||
# 如果EULA或隐私条款有更新,提示用户重新确认
|
|
||||||
if eula_updated or privacy_updated:
|
if eula_updated or privacy_updated:
|
||||||
confirm_logger.critical("EULA或隐私条款内容已更新,请在阅读后重新确认,继续运行视为同意更新后的以上两款协议")
|
_prompt_user_confirmation(eula_hash, privacy_hash)
|
||||||
confirm_logger.critical(
|
_save_confirmations(eula_updated, privacy_updated, eula_hash, privacy_hash)
|
||||||
f'输入"同意"或"confirmed"或设置环境变量"EULA_AGREE={eula_new_hash}"和"PRIVACY_AGREE={privacy_new_hash}"继续运行'
|
|
||||||
)
|
|
||||||
while True:
|
|
||||||
user_input = input().strip().lower()
|
|
||||||
if user_input in ["同意", "confirmed"]:
|
|
||||||
# print("确认成功,继续运行")
|
|
||||||
# print(f"确认成功,继续运行{eula_updated} {privacy_updated}")
|
|
||||||
if eula_updated:
|
|
||||||
logger.info(f"更新EULA确认文件{eula_new_hash}")
|
|
||||||
eula_confirm_file.write_text(eula_new_hash, encoding="utf-8")
|
|
||||||
if privacy_updated:
|
|
||||||
logger.info(f"更新隐私条款确认文件{privacy_new_hash}")
|
|
||||||
privacy_confirm_file.write_text(privacy_new_hash, encoding="utf-8")
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
confirm_logger.critical('请输入"同意"或"confirmed"以继续运行')
|
|
||||||
return
|
|
||||||
elif eula_confirmed and privacy_confirmed:
|
|
||||||
return
|
|
||||||
|
|
||||||
|
|
||||||
def raw_main():
|
def raw_main():
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from src.chat.knowledge.kg_manager import KGManager
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.knowledge.utils.hash import get_sha256
|
from src.chat.knowledge.utils.hash import get_sha256
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
# 添加项目根目录到 sys.path
|
||||||
@@ -23,6 +24,45 @@ OPENIE_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
|||||||
|
|
||||||
logger = get_logger("OpenIE导入")
|
logger = get_logger("OpenIE导入")
|
||||||
|
|
||||||
|
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
||||||
|
|
||||||
|
if os.path.exists(".env"):
|
||||||
|
load_dotenv(".env", override=True)
|
||||||
|
print("成功加载环境变量配置")
|
||||||
|
else:
|
||||||
|
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||||
|
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||||
|
|
||||||
|
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||||
|
def scan_provider(env_config: dict):
|
||||||
|
provider = {}
|
||||||
|
|
||||||
|
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
||||||
|
# 避免 GPG_KEY 这样的变量干扰检查
|
||||||
|
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
||||||
|
|
||||||
|
# 遍历 env_config 的所有键
|
||||||
|
for key in env_config:
|
||||||
|
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
||||||
|
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
||||||
|
# 提取 provider 名称
|
||||||
|
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
||||||
|
|
||||||
|
# 初始化 provider 的字典(如果尚未初始化)
|
||||||
|
if provider_name not in provider:
|
||||||
|
provider[provider_name] = {"url": None, "key": None}
|
||||||
|
|
||||||
|
# 根据键的类型填充 url 或 key
|
||||||
|
if key.endswith("_BASE_URL"):
|
||||||
|
provider[provider_name]["url"] = env_config[key]
|
||||||
|
elif key.endswith("_KEY"):
|
||||||
|
provider[provider_name]["key"] = env_config[key]
|
||||||
|
|
||||||
|
# 检查每个 provider 是否同时存在 url 和 key
|
||||||
|
for provider_name, config in provider.items():
|
||||||
|
if config["url"] is None or config["key"] is None:
|
||||||
|
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
||||||
|
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
||||||
|
|
||||||
def ensure_openie_dir():
|
def ensure_openie_dir():
|
||||||
"""确保OpenIE数据目录存在"""
|
"""确保OpenIE数据目录存在"""
|
||||||
@@ -174,6 +214,8 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||||||
|
|
||||||
def main(): # sourcery skip: dict-comprehension
|
def main(): # sourcery skip: dict-comprehension
|
||||||
# 新增确认提示
|
# 新增确认提示
|
||||||
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
|
scan_provider(env_config)
|
||||||
print("=== 重要操作确认 ===")
|
print("=== 重要操作确认 ===")
|
||||||
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
print("OpenIE导入时会大量发送请求,可能会撞到请求速度上限,请注意选用的模型")
|
||||||
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
print("同之前样例:在本地模型下,在70分钟内我们发送了约8万条请求,在网络允许下,速度会更快")
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from rich.progress import (
|
|||||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
logger = get_logger("LPMM知识库-信息提取")
|
logger = get_logger("LPMM知识库-信息提取")
|
||||||
|
|
||||||
@@ -35,6 +36,45 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|||||||
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||||
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
# IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data", "imported_lpmm_data")
|
||||||
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
OPENIE_OUTPUT_DIR = os.path.join(ROOT_PATH, "data", "openie")
|
||||||
|
ENV_FILE = os.path.join(ROOT_PATH, ".env")
|
||||||
|
|
||||||
|
if os.path.exists(".env"):
|
||||||
|
load_dotenv(".env", override=True)
|
||||||
|
print("成功加载环境变量配置")
|
||||||
|
else:
|
||||||
|
print("未找到.env文件,请确保程序所需的环境变量被正确设置")
|
||||||
|
raise FileNotFoundError(".env 文件不存在,请创建并配置所需的环境变量")
|
||||||
|
|
||||||
|
env_mask = {key: os.getenv(key) for key in os.environ}
|
||||||
|
def scan_provider(env_config: dict):
|
||||||
|
provider = {}
|
||||||
|
|
||||||
|
# 利用未初始化 env 时获取的 env_mask 来对新的环境变量集去重
|
||||||
|
# 避免 GPG_KEY 这样的变量干扰检查
|
||||||
|
env_config = dict(filter(lambda item: item[0] not in env_mask, env_config.items()))
|
||||||
|
|
||||||
|
# 遍历 env_config 的所有键
|
||||||
|
for key in env_config:
|
||||||
|
# 检查键是否符合 {provider}_BASE_URL 或 {provider}_KEY 的格式
|
||||||
|
if key.endswith("_BASE_URL") or key.endswith("_KEY"):
|
||||||
|
# 提取 provider 名称
|
||||||
|
provider_name = key.split("_", 1)[0] # 从左分割一次,取第一部分
|
||||||
|
|
||||||
|
# 初始化 provider 的字典(如果尚未初始化)
|
||||||
|
if provider_name not in provider:
|
||||||
|
provider[provider_name] = {"url": None, "key": None}
|
||||||
|
|
||||||
|
# 根据键的类型填充 url 或 key
|
||||||
|
if key.endswith("_BASE_URL"):
|
||||||
|
provider[provider_name]["url"] = env_config[key]
|
||||||
|
elif key.endswith("_KEY"):
|
||||||
|
provider[provider_name]["key"] = env_config[key]
|
||||||
|
|
||||||
|
# 检查每个 provider 是否同时存在 url 和 key
|
||||||
|
for provider_name, config in provider.items():
|
||||||
|
if config["url"] is None or config["key"] is None:
|
||||||
|
logger.error(f"provider 内容:{config}\nenv_config 内容:{env_config}")
|
||||||
|
raise ValueError(f"请检查 '{provider_name}' 提供商配置是否丢失 BASE_URL 或 KEY 环境变量")
|
||||||
|
|
||||||
def ensure_dirs():
|
def ensure_dirs():
|
||||||
"""确保临时目录和输出目录存在"""
|
"""确保临时目录和输出目录存在"""
|
||||||
@@ -118,6 +158,8 @@ def main(): # sourcery skip: comprehension-to-generator, extract-method
|
|||||||
# 设置信号处理器
|
# 设置信号处理器
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
ensure_dirs() # 确保目录存在
|
ensure_dirs() # 确保目录存在
|
||||||
|
env_config = {key: os.getenv(key) for key in os.environ}
|
||||||
|
scan_provider(env_config)
|
||||||
# 新增用户确认提示
|
# 新增用户确认提示
|
||||||
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
print("=== 重要操作确认,请认真阅读以下内容哦 ===")
|
||||||
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
print("实体提取操作将会花费较多api余额和时间,建议在空闲时段执行。")
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
|
import asyncio
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -99,7 +100,30 @@ class EmbeddingStore:
|
|||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
|
||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
return get_embedding(s)
|
"""获取字符串的嵌入向量,处理异步调用"""
|
||||||
|
try:
|
||||||
|
# 尝试获取当前事件循环
|
||||||
|
asyncio.get_running_loop()
|
||||||
|
# 如果在事件循环中,使用线程池执行
|
||||||
|
import concurrent.futures
|
||||||
|
|
||||||
|
def run_in_thread():
|
||||||
|
return asyncio.run(get_embedding(s))
|
||||||
|
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
future = executor.submit(run_in_thread)
|
||||||
|
result = future.result()
|
||||||
|
if result is None:
|
||||||
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
|
return []
|
||||||
|
return result
|
||||||
|
except RuntimeError:
|
||||||
|
# 没有运行的事件循环,直接运行
|
||||||
|
result = asyncio.run(get_embedding(s))
|
||||||
|
if result is None:
|
||||||
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
|
return []
|
||||||
|
return result
|
||||||
|
|
||||||
def get_test_file_path(self):
|
def get_test_file_path(self):
|
||||||
return EMBEDDING_TEST_FILE
|
return EMBEDDING_TEST_FILE
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from typing import List, Union
|
from typing import List, Union
|
||||||
@@ -7,8 +8,12 @@ from . import prompt_template
|
|||||||
from .knowledge_lib import INVALID_ENTITY
|
from .knowledge_lib import INVALID_ENTITY
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
def _extract_json_from_text(text: str) -> dict:
|
def _extract_json_from_text(text: str):
|
||||||
"""从文本中提取JSON数据的高容错方法"""
|
"""从文本中提取JSON数据的高容错方法"""
|
||||||
|
if text is None:
|
||||||
|
logger.error("输入文本为None")
|
||||||
|
return []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
fixed_json = repair_json(text)
|
fixed_json = repair_json(text)
|
||||||
if isinstance(fixed_json, str):
|
if isinstance(fixed_json, str):
|
||||||
@@ -16,23 +21,66 @@ def _extract_json_from_text(text: str) -> dict:
|
|||||||
else:
|
else:
|
||||||
parsed_json = fixed_json
|
parsed_json = fixed_json
|
||||||
|
|
||||||
if isinstance(parsed_json, list) and parsed_json:
|
# 如果是列表,直接返回
|
||||||
parsed_json = parsed_json[0]
|
if isinstance(parsed_json, list):
|
||||||
|
|
||||||
if isinstance(parsed_json, dict):
|
|
||||||
return parsed_json
|
return parsed_json
|
||||||
|
|
||||||
|
# 如果是字典且只有一个项目,可能包装了列表
|
||||||
|
if isinstance(parsed_json, dict):
|
||||||
|
# 如果字典只有一个键,并且值是列表,返回那个列表
|
||||||
|
if len(parsed_json) == 1:
|
||||||
|
value = list(parsed_json.values())[0]
|
||||||
|
if isinstance(value, list):
|
||||||
|
return value
|
||||||
|
return parsed_json
|
||||||
|
|
||||||
|
# 其他情况,尝试转换为列表
|
||||||
|
logger.warning(f"解析的JSON不是预期格式: {type(parsed_json)}, 内容: {parsed_json}")
|
||||||
|
return []
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100]}...")
|
logger.error(f"JSON提取失败: {e}, 原始文本: {text[:100] if text else 'None'}...")
|
||||||
|
return []
|
||||||
|
|
||||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
|
|
||||||
|
# 使用 asyncio.run 来运行异步方法
|
||||||
|
try:
|
||||||
|
# 如果当前已有事件循环在运行,使用它
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
llm_req.generate_response_async(entity_extract_context), loop
|
||||||
|
)
|
||||||
|
response, (reasoning_content, model_name) = future.result()
|
||||||
|
except RuntimeError:
|
||||||
|
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||||
|
response, (reasoning_content, model_name) = asyncio.run(
|
||||||
|
llm_req.generate_response_async(entity_extract_context)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加调试日志
|
||||||
|
logger.debug(f"LLM返回的原始响应: {response}")
|
||||||
|
|
||||||
entity_extract_result = _extract_json_from_text(response)
|
entity_extract_result = _extract_json_from_text(response)
|
||||||
# 尝试load JSON数据
|
|
||||||
json.loads(entity_extract_result)
|
# 检查返回的是否为有效的实体列表
|
||||||
|
if not isinstance(entity_extract_result, list):
|
||||||
|
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
||||||
|
if isinstance(entity_extract_result, dict):
|
||||||
|
# 尝试常见的键名
|
||||||
|
for key in ['entities', 'result', 'data', 'items']:
|
||||||
|
if key in entity_extract_result and isinstance(entity_extract_result[key], list):
|
||||||
|
entity_extract_result = entity_extract_result[key]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 如果找不到合适的列表,抛出异常
|
||||||
|
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||||
|
else:
|
||||||
|
raise Exception(f"实体提取结果格式错误,期望列表但得到: {type(entity_extract_result)}")
|
||||||
|
|
||||||
|
# 过滤无效实体
|
||||||
entity_extract_result = [
|
entity_extract_result = [
|
||||||
entity
|
entity
|
||||||
for entity in entity_extract_result
|
for entity in entity_extract_result
|
||||||
@@ -50,16 +98,47 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
|||||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||||
)
|
)
|
||||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
|
|
||||||
|
|
||||||
entity_extract_result = _extract_json_from_text(response)
|
# 使用 asyncio.run 来运行异步方法
|
||||||
# 尝试load JSON数据
|
try:
|
||||||
json.loads(entity_extract_result)
|
# 如果当前已有事件循环在运行,使用它
|
||||||
for triple in entity_extract_result:
|
loop = asyncio.get_running_loop()
|
||||||
if len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
future = asyncio.run_coroutine_threadsafe(
|
||||||
|
llm_req.generate_response_async(rdf_extract_context), loop
|
||||||
|
)
|
||||||
|
response, (reasoning_content, model_name) = future.result()
|
||||||
|
except RuntimeError:
|
||||||
|
# 如果没有运行中的事件循环,直接使用 asyncio.run
|
||||||
|
response, (reasoning_content, model_name) = asyncio.run(
|
||||||
|
llm_req.generate_response_async(rdf_extract_context)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加调试日志
|
||||||
|
logger.debug(f"RDF LLM返回的原始响应: {response}")
|
||||||
|
|
||||||
|
rdf_triple_result = _extract_json_from_text(response)
|
||||||
|
|
||||||
|
# 检查返回的是否为有效的三元组列表
|
||||||
|
if not isinstance(rdf_triple_result, list):
|
||||||
|
# 如果不是列表,可能是字典格式,尝试从中提取列表
|
||||||
|
if isinstance(rdf_triple_result, dict):
|
||||||
|
# 尝试常见的键名
|
||||||
|
for key in ['triples', 'result', 'data', 'items']:
|
||||||
|
if key in rdf_triple_result and isinstance(rdf_triple_result[key], list):
|
||||||
|
rdf_triple_result = rdf_triple_result[key]
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 如果找不到合适的列表,抛出异常
|
||||||
|
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||||
|
else:
|
||||||
|
raise Exception(f"RDF三元组提取结果格式错误,期望列表但得到: {type(rdf_triple_result)}")
|
||||||
|
|
||||||
|
# 验证三元组格式
|
||||||
|
for triple in rdf_triple_result:
|
||||||
|
if not isinstance(triple, list) or len(triple) != 3 or (triple[0] is None or triple[1] is None or triple[2] is None) or "" in triple:
|
||||||
raise Exception("RDF提取结果格式错误")
|
raise Exception("RDF提取结果格式错误")
|
||||||
|
|
||||||
return entity_extract_result
|
return rdf_triple_result
|
||||||
|
|
||||||
|
|
||||||
def info_extract_from_str(
|
def info_extract_from_str(
|
||||||
|
|||||||
@@ -11,12 +11,14 @@ entity_extract_system_prompt = """你是一个性能优异的实体提取系统
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_entity_extract_context(paragraph: str) -> list[LLMMessage]:
|
def build_entity_extract_context(paragraph: str) -> str:
|
||||||
messages = [
|
"""构建实体提取的完整提示文本"""
|
||||||
LLMMessage("system", entity_extract_system_prompt).to_dict(),
|
return f"""{entity_extract_system_prompt}
|
||||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```""").to_dict(),
|
|
||||||
]
|
段落:
|
||||||
return messages
|
```
|
||||||
|
{paragraph}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
|
||||||
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
|
rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描述框架,由节点和边组成,节点表示实体/资源、属性,边则表示了实体和实体之间的关系以及实体和属性的关系。)构造系统。你的任务是根据给定的段落和实体列表构建RDF图。
|
||||||
@@ -36,12 +38,19 @@ rdf_triple_extract_system_prompt = """你是一个性能优异的RDF(资源描
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> list[LLMMessage]:
|
def build_rdf_triple_extract_context(paragraph: str, entities: str) -> str:
|
||||||
messages = [
|
"""构建RDF三元组提取的完整提示文本"""
|
||||||
LLMMessage("system", rdf_triple_extract_system_prompt).to_dict(),
|
return f"""{rdf_triple_extract_system_prompt}
|
||||||
LLMMessage("user", f"""段落:\n```\n{paragraph}```\n\n实体列表:\n```\n{entities}```""").to_dict(),
|
|
||||||
]
|
段落:
|
||||||
return messages
|
```
|
||||||
|
{paragraph}
|
||||||
|
```
|
||||||
|
|
||||||
|
实体列表:
|
||||||
|
```
|
||||||
|
{entities}
|
||||||
|
```"""
|
||||||
|
|
||||||
|
|
||||||
qa_system_prompt = """
|
qa_system_prompt = """
|
||||||
|
|||||||
@@ -255,12 +255,11 @@ class LLMRequest:
|
|||||||
if self.temp != 0.7:
|
if self.temp != 0.7:
|
||||||
payload["temperature"] = self.temp
|
payload["temperature"] = self.temp
|
||||||
|
|
||||||
# 添加enable_thinking参数(如果不是默认值False)
|
# 添加enable_thinking参数(仅在启用时添加)
|
||||||
if not self.enable_thinking:
|
if self.enable_thinking:
|
||||||
payload["enable_thinking"] = False
|
payload["enable_thinking"] = True
|
||||||
|
if self.thinking_budget != 4096:
|
||||||
if self.thinking_budget != 4096:
|
payload["thinking_budget"] = self.thinking_budget
|
||||||
payload["thinking_budget"] = self.thinking_budget
|
|
||||||
|
|
||||||
if self.max_tokens:
|
if self.max_tokens:
|
||||||
payload["max_tokens"] = self.max_tokens
|
payload["max_tokens"] = self.max_tokens
|
||||||
@@ -670,12 +669,11 @@ class LLMRequest:
|
|||||||
if self.temp != 0.7:
|
if self.temp != 0.7:
|
||||||
payload["temperature"] = self.temp
|
payload["temperature"] = self.temp
|
||||||
|
|
||||||
# 添加enable_thinking参数(如果不是默认值False)
|
# 添加enable_thinking参数(仅在启用时添加)
|
||||||
if not self.enable_thinking:
|
if self.enable_thinking:
|
||||||
payload["enable_thinking"] = False
|
payload["enable_thinking"] = True
|
||||||
|
if self.thinking_budget != 4096:
|
||||||
if self.thinking_budget != 4096:
|
payload["thinking_budget"] = self.thinking_budget
|
||||||
payload["thinking_budget"] = self.thinking_budget
|
|
||||||
|
|
||||||
if self.max_tokens:
|
if self.max_tokens:
|
||||||
payload["max_tokens"] = self.max_tokens
|
payload["max_tokens"] = self.max_tokens
|
||||||
|
|||||||
Reference in New Issue
Block a user