Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
3
bot.py
3
bot.py
@@ -13,6 +13,9 @@ from src.common.logger_manager import get_logger
|
|||||||
# from src.common.logger import LogConfig, CONFIRM_STYLE_CONFIG
|
# from src.common.logger import LogConfig, CONFIRM_STYLE_CONFIG
|
||||||
from src.common.crash_logger import install_crash_handler
|
from src.common.crash_logger import install_crash_handler
|
||||||
from src.main import MainSystem
|
from src.main import MainSystem
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
|
from time import sleep
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
@@ -19,9 +20,14 @@ from src.plugins.knowledge.src.utils.hash import get_sha256
|
|||||||
|
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
# 添加项目根目录到 sys.path
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
OPENIE_DIR = (
|
||||||
|
global_config["persistence"]["openie_data_path"]
|
||||||
|
if global_config["persistence"]["openie_data_path"]
|
||||||
|
else os.path.join(ROOT_PATH, "data/openie")
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_module_logger("OpenIE导入")
|
||||||
logger = get_module_logger("LPMM知识库-OpenIE导入")
|
|
||||||
|
|
||||||
|
|
||||||
def hash_deduplicate(
|
def hash_deduplicate(
|
||||||
@@ -66,8 +72,45 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
|
|||||||
entity_list_data = openie_data.extract_entity_dict()
|
entity_list_data = openie_data.extract_entity_dict()
|
||||||
# 索引的三元组列表
|
# 索引的三元组列表
|
||||||
triple_list_data = openie_data.extract_triple_dict()
|
triple_list_data = openie_data.extract_triple_dict()
|
||||||
|
# print(openie_data.docs)
|
||||||
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
|
||||||
logger.error("OpenIE数据存在异常")
|
logger.error("OpenIE数据存在异常")
|
||||||
|
logger.error(f"原始段落数量:{len(raw_paragraphs)}")
|
||||||
|
logger.error(f"实体列表数量:{len(entity_list_data)}")
|
||||||
|
logger.error(f"三元组列表数量:{len(triple_list_data)}")
|
||||||
|
logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致")
|
||||||
|
logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况")
|
||||||
|
logger.error("或者一段中只有符号的情况")
|
||||||
|
# 新增:检查docs中每条数据的完整性
|
||||||
|
logger.error("系统将于2秒后开始检查数据完整性")
|
||||||
|
sleep(2)
|
||||||
|
found_missing = False
|
||||||
|
for doc in getattr(openie_data, "docs", []):
|
||||||
|
idx = doc.get("idx", "<无idx>")
|
||||||
|
passage = doc.get("passage", "<无passage>")
|
||||||
|
missing = []
|
||||||
|
# 检查字段是否存在且非空
|
||||||
|
if "passage" not in doc or not doc.get("passage"):
|
||||||
|
missing.append("passage")
|
||||||
|
if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list):
|
||||||
|
missing.append("名词列表缺失")
|
||||||
|
elif len(doc.get("extracted_entities", [])) == 0:
|
||||||
|
missing.append("名词列表为空")
|
||||||
|
if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list):
|
||||||
|
missing.append("主谓宾三元组缺失")
|
||||||
|
elif len(doc.get("extracted_triples", [])) == 0:
|
||||||
|
missing.append("主谓宾三元组为空")
|
||||||
|
# 输出所有doc的idx
|
||||||
|
# print(f"检查: idx={idx}")
|
||||||
|
if missing:
|
||||||
|
found_missing = True
|
||||||
|
logger.error("\n")
|
||||||
|
logger.error("数据缺失:")
|
||||||
|
logger.error(f"对应哈希值:{idx}")
|
||||||
|
logger.error(f"对应文段内容内容:{passage}")
|
||||||
|
logger.error(f"非法原因:{', '.join(missing)}")
|
||||||
|
if not found_missing:
|
||||||
|
print("所有数据均完整,没有发现缺失字段。")
|
||||||
return False
|
return False
|
||||||
# 将索引换为对应段落的hash值
|
# 将索引换为对应段落的hash值
|
||||||
logger.info("正在进行段落去重与重索引")
|
logger.info("正在进行段落去重与重索引")
|
||||||
@@ -131,6 +174,7 @@ def main():
|
|||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
||||||
|
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||||
logger.info("Embedding库加载完成")
|
logger.info("Embedding库加载完成")
|
||||||
# 初始化KG
|
# 初始化KG
|
||||||
kg_manager = KGManager()
|
kg_manager = KGManager()
|
||||||
@@ -139,6 +183,7 @@ def main():
|
|||||||
kg_manager.load_from_file()
|
kg_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("从文件加载KG时发生错误:{}".format(e))
|
logger.error("从文件加载KG时发生错误:{}".format(e))
|
||||||
|
logger.error("如果你是第一次导入知识,请忽略此错误")
|
||||||
logger.info("KG加载完成")
|
logger.info("KG加载完成")
|
||||||
|
|
||||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||||
@@ -163,4 +208,5 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
# logger.info(f"111111111111111111111111{ROOT_PATH}")
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -4,11 +4,13 @@ import signal
|
|||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from threading import Lock, Event
|
from threading import Lock, Event
|
||||||
import sys
|
import sys
|
||||||
|
import glob
|
||||||
|
import datetime
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
# 添加项目根目录到 sys.path
|
# 添加项目根目录到 sys.path
|
||||||
|
|
||||||
import tqdm
|
from rich.progress import Progress # 替换为 rich 进度条
|
||||||
|
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.plugins.knowledge.src.lpmmconfig import global_config
|
from src.plugins.knowledge.src.lpmmconfig import global_config
|
||||||
@@ -16,10 +18,31 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str
|
|||||||
from src.plugins.knowledge.src.llm_client import LLMClient
|
from src.plugins.knowledge.src.llm_client import LLMClient
|
||||||
from src.plugins.knowledge.src.open_ie import OpenIE
|
from src.plugins.knowledge.src.open_ie import OpenIE
|
||||||
from src.plugins.knowledge.src.raw_processing import load_raw_data
|
from src.plugins.knowledge.src.raw_processing import load_raw_data
|
||||||
|
from rich.progress import (
|
||||||
|
BarColumn,
|
||||||
|
TimeElapsedColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
MofNCompleteColumn,
|
||||||
|
SpinnerColumn,
|
||||||
|
TextColumn,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_module_logger("LPMM知识库-信息提取")
|
logger = get_module_logger("LPMM知识库-信息提取")
|
||||||
|
|
||||||
TEMP_DIR = "./temp"
|
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
|
||||||
|
IMPORTED_DATA_PATH = (
|
||||||
|
global_config["persistence"]["raw_data_path"]
|
||||||
|
if global_config["persistence"]["raw_data_path"]
|
||||||
|
else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||||
|
)
|
||||||
|
OPENIE_OUTPUT_DIR = (
|
||||||
|
global_config["persistence"]["openie_data_path"]
|
||||||
|
if global_config["persistence"]["openie_data_path"]
|
||||||
|
else os.path.join(ROOT_PATH, "data/openie")
|
||||||
|
)
|
||||||
|
|
||||||
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
# 创建一个线程安全的锁,用于保护文件操作和共享数据
|
||||||
file_lock = Lock()
|
file_lock = Lock()
|
||||||
@@ -70,8 +93,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
|||||||
# 如果保存失败,确保不会留下损坏的文件
|
# 如果保存失败,确保不会留下损坏的文件
|
||||||
if os.path.exists(temp_file_path):
|
if os.path.exists(temp_file_path):
|
||||||
os.remove(temp_file_path)
|
os.remove(temp_file_path)
|
||||||
# 设置shutdown_event以终止程序
|
sys.exit(0)
|
||||||
shutdown_event.set()
|
|
||||||
return None, pg_hash
|
return None, pg_hash
|
||||||
return doc_item, None
|
return doc_item, None
|
||||||
|
|
||||||
@@ -79,7 +101,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
|
|||||||
def signal_handler(_signum, _frame):
|
def signal_handler(_signum, _frame):
|
||||||
"""处理Ctrl+C信号"""
|
"""处理Ctrl+C信号"""
|
||||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||||
shutdown_event.set()
|
sys.exit(0)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
@@ -110,33 +132,61 @@ def main():
|
|||||||
global_config["llm_providers"][key]["api_key"],
|
global_config["llm_providers"][key]["api_key"],
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info("正在加载原始数据")
|
# 检查 openie 输出目录
|
||||||
sha256_list, raw_datas = load_raw_data()
|
if not os.path.exists(OPENIE_OUTPUT_DIR):
|
||||||
logger.info("原始数据加载完成\n")
|
os.makedirs(OPENIE_OUTPUT_DIR)
|
||||||
|
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
|
||||||
|
|
||||||
# 创建临时目录
|
# 确保 TEMP_DIR 目录存在
|
||||||
if not os.path.exists(f"{TEMP_DIR}"):
|
if not os.path.exists(TEMP_DIR):
|
||||||
os.makedirs(f"{TEMP_DIR}")
|
os.makedirs(TEMP_DIR)
|
||||||
|
logger.info(f"已创建缓存目录: {TEMP_DIR}")
|
||||||
|
|
||||||
|
# 遍历IMPORTED_DATA_PATH下所有json文件
|
||||||
|
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
|
||||||
|
if not imported_files:
|
||||||
|
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
all_sha256_list = []
|
||||||
|
all_raw_datas = []
|
||||||
|
|
||||||
|
for imported_file in imported_files:
|
||||||
|
logger.info(f"正在处理文件: {imported_file}")
|
||||||
|
try:
|
||||||
|
sha256_list, raw_datas = load_raw_data(imported_file)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
|
||||||
|
continue
|
||||||
|
all_sha256_list.extend(sha256_list)
|
||||||
|
all_raw_datas.extend(raw_datas)
|
||||||
|
|
||||||
failed_sha256 = []
|
failed_sha256 = []
|
||||||
open_ie_doc = []
|
open_ie_doc = []
|
||||||
|
|
||||||
# 创建线程池,最大线程数为50
|
|
||||||
workers = global_config["info_extraction"]["workers"]
|
workers = global_config["info_extraction"]["workers"]
|
||||||
with ThreadPoolExecutor(max_workers=workers) as executor:
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
||||||
# 提交所有任务到线程池
|
|
||||||
future_to_hash = {
|
future_to_hash = {
|
||||||
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
|
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
|
||||||
for pg_hash, raw_data in zip(sha256_list, raw_datas)
|
for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas)
|
||||||
}
|
}
|
||||||
|
|
||||||
# 使用tqdm显示进度
|
with Progress(
|
||||||
with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar:
|
SpinnerColumn(),
|
||||||
# 处理完成的任务
|
TextColumn("[progress.description]{task.description}"),
|
||||||
|
BarColumn(),
|
||||||
|
TaskProgressColumn(),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
"•",
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
"<",
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
transient=False,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
|
||||||
try:
|
try:
|
||||||
for future in as_completed(future_to_hash):
|
for future in as_completed(future_to_hash):
|
||||||
if shutdown_event.is_set():
|
if shutdown_event.is_set():
|
||||||
# 取消所有未完成的任务
|
|
||||||
for f in future_to_hash:
|
for f in future_to_hash:
|
||||||
if not f.done():
|
if not f.done():
|
||||||
f.cancel()
|
f.cancel()
|
||||||
@@ -149,26 +199,38 @@ def main():
|
|||||||
elif doc_item:
|
elif doc_item:
|
||||||
with open_ie_doc_lock:
|
with open_ie_doc_lock:
|
||||||
open_ie_doc.append(doc_item)
|
open_ie_doc.append(doc_item)
|
||||||
pbar.update(1)
|
progress.update(task, advance=1)
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
# 如果在这里捕获到KeyboardInterrupt,说明signal_handler可能没有正常工作
|
|
||||||
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
logger.info("\n接收到中断信号,正在优雅地关闭程序...")
|
||||||
shutdown_event.set()
|
shutdown_event.set()
|
||||||
# 取消所有未完成的任务
|
|
||||||
for f in future_to_hash:
|
for f in future_to_hash:
|
||||||
if not f.done():
|
if not f.done():
|
||||||
f.cancel()
|
f.cancel()
|
||||||
|
|
||||||
# 保存信息提取结果
|
# 合并所有文件的提取结果并保存
|
||||||
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
if open_ie_doc:
|
||||||
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||||
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
|
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
|
||||||
openie_obj = OpenIE(
|
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
|
||||||
open_ie_doc,
|
openie_obj = OpenIE(
|
||||||
round(sum_phrase_chars / num_phrases, 4),
|
open_ie_doc,
|
||||||
round(sum_phrase_words / num_phrases, 4),
|
round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
|
||||||
)
|
round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
|
||||||
OpenIE.save(openie_obj)
|
)
|
||||||
|
# 输出文件名格式:MM-DD-HH-ss-openie.json
|
||||||
|
now = datetime.datetime.now()
|
||||||
|
filename = now.strftime("%m-%d-%H-%S-openie.json")
|
||||||
|
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
|
||||||
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
|
json.dump(
|
||||||
|
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
|
||||||
|
f,
|
||||||
|
ensure_ascii=False,
|
||||||
|
indent=4,
|
||||||
|
)
|
||||||
|
logger.info(f"信息提取结果已保存到: {output_path}")
|
||||||
|
else:
|
||||||
|
logger.warning("没有可保存的信息提取结果")
|
||||||
|
|
||||||
logger.info("--------信息提取完成--------")
|
logger.info("--------信息提取完成--------")
|
||||||
logger.info(f"提取失败的文段SHA256:{failed_sha256}")
|
logger.info(f"提取失败的文段SHA256:{failed_sha256}")
|
||||||
|
|||||||
@@ -2,18 +2,22 @@ import json
|
|||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import sys # 新增系统模块导入
|
import sys # 新增系统模块导入
|
||||||
|
import datetime # 新增导入
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
logger = get_module_logger("LPMM数据库-原始数据处理")
|
logger = get_module_logger("LPMM数据库-原始数据处理")
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||||
|
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
|
||||||
|
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
|
||||||
|
|
||||||
# 添加项目根目录到 sys.path
|
# 添加项目根目录到 sys.path
|
||||||
|
|
||||||
|
|
||||||
def check_and_create_dirs():
|
def check_and_create_dirs():
|
||||||
"""检查并创建必要的目录"""
|
"""检查并创建必要的目录"""
|
||||||
required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"]
|
required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
|
||||||
|
|
||||||
for dir_path in required_dirs:
|
for dir_path in required_dirs:
|
||||||
if not os.path.exists(dir_path):
|
if not os.path.exists(dir_path):
|
||||||
@@ -58,17 +62,17 @@ def main():
|
|||||||
# 检查并创建必要的目录
|
# 检查并创建必要的目录
|
||||||
check_and_create_dirs()
|
check_and_create_dirs()
|
||||||
|
|
||||||
# 检查输出文件是否存在
|
# # 检查输出文件是否存在
|
||||||
if os.path.exists("data/import.json"):
|
# if os.path.exists(RAW_DATA_PATH):
|
||||||
logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
|
# logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
|
||||||
sys.exit(1)
|
# sys.exit(1)
|
||||||
|
|
||||||
if os.path.exists("data/openie.json"):
|
# if os.path.exists(RAW_DATA_PATH):
|
||||||
logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
|
# logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
|
||||||
sys.exit(1)
|
# sys.exit(1)
|
||||||
|
|
||||||
# 获取所有原始文本文件
|
# 获取所有原始文本文件
|
||||||
raw_files = list(Path("data/lpmm_raw_data").glob("*.txt"))
|
raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
|
||||||
if not raw_files:
|
if not raw_files:
|
||||||
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
@@ -80,8 +84,10 @@ def main():
|
|||||||
paragraphs = process_text_file(file)
|
paragraphs = process_text_file(file)
|
||||||
all_paragraphs.extend(paragraphs)
|
all_paragraphs.extend(paragraphs)
|
||||||
|
|
||||||
# 保存合并后的结果
|
# 保存合并后的结果到 IMPORTED_DATA_PATH,文件名格式为 MM-DD-HH-ss-imported-data.json
|
||||||
output_path = "data/import.json"
|
now = datetime.datetime.now()
|
||||||
|
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
|
||||||
|
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
|
||||||
with open(output_path, "w", encoding="utf-8") as f:
|
with open(output_path, "w", encoding="utf-8") as f:
|
||||||
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
|
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
|
||||||
|
|
||||||
@@ -89,4 +95,6 @@ def main():
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
print(f"Raw Data Path: {RAW_DATA_PATH}")
|
||||||
|
print(f"Imported Data Path: {IMPORTED_DATA_PATH}")
|
||||||
main()
|
main()
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import os
|
import os
|
||||||
from pymongo import MongoClient
|
from pymongo import MongoClient
|
||||||
from pymongo.database import Database
|
from pymongo.database import Database
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
_client = None
|
_client = None
|
||||||
_db = None
|
_db = None
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ import functools
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import Callable, Any
|
from typing import Callable, Any
|
||||||
from .logger import logger, add_custom_style_handler
|
from .logger import logger, add_custom_style_handler
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
def use_log_style(
|
def use_log_style(
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ from fastapi import FastAPI, APIRouter
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
from uvicorn import Config, Server as UvicornServer
|
from uvicorn import Config, Server as UvicornServer
|
||||||
import os
|
import os
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ from packaging.version import Version, InvalidVersion
|
|||||||
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
from packaging.specifiers import SpecifierSet, InvalidSpecifier
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
# 配置主程序日志格式
|
# 配置主程序日志格式
|
||||||
|
|||||||
@@ -4,6 +4,9 @@ import importlib
|
|||||||
import pkgutil
|
import pkgutil
|
||||||
import os
|
import os
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("base_tool")
|
logger = get_logger("base_tool")
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ from typing import Optional
|
|||||||
from .personality import Personality
|
from .personality import Personality
|
||||||
from .identity import Identity
|
from .identity import Identity
|
||||||
import random
|
import random
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
class Individuality:
|
class Individuality:
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ from typing import Tuple, Union
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_module_logger("offline_llm")
|
logger = get_module_logger("offline_llm")
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,9 @@ from .common.logger_manager import get_logger
|
|||||||
from .plugins.remote import heartbeat_thread # noqa: F401
|
from .plugins.remote import heartbeat_thread # noqa: F401
|
||||||
from .individuality.individuality import Individuality
|
from .individuality.individuality import Individuality
|
||||||
from .common.server import global_server
|
from .common.server import global_server
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ from maim_message import UserInfo
|
|||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||||
from .message_storage import MongoDBMessageStorage
|
from .message_storage import MongoDBMessageStorage
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_module_logger("chat_observer")
|
logger = get_module_logger("chat_observer")
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,9 @@ from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
|||||||
from .waiter import Waiter
|
from .waiter import Waiter
|
||||||
|
|
||||||
import traceback
|
import traceback
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("pfc")
|
logger = get_logger("pfc")
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ from src.plugins.chat.message import MessageSending, MessageSet
|
|||||||
from src.plugins.chat.message_sender import message_manager
|
from src.plugins.chat.message_sender import message_manager
|
||||||
from ..storage.storage import MessageStorage
|
from ..storage.storage import MessageStorage
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
logger = get_module_logger("message_sender")
|
logger = get_module_logger("message_sender")
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ from src.individuality.individuality import Individuality
|
|||||||
from .conversation_info import ConversationInfo
|
from .conversation_info import ConversationInfo
|
||||||
from .observation_info import ObservationInfo
|
from .observation_info import ObservationInfo
|
||||||
from src.plugins.utils.chat_message_builder import build_readable_messages
|
from src.plugins.utils.chat_message_builder import build_readable_messages
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ from ...common.database import db
|
|||||||
from maim_message import GroupInfo, UserInfo
|
from maim_message import GroupInfo, UserInfo
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("chat_stream")
|
logger = get_logger("chat_stream")
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ from src.common.logger_manager import get_logger
|
|||||||
from .chat_stream import ChatStream
|
from .chat_stream import ChatStream
|
||||||
from .utils_image import image_manager
|
from .utils_image import image_manager
|
||||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("chat_message")
|
logger = get_logger("chat_message")
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ from ...config.config import global_config
|
|||||||
from .utils import truncate_message, calculate_typing_time, count_messages_between
|
from .utils import truncate_message, calculate_typing_time, count_messages_between
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("sender")
|
logger = get_logger("sender")
|
||||||
|
|||||||
@@ -13,6 +13,9 @@ from ...config.config import global_config
|
|||||||
from ..models.utils_model import LLMRequest
|
from ..models.utils_model import LLMRequest
|
||||||
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("chat_image")
|
logger = get_logger("chat_image")
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,7 @@
|
|||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
# 创建APIRouter而不是FastAPI实例
|
# 创建APIRouter而不是FastAPI实例
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|||||||
@@ -15,6 +15,9 @@ from ...config.config import global_config
|
|||||||
from ..chat.utils_image import image_path_to_base64, image_manager
|
from ..chat.utils_image import image_path_to_base64, image_manager
|
||||||
from ..models.utils_model import LLMRequest
|
from ..models.utils_model import LLMRequest
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("emoji")
|
logger = get_logger("emoji")
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,9 @@ from src.plugins.chat.utils import process_llm_response
|
|||||||
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
|
||||||
from src.plugins.moods.moods import MoodManager
|
from src.plugins.moods.moods import MoodManager
|
||||||
from src.heart_flow.utils_chat import get_chat_type_and_target_info
|
from src.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒
|
WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ from ..storage.storage import MessageStorage
|
|||||||
from ..chat.utils import truncate_message
|
from ..chat.utils import truncate_message
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.plugins.chat.utils import calculate_typing_time
|
from src.plugins.chat.utils import calculate_typing_time
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("sender")
|
logger = get_logger("sender")
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ try:
|
|||||||
embed_manager.load_from_file()
|
embed_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
logger.error("从文件加载Embedding库时发生错误:{}".format(e))
|
||||||
|
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||||
logger.info("Embedding库加载完成")
|
logger.info("Embedding库加载完成")
|
||||||
# 初始化KG
|
# 初始化KG
|
||||||
kg_manager = KGManager()
|
kg_manager = KGManager()
|
||||||
@@ -34,6 +35,7 @@ try:
|
|||||||
kg_manager.load_from_file()
|
kg_manager.load_from_file()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error("从文件加载KG时发生错误:{}".format(e))
|
logger.error("从文件加载KG时发生错误:{}".format(e))
|
||||||
|
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
|
||||||
logger.info("KG加载完成")
|
logger.info("KG加载完成")
|
||||||
|
|
||||||
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
logger.info(f"KG节点数量:{len(kg_manager.graph.get_node_list())}")
|
||||||
|
|||||||
@@ -12,6 +12,21 @@ from .llm_client import LLMClient
|
|||||||
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
|
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
|
||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
from .global_logger import logger
|
from .global_logger import logger
|
||||||
|
from rich.traceback import install
|
||||||
|
from rich.progress import (
|
||||||
|
Progress,
|
||||||
|
BarColumn,
|
||||||
|
TimeElapsedColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
MofNCompleteColumn,
|
||||||
|
SpinnerColumn,
|
||||||
|
TextColumn,
|
||||||
|
)
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -49,20 +64,35 @@ class EmbeddingStore:
|
|||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
|
||||||
|
|
||||||
def batch_insert_strs(self, strs: List[str]) -> None:
|
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||||
"""向库中存入字符串"""
|
"""向库中存入字符串"""
|
||||||
# 逐项处理
|
total = len(strs)
|
||||||
for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"):
|
with Progress(
|
||||||
# 计算hash去重
|
SpinnerColumn(),
|
||||||
item_hash = self.namespace + "-" + get_sha256(s)
|
TextColumn("[progress.description]{task.description}"),
|
||||||
if item_hash in self.store:
|
BarColumn(),
|
||||||
continue
|
TaskProgressColumn(),
|
||||||
|
MofNCompleteColumn(),
|
||||||
|
"•",
|
||||||
|
TimeElapsedColumn(),
|
||||||
|
"<",
|
||||||
|
TimeRemainingColumn(),
|
||||||
|
transient=False,
|
||||||
|
) as progress:
|
||||||
|
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
||||||
|
for s in strs:
|
||||||
|
# 计算hash去重
|
||||||
|
item_hash = self.namespace + "-" + get_sha256(s)
|
||||||
|
if item_hash in self.store:
|
||||||
|
progress.update(task, advance=1)
|
||||||
|
continue
|
||||||
|
|
||||||
# 获取embedding
|
# 获取embedding
|
||||||
embedding = self._get_embedding(s)
|
embedding = self._get_embedding(s)
|
||||||
|
|
||||||
# 存入
|
# 存入
|
||||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||||
|
progress.update(task, advance=1)
|
||||||
|
|
||||||
def save_to_file(self) -> None:
|
def save_to_file(self) -> None:
|
||||||
"""保存到文件"""
|
"""保存到文件"""
|
||||||
@@ -188,7 +218,7 @@ class EmbeddingManager:
|
|||||||
|
|
||||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||||
"""将段落编码存入Embedding库"""
|
"""将段落编码存入Embedding库"""
|
||||||
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()))
|
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
|
||||||
|
|
||||||
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||||
"""将实体编码存入Embedding库"""
|
"""将实体编码存入Embedding库"""
|
||||||
@@ -197,7 +227,7 @@ class EmbeddingManager:
|
|||||||
for triple in triple_list:
|
for triple in triple_list:
|
||||||
entities.add(triple[0])
|
entities.add(triple[0])
|
||||||
entities.add(triple[2])
|
entities.add(triple[2])
|
||||||
self.entities_embedding_store.batch_insert_strs(list(entities))
|
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
|
||||||
|
|
||||||
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
|
||||||
"""将关系编码存入Embedding库"""
|
"""将关系编码存入Embedding库"""
|
||||||
@@ -205,7 +235,7 @@ class EmbeddingManager:
|
|||||||
for triples in triple_list_data.values():
|
for triples in triple_list_data.values():
|
||||||
graph_triples.extend([tuple(t) for t in triples])
|
graph_triples.extend([tuple(t) for t in triples])
|
||||||
graph_triples = list(set(graph_triples))
|
graph_triples = list(set(graph_triples))
|
||||||
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples])
|
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
|
||||||
|
|
||||||
def load_from_file(self):
|
def load_from_file(self):
|
||||||
"""从文件加载"""
|
"""从文件加载"""
|
||||||
|
|||||||
@@ -5,7 +5,16 @@ from typing import Dict, List, Tuple
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pandas as pd
|
import pandas as pd
|
||||||
import tqdm
|
from rich.progress import (
|
||||||
|
Progress,
|
||||||
|
BarColumn,
|
||||||
|
TimeElapsedColumn,
|
||||||
|
TimeRemainingColumn,
|
||||||
|
TaskProgressColumn,
|
||||||
|
MofNCompleteColumn,
|
||||||
|
SpinnerColumn,
|
||||||
|
TextColumn,
|
||||||
|
)
|
||||||
from quick_algo import di_graph, pagerank
|
from quick_algo import di_graph, pagerank
|
||||||
|
|
||||||
|
|
||||||
@@ -132,41 +141,56 @@ class KGManager:
|
|||||||
ent_hash_list = list(ent_hash_list)
|
ent_hash_list = list(ent_hash_list)
|
||||||
|
|
||||||
synonym_hash_set = set()
|
synonym_hash_set = set()
|
||||||
|
|
||||||
synonym_result = dict()
|
synonym_result = dict()
|
||||||
|
|
||||||
# 对每个实体节点,查找其相似的实体节点,建立扩展连接
|
# rich 进度条
|
||||||
for ent_hash in tqdm.tqdm(ent_hash_list):
|
total = len(ent_hash_list)
|
||||||
if ent_hash in synonym_hash_set:
|
with Progress(
|
||||||
# 避免同一批次内重复添加
|
SpinnerColumn(),
|
||||||
continue
|
TextColumn("[progress.description]{task.description}"),
|
||||||
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
|
BarColumn(),
|
||||||
assert isinstance(ent, EmbeddingStoreItem)
|
TaskProgressColumn(),
|
||||||
if ent is None:
|
MofNCompleteColumn(),
|
||||||
continue
|
"•",
|
||||||
# 查询相似实体
|
TimeElapsedColumn(),
|
||||||
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
"<",
|
||||||
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
|
TimeRemainingColumn(),
|
||||||
)
|
transient=False,
|
||||||
res_ent = [] # Debug
|
) as progress:
|
||||||
for res_ent_hash, similarity in similar_ents:
|
task = progress.add_task("同义词连接", total=total)
|
||||||
if res_ent_hash == ent_hash:
|
for ent_hash in ent_hash_list:
|
||||||
# 避免自连接
|
if ent_hash in synonym_hash_set:
|
||||||
|
progress.update(task, advance=1)
|
||||||
continue
|
continue
|
||||||
if similarity < global_config["rag"]["params"]["synonym_threshold"]:
|
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
|
||||||
# 相似度阈值
|
assert isinstance(ent, EmbeddingStoreItem)
|
||||||
|
if ent is None:
|
||||||
|
progress.update(task, advance=1)
|
||||||
continue
|
continue
|
||||||
node_to_node[(res_ent_hash, ent_hash)] = similarity
|
# 查询相似实体
|
||||||
node_to_node[(ent_hash, res_ent_hash)] = similarity
|
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
|
||||||
synonym_hash_set.add(res_ent_hash)
|
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
|
||||||
new_edge_cnt += 1
|
)
|
||||||
res_ent.append(
|
res_ent = [] # Debug
|
||||||
(
|
for res_ent_hash, similarity in similar_ents:
|
||||||
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
|
if res_ent_hash == ent_hash:
|
||||||
similarity,
|
# 避免自连接
|
||||||
)
|
continue
|
||||||
) # Debug
|
if similarity < global_config["rag"]["params"]["synonym_threshold"]:
|
||||||
synonym_result[ent.str] = res_ent
|
# 相似度阈值
|
||||||
|
continue
|
||||||
|
node_to_node[(res_ent_hash, ent_hash)] = similarity
|
||||||
|
node_to_node[(ent_hash, res_ent_hash)] = similarity
|
||||||
|
synonym_hash_set.add(res_ent_hash)
|
||||||
|
new_edge_cnt += 1
|
||||||
|
res_ent.append(
|
||||||
|
(
|
||||||
|
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
|
||||||
|
similarity,
|
||||||
|
)
|
||||||
|
) # Debug
|
||||||
|
synonym_result[ent.str] = res_ent
|
||||||
|
progress.update(task, advance=1)
|
||||||
|
|
||||||
for k, v in synonym_result.items():
|
for k, v in synonym_result.items():
|
||||||
print(f'"{k}"的相似实体为:{v}')
|
print(f'"{k}"的相似实体为:{v}')
|
||||||
|
|||||||
@@ -1,9 +1,13 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import glob
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
|
||||||
from .lpmmconfig import INVALID_ENTITY, global_config
|
from .lpmmconfig import INVALID_ENTITY, global_config
|
||||||
|
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||||
|
|
||||||
|
|
||||||
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
def _filter_invalid_entities(entities: List[str]) -> List[str]:
|
||||||
"""过滤无效的实体"""
|
"""过滤无效的实体"""
|
||||||
@@ -74,12 +78,22 @@ class OpenIE:
|
|||||||
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
|
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _from_dict(data):
|
def _from_dict(data_list):
|
||||||
"""从字典中获取OpenIE对象"""
|
"""从多个字典合并OpenIE对象"""
|
||||||
|
# data_list: List[dict]
|
||||||
|
all_docs = []
|
||||||
|
for data in data_list:
|
||||||
|
all_docs.extend(data.get("docs", []))
|
||||||
|
# 重新计算统计
|
||||||
|
sum_phrase_chars = sum([len(e) for chunk in all_docs for e in chunk["extracted_entities"]])
|
||||||
|
sum_phrase_words = sum([len(e.split()) for chunk in all_docs for e in chunk["extracted_entities"]])
|
||||||
|
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in all_docs])
|
||||||
|
avg_ent_chars = round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0
|
||||||
|
avg_ent_words = round(sum_phrase_words / num_phrases, 4) if num_phrases else 0
|
||||||
return OpenIE(
|
return OpenIE(
|
||||||
docs=data["docs"],
|
docs=all_docs,
|
||||||
avg_ent_chars=data["avg_ent_chars"],
|
avg_ent_chars=avg_ent_chars,
|
||||||
avg_ent_words=data["avg_ent_words"],
|
avg_ent_words=avg_ent_words,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _to_dict(self):
|
def _to_dict(self):
|
||||||
@@ -92,12 +106,20 @@ class OpenIE:
|
|||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load() -> "OpenIE":
|
def load() -> "OpenIE":
|
||||||
"""从文件中加载OpenIE数据"""
|
"""从OPENIE_DIR下所有json文件合并加载OpenIE数据"""
|
||||||
with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f:
|
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
|
||||||
data = json.loads(f.read())
|
if not os.path.exists(openie_dir):
|
||||||
|
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
|
||||||
openie_data = OpenIE._from_dict(data)
|
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
|
||||||
|
data_list = []
|
||||||
|
for file in json_files:
|
||||||
|
with open(file, "r", encoding="utf-8") as f:
|
||||||
|
data = json.load(f)
|
||||||
|
data_list.append(data)
|
||||||
|
if not data_list:
|
||||||
|
# print(f"111111111111111111111Root Path : \n{ROOT_PATH}")
|
||||||
|
raise Exception(f"未在 {openie_dir} 找到任何OpenIE json文件")
|
||||||
|
openie_data = OpenIE._from_dict(data_list)
|
||||||
return openie_data
|
return openie_data
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -132,3 +154,8 @@ class OpenIE:
|
|||||||
"""提取原始段落"""
|
"""提取原始段落"""
|
||||||
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
|
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
|
||||||
return raw_paragraph_dict
|
return raw_paragraph_dict
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试代码
|
||||||
|
print(ROOT_PATH)
|
||||||
|
|||||||
@@ -6,21 +6,25 @@ from .lpmmconfig import global_config
|
|||||||
from .utils.hash import get_sha256
|
from .utils.hash import get_sha256
|
||||||
|
|
||||||
|
|
||||||
def load_raw_data() -> tuple[list[str], list[str]]:
|
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
|
||||||
"""加载原始数据文件
|
"""加载原始数据文件
|
||||||
|
|
||||||
读取原始数据文件,将原始数据加载到内存中
|
读取原始数据文件,将原始数据加载到内存中
|
||||||
|
|
||||||
|
Args:
|
||||||
|
path: 可选,指定要读取的json文件绝对路径
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
- raw_data: 原始数据字典
|
- raw_data: 原始数据列表
|
||||||
- md5_set: 原始数据的SHA256集合
|
- sha256_list: 原始数据的SHA256集合
|
||||||
"""
|
"""
|
||||||
# 读取import.json文件
|
# 读取指定路径或默认路径的json文件
|
||||||
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
|
json_path = path if path else global_config["persistence"]["raw_data_path"]
|
||||||
with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f:
|
if os.path.exists(json_path):
|
||||||
|
with open(json_path, "r", encoding="utf-8") as f:
|
||||||
import_json = json.loads(f.read())
|
import_json = json.loads(f.read())
|
||||||
else:
|
else:
|
||||||
raise Exception("原始数据文件读取失败")
|
raise Exception(f"原始数据文件读取失败: {json_path}")
|
||||||
# import_json内容示例:
|
# import_json内容示例:
|
||||||
# import_json = [
|
# import_json = [
|
||||||
# "The capital of China is Beijing. The capital of France is Paris.",
|
# "The capital of China is Beijing. The capital of France is Paris.",
|
||||||
|
|||||||
@@ -20,6 +20,9 @@ from ..utils.chat_message_builder import (
|
|||||||
) # 导入 build_readable_messages
|
) # 导入 build_readable_messages
|
||||||
from ..chat.utils import translate_timestamp_to_human_readable
|
from ..chat.utils import translate_timestamp_to_human_readable
|
||||||
from .memory_config import MemoryConfig
|
from .memory_config import MemoryConfig
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
def calculate_information_content(text):
|
def calculate_information_content(text):
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ import os
|
|||||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
|
||||||
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
from src.plugins.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
async def test_memory_system():
|
async def test_memory_system():
|
||||||
|
|||||||
@@ -9,6 +9,9 @@ from Hippocampus import Hippocampus # 海马体和记忆图
|
|||||||
|
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -6,6 +6,9 @@ from typing import Tuple, Union
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import requests
|
import requests
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_module_logger("offline_llm")
|
logger = get_module_logger("offline_llm")
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,9 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from scipy import stats
|
from scipy import stats
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
class DistributionVisualizer:
|
class DistributionVisualizer:
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ import io
|
|||||||
import os
|
import os
|
||||||
from ...common.database import db
|
from ...common.database import db
|
||||||
from ...config.config import global_config
|
from ...config.config import global_config
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_module_logger("model_utils")
|
logger = get_module_logger("model_utils")
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,11 @@ import re
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import asyncio
|
import asyncio
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
|
|
||||||
# import traceback
|
# import traceback
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_module_logger("prompt_build")
|
logger = get_module_logger("prompt_build")
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,9 @@ from time import perf_counter
|
|||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Optional, Dict, Callable
|
from typing import Optional, Dict, Callable
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
# 更好的计时器
|
# 更好的计时器
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ from abc import ABC, abstractmethod
|
|||||||
import importlib
|
import importlib
|
||||||
from typing import Dict, Optional
|
from typing import Dict, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
基类方法概览:
|
基类方法概览:
|
||||||
|
|||||||
@@ -7,6 +7,9 @@ from datetime import datetime
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
from rich.table import Table
|
from rich.table import Table
|
||||||
|
from rich.traceback import install
|
||||||
|
|
||||||
|
install(extra_lines=3)
|
||||||
|
|
||||||
# 添加项目根目录到 Python 路径
|
# 添加项目根目录到 Python 路径
|
||||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||||
@@ -15,6 +18,7 @@ sys.path.append(root_path)
|
|||||||
# 现在可以导入src模块
|
# 现在可以导入src模块
|
||||||
from src.common.database import db # noqa E402
|
from src.common.database import db # noqa E402
|
||||||
|
|
||||||
|
|
||||||
# 加载根目录下的env.edv文件
|
# 加载根目录下的env.edv文件
|
||||||
env_path = os.path.join(root_path, ".env")
|
env_path = os.path.join(root_path, ".env")
|
||||||
if not os.path.exists(env_path):
|
if not os.path.exists(env_path):
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ res_top_k = 3 # 最终提供的文段TopK
|
|||||||
[persistence]
|
[persistence]
|
||||||
# 持久化配置(存储中间数据,防止重复计算)
|
# 持久化配置(存储中间数据,防止重复计算)
|
||||||
data_root_path = "data" # 数据根目录
|
data_root_path = "data" # 数据根目录
|
||||||
raw_data_path = "data/import.json" # 原始数据路径
|
raw_data_path = "data/imported_lpmm_data" # 原始数据路径
|
||||||
openie_data_path = "data/openie.json" # OpenIE数据路径
|
openie_data_path = "data/openie" # OpenIE数据路径
|
||||||
embedding_data_dir = "data/embedding" # 嵌入数据目录
|
embedding_data_dir = "data/embedding" # 嵌入数据目录
|
||||||
rag_data_dir = "data/rag" # RAG数据目录
|
rag_data_dir = "data/rag" # RAG数据目录
|
||||||
|
|||||||
Reference in New Issue
Block a user