This commit is contained in:
SengokuCola
2025-05-02 19:15:44 +08:00
40 changed files with 408 additions and 110 deletions

3
bot.py
View File

@@ -13,6 +13,9 @@ from src.common.logger_manager import get_logger
# from src.common.logger import LogConfig, CONFIRM_STYLE_CONFIG # from src.common.logger import LogConfig, CONFIRM_STYLE_CONFIG
from src.common.crash_logger import install_crash_handler from src.common.crash_logger import install_crash_handler
from src.main import MainSystem from src.main import MainSystem
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("main") logger = get_logger("main")

View File

@@ -6,6 +6,7 @@
import sys import sys
import os import os
from time import sleep
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
@@ -19,9 +20,14 @@ from src.plugins.knowledge.src.utils.hash import get_sha256
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
OPENIE_DIR = (
global_config["persistence"]["openie_data_path"]
if global_config["persistence"]["openie_data_path"]
else os.path.join(ROOT_PATH, "data/openie")
)
logger = get_module_logger("OpenIE导入")
logger = get_module_logger("LPMM知识库-OpenIE导入")
def hash_deduplicate( def hash_deduplicate(
@@ -66,8 +72,45 @@ def handle_import_openie(openie_data: OpenIE, embed_manager: EmbeddingManager, k
entity_list_data = openie_data.extract_entity_dict() entity_list_data = openie_data.extract_entity_dict()
# 索引的三元组列表 # 索引的三元组列表
triple_list_data = openie_data.extract_triple_dict() triple_list_data = openie_data.extract_triple_dict()
# print(openie_data.docs)
if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data): if len(raw_paragraphs) != len(entity_list_data) or len(raw_paragraphs) != len(triple_list_data):
logger.error("OpenIE数据存在异常") logger.error("OpenIE数据存在异常")
logger.error(f"原始段落数量:{len(raw_paragraphs)}")
logger.error(f"实体列表数量:{len(entity_list_data)}")
logger.error(f"三元组列表数量:{len(triple_list_data)}")
logger.error("OpenIE数据段落数量与实体列表数量或三元组列表数量不一致")
logger.error("请保证你的原始数据分段良好,不要有类似于 “.....” 单独成一段的情况")
logger.error("或者一段中只有符号的情况")
# 新增检查docs中每条数据的完整性
logger.error("系统将于2秒后开始检查数据完整性")
sleep(2)
found_missing = False
for doc in getattr(openie_data, "docs", []):
idx = doc.get("idx", "<无idx>")
passage = doc.get("passage", "<无passage>")
missing = []
# 检查字段是否存在且非空
if "passage" not in doc or not doc.get("passage"):
missing.append("passage")
if "extracted_entities" not in doc or not isinstance(doc.get("extracted_entities"), list):
missing.append("名词列表缺失")
elif len(doc.get("extracted_entities", [])) == 0:
missing.append("名词列表为空")
if "extracted_triples" not in doc or not isinstance(doc.get("extracted_triples"), list):
missing.append("主谓宾三元组缺失")
elif len(doc.get("extracted_triples", [])) == 0:
missing.append("主谓宾三元组为空")
# 输出所有doc的idx
# print(f"检查: idx={idx}")
if missing:
found_missing = True
logger.error("\n")
logger.error("数据缺失:")
logger.error(f"对应哈希值:{idx}")
logger.error(f"对应文段内容内容:{passage}")
logger.error(f"非法原因:{', '.join(missing)}")
if not found_missing:
print("所有数据均完整,没有发现缺失字段。")
return False return False
# 将索引换为对应段落的hash值 # 将索引换为对应段落的hash值
logger.info("正在进行段落去重与重索引") logger.info("正在进行段落去重与重索引")
@@ -131,6 +174,7 @@ def main():
embed_manager.load_from_file() embed_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载Embedding库时发生错误{}".format(e)) logger.error("从文件加载Embedding库时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG # 初始化KG
kg_manager = KGManager() kg_manager = KGManager()
@@ -139,6 +183,7 @@ def main():
kg_manager.load_from_file() kg_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载KG时发生错误{}".format(e)) logger.error("从文件加载KG时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,请忽略此错误")
logger.info("KG加载完成") logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}") logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")
@@ -163,4 +208,5 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
# logger.info(f"111111111111111111111111{ROOT_PATH}")
main() main()

View File

@@ -4,11 +4,13 @@ import signal
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from threading import Lock, Event from threading import Lock, Event
import sys import sys
import glob
import datetime
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
import tqdm from rich.progress import Progress # 替换为 rich 进度条
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from src.plugins.knowledge.src.lpmmconfig import global_config from src.plugins.knowledge.src.lpmmconfig import global_config
@@ -16,10 +18,31 @@ from src.plugins.knowledge.src.ie_process import info_extract_from_str
from src.plugins.knowledge.src.llm_client import LLMClient from src.plugins.knowledge.src.llm_client import LLMClient
from src.plugins.knowledge.src.open_ie import OpenIE from src.plugins.knowledge.src.open_ie import OpenIE
from src.plugins.knowledge.src.raw_processing import load_raw_data from src.plugins.knowledge.src.raw_processing import load_raw_data
from rich.progress import (
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
logger = get_module_logger("LPMM知识库-信息提取") logger = get_module_logger("LPMM知识库-信息提取")
TEMP_DIR = "./temp"
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
TEMP_DIR = os.path.join(ROOT_PATH, "temp")
IMPORTED_DATA_PATH = (
global_config["persistence"]["raw_data_path"]
if global_config["persistence"]["raw_data_path"]
else os.path.join(ROOT_PATH, "data/imported_lpmm_data")
)
OPENIE_OUTPUT_DIR = (
global_config["persistence"]["openie_data_path"]
if global_config["persistence"]["openie_data_path"]
else os.path.join(ROOT_PATH, "data/openie")
)
# 创建一个线程安全的锁,用于保护文件操作和共享数据 # 创建一个线程安全的锁,用于保护文件操作和共享数据
file_lock = Lock() file_lock = Lock()
@@ -70,8 +93,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
# 如果保存失败,确保不会留下损坏的文件 # 如果保存失败,确保不会留下损坏的文件
if os.path.exists(temp_file_path): if os.path.exists(temp_file_path):
os.remove(temp_file_path) os.remove(temp_file_path)
# 设置shutdown_event以终止程序 sys.exit(0)
shutdown_event.set()
return None, pg_hash return None, pg_hash
return doc_item, None return doc_item, None
@@ -79,7 +101,7 @@ def process_single_text(pg_hash, raw_data, llm_client_list):
def signal_handler(_signum, _frame): def signal_handler(_signum, _frame):
"""处理Ctrl+C信号""" """处理Ctrl+C信号"""
logger.info("\n接收到中断信号,正在优雅地关闭程序...") logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set() sys.exit(0)
def main(): def main():
@@ -110,33 +132,61 @@ def main():
global_config["llm_providers"][key]["api_key"], global_config["llm_providers"][key]["api_key"],
) )
logger.info("正在加载原始数据") # 检查 openie 输出目录
sha256_list, raw_datas = load_raw_data() if not os.path.exists(OPENIE_OUTPUT_DIR):
logger.info("原始数据加载完成\n") os.makedirs(OPENIE_OUTPUT_DIR)
logger.info(f"已创建输出目录: {OPENIE_OUTPUT_DIR}")
# 创建临时目录 # 确保 TEMP_DIR 目录存在
if not os.path.exists(f"{TEMP_DIR}"): if not os.path.exists(TEMP_DIR):
os.makedirs(f"{TEMP_DIR}") os.makedirs(TEMP_DIR)
logger.info(f"已创建缓存目录: {TEMP_DIR}")
# 遍历IMPORTED_DATA_PATH下所有json文件
imported_files = sorted(glob.glob(os.path.join(IMPORTED_DATA_PATH, "*.json")))
if not imported_files:
logger.error(f"未在 {IMPORTED_DATA_PATH} 下找到任何json文件")
sys.exit(1)
all_sha256_list = []
all_raw_datas = []
for imported_file in imported_files:
logger.info(f"正在处理文件: {imported_file}")
try:
sha256_list, raw_datas = load_raw_data(imported_file)
except Exception as e:
logger.error(f"读取文件失败: {imported_file}, 错误: {e}")
continue
all_sha256_list.extend(sha256_list)
all_raw_datas.extend(raw_datas)
failed_sha256 = [] failed_sha256 = []
open_ie_doc = [] open_ie_doc = []
# 创建线程池最大线程数为50
workers = global_config["info_extraction"]["workers"] workers = global_config["info_extraction"]["workers"]
with ThreadPoolExecutor(max_workers=workers) as executor: with ThreadPoolExecutor(max_workers=workers) as executor:
# 提交所有任务到线程池
future_to_hash = { future_to_hash = {
executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash executor.submit(process_single_text, pg_hash, raw_data, llm_client_list): pg_hash
for pg_hash, raw_data in zip(sha256_list, raw_datas) for pg_hash, raw_data in zip(all_sha256_list, all_raw_datas)
} }
# 使用tqdm显示进度 with Progress(
with tqdm.tqdm(total=len(future_to_hash), postfix="正在进行提取:") as pbar: SpinnerColumn(),
# 处理完成的任务 TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("正在进行提取:", total=len(future_to_hash))
try: try:
for future in as_completed(future_to_hash): for future in as_completed(future_to_hash):
if shutdown_event.is_set(): if shutdown_event.is_set():
# 取消所有未完成的任务
for f in future_to_hash: for f in future_to_hash:
if not f.done(): if not f.done():
f.cancel() f.cancel()
@@ -149,26 +199,38 @@ def main():
elif doc_item: elif doc_item:
with open_ie_doc_lock: with open_ie_doc_lock:
open_ie_doc.append(doc_item) open_ie_doc.append(doc_item)
pbar.update(1) progress.update(task, advance=1)
except KeyboardInterrupt: except KeyboardInterrupt:
# 如果在这里捕获到KeyboardInterrupt说明signal_handler可能没有正常工作
logger.info("\n接收到中断信号,正在优雅地关闭程序...") logger.info("\n接收到中断信号,正在优雅地关闭程序...")
shutdown_event.set() shutdown_event.set()
# 取消所有未完成的任务
for f in future_to_hash: for f in future_to_hash:
if not f.done(): if not f.done():
f.cancel() f.cancel()
# 保存信息提取结果 # 合并所有文件的提取结果并保存
sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) if open_ie_doc:
sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]]) sum_phrase_chars = sum([len(e) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc]) sum_phrase_words = sum([len(e.split()) for chunk in open_ie_doc for e in chunk["extracted_entities"]])
openie_obj = OpenIE( num_phrases = sum([len(chunk["extracted_entities"]) for chunk in open_ie_doc])
open_ie_doc, openie_obj = OpenIE(
round(sum_phrase_chars / num_phrases, 4), open_ie_doc,
round(sum_phrase_words / num_phrases, 4), round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0,
) round(sum_phrase_words / num_phrases, 4) if num_phrases else 0,
OpenIE.save(openie_obj) )
# 输出文件名格式MM-DD-HH-ss-openie.json
now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-openie.json")
output_path = os.path.join(OPENIE_OUTPUT_DIR, filename)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(
openie_obj.to_dict() if hasattr(openie_obj, "to_dict") else openie_obj.__dict__,
f,
ensure_ascii=False,
indent=4,
)
logger.info(f"信息提取结果已保存到: {output_path}")
else:
logger.warning("没有可保存的信息提取结果")
logger.info("--------信息提取完成--------") logger.info("--------信息提取完成--------")
logger.info(f"提取失败的文段SHA256{failed_sha256}") logger.info(f"提取失败的文段SHA256{failed_sha256}")

View File

@@ -2,18 +2,22 @@ import json
import os import os
from pathlib import Path from pathlib import Path
import sys # 新增系统模块导入 import sys # 新增系统模块导入
import datetime # 新增导入
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
logger = get_module_logger("LPMM数据库-原始数据处理") logger = get_module_logger("LPMM数据库-原始数据处理")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
RAW_DATA_PATH = os.path.join(ROOT_PATH, "data/lpmm_raw_data")
IMPORTED_DATA_PATH = os.path.join(ROOT_PATH, "data/imported_lpmm_data")
# 添加项目根目录到 sys.path # 添加项目根目录到 sys.path
def check_and_create_dirs(): def check_and_create_dirs():
"""检查并创建必要的目录""" """检查并创建必要的目录"""
required_dirs = ["data/lpmm_raw_data", "data/imported_lpmm_data"] required_dirs = [RAW_DATA_PATH, IMPORTED_DATA_PATH]
for dir_path in required_dirs: for dir_path in required_dirs:
if not os.path.exists(dir_path): if not os.path.exists(dir_path):
@@ -58,17 +62,17 @@ def main():
# 检查并创建必要的目录 # 检查并创建必要的目录
check_and_create_dirs() check_and_create_dirs()
# 检查输出文件是否存在 # # 检查输出文件是否存在
if os.path.exists("data/import.json"): # if os.path.exists(RAW_DATA_PATH):
logger.error("错误: data/import.json 已存在,请先处理或删除该文件") # logger.error("错误: data/import.json 已存在,请先处理或删除该文件")
sys.exit(1) # sys.exit(1)
if os.path.exists("data/openie.json"): # if os.path.exists(RAW_DATA_PATH):
logger.error("错误: data/openie.json 已存在,请先处理或删除该文件") # logger.error("错误: data/openie.json 已存在,请先处理或删除该文件")
sys.exit(1) # sys.exit(1)
# 获取所有原始文本文件 # 获取所有原始文本文件
raw_files = list(Path("data/lpmm_raw_data").glob("*.txt")) raw_files = list(Path(RAW_DATA_PATH).glob("*.txt"))
if not raw_files: if not raw_files:
logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件") logger.warning("警告: data/lpmm_raw_data 中没有找到任何 .txt 文件")
sys.exit(1) sys.exit(1)
@@ -80,8 +84,10 @@ def main():
paragraphs = process_text_file(file) paragraphs = process_text_file(file)
all_paragraphs.extend(paragraphs) all_paragraphs.extend(paragraphs)
# 保存合并后的结果 # 保存合并后的结果到 IMPORTED_DATA_PATH文件名格式为 MM-DD-HH-ss-imported-data.json
output_path = "data/import.json" now = datetime.datetime.now()
filename = now.strftime("%m-%d-%H-%S-imported-data.json")
output_path = os.path.join(IMPORTED_DATA_PATH, filename)
with open(output_path, "w", encoding="utf-8") as f: with open(output_path, "w", encoding="utf-8") as f:
json.dump(all_paragraphs, f, ensure_ascii=False, indent=4) json.dump(all_paragraphs, f, ensure_ascii=False, indent=4)
@@ -89,4 +95,6 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
print(f"Raw Data Path: {RAW_DATA_PATH}")
print(f"Imported Data Path: {IMPORTED_DATA_PATH}")
main() main()

View File

@@ -1,6 +1,9 @@
import os import os
from pymongo import MongoClient from pymongo import MongoClient
from pymongo.database import Database from pymongo.database import Database
from rich.traceback import install
install(extra_lines=3)
_client = None _client = None
_db = None _db = None

View File

@@ -2,6 +2,9 @@ import functools
import inspect import inspect
from typing import Callable, Any from typing import Callable, Any
from .logger import logger, add_custom_style_handler from .logger import logger, add_custom_style_handler
from rich.traceback import install
install(extra_lines=3)
def use_log_style( def use_log_style(

View File

@@ -2,6 +2,9 @@ from fastapi import FastAPI, APIRouter
from typing import Optional from typing import Optional
from uvicorn import Config, Server as UvicornServer from uvicorn import Config, Server as UvicornServer
import os import os
from rich.traceback import install
install(extra_lines=3)
class Server: class Server:

View File

@@ -14,6 +14,9 @@ from packaging.version import Version, InvalidVersion
from packaging.specifiers import SpecifierSet, InvalidSpecifier from packaging.specifiers import SpecifierSet, InvalidSpecifier
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
# 配置主程序日志格式 # 配置主程序日志格式

View File

@@ -4,6 +4,9 @@ import importlib
import pkgutil import pkgutil
import os import os
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("base_tool") logger = get_logger("base_tool")

View File

@@ -2,6 +2,9 @@ from typing import Optional
from .personality import Personality from .personality import Personality
from .identity import Identity from .identity import Identity
import random import random
from rich.traceback import install
install(extra_lines=3)
class Individuality: class Individuality:

View File

@@ -6,6 +6,9 @@ from typing import Tuple, Union
import aiohttp import aiohttp
import requests import requests
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")

View File

@@ -17,6 +17,9 @@ from .common.logger_manager import get_logger
from .plugins.remote import heartbeat_thread # noqa: F401 from .plugins.remote import heartbeat_thread # noqa: F401
from .individuality.individuality import Individuality from .individuality.individuality import Individuality
from .common.server import global_server from .common.server import global_server
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("main") logger = get_logger("main")

View File

@@ -7,6 +7,9 @@ from maim_message import UserInfo
from ...config.config import global_config from ...config.config import global_config
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
from .message_storage import MongoDBMessageStorage from .message_storage import MongoDBMessageStorage
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("chat_observer") logger = get_module_logger("chat_observer")

View File

@@ -23,6 +23,9 @@ from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .waiter import Waiter from .waiter import Waiter
import traceback import traceback
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("pfc") logger = get_logger("pfc")

View File

@@ -8,6 +8,9 @@ from src.plugins.chat.message import MessageSending, MessageSet
from src.plugins.chat.message_sender import message_manager from src.plugins.chat.message_sender import message_manager
from ..storage.storage import MessageStorage from ..storage.storage import MessageStorage
from ...config.config import global_config from ...config.config import global_config
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("message_sender") logger = get_module_logger("message_sender")

View File

@@ -8,6 +8,9 @@ from src.individuality.individuality import Individuality
from .conversation_info import ConversationInfo from .conversation_info import ConversationInfo
from .observation_info import ObservationInfo from .observation_info import ObservationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages from src.plugins.utils.chat_message_builder import build_readable_messages
from rich.traceback import install
install(extra_lines=3)
if TYPE_CHECKING: if TYPE_CHECKING:
pass pass

View File

@@ -9,6 +9,9 @@ from ...common.database import db
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_stream") logger = get_logger("chat_stream")

View File

@@ -9,6 +9,9 @@ from src.common.logger_manager import get_logger
from .chat_stream import ChatStream from .chat_stream import ChatStream
from .utils_image import image_manager from .utils_image import image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_message") logger = get_logger("chat_message")

View File

@@ -13,6 +13,9 @@ from ...config.config import global_config
from .utils import truncate_message, calculate_typing_time, count_messages_between from .utils import truncate_message, calculate_typing_time, count_messages_between
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("sender") logger = get_logger("sender")

View File

@@ -13,6 +13,9 @@ from ...config.config import global_config
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_image") logger = get_logger("chat_image")

View File

@@ -1,4 +1,7 @@
from fastapi import APIRouter, HTTPException from fastapi import APIRouter, HTTPException
from rich.traceback import install
install(extra_lines=3)
# 创建APIRouter而不是FastAPI实例 # 创建APIRouter而不是FastAPI实例
router = APIRouter() router = APIRouter()

View File

@@ -15,6 +15,9 @@ from ...config.config import global_config
from ..chat.utils_image import image_path_to_base64, image_manager from ..chat.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest from ..models.utils_model import LLMRequest
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("emoji") logger = get_logger("emoji")

View File

@@ -27,6 +27,9 @@ from src.plugins.chat.utils import process_llm_response
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
from src.plugins.moods.moods import MoodManager from src.plugins.moods.moods import MoodManager
from src.heart_flow.utils_chat import get_chat_type_and_target_info from src.heart_flow.utils_chat import get_chat_type_and_target_info
from rich.traceback import install
install(extra_lines=3)
WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒 WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒

View File

@@ -9,6 +9,9 @@ from ..storage.storage import MessageStorage
from ..chat.utils import truncate_message from ..chat.utils import truncate_message
from src.common.logger_manager import get_logger from src.common.logger_manager import get_logger
from src.plugins.chat.utils import calculate_typing_time from src.plugins.chat.utils import calculate_typing_time
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("sender") logger = get_logger("sender")

View File

@@ -26,6 +26,7 @@ try:
embed_manager.load_from_file() embed_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载Embedding库时发生错误{}".format(e)) logger.error("从文件加载Embedding库时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成") logger.info("Embedding库加载完成")
# 初始化KG # 初始化KG
kg_manager = KGManager() kg_manager = KGManager()
@@ -34,6 +35,7 @@ try:
kg_manager.load_from_file() kg_manager.load_from_file()
except Exception as e: except Exception as e:
logger.error("从文件加载KG时发生错误{}".format(e)) logger.error("从文件加载KG时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成") logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}") logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")

View File

@@ -12,6 +12,21 @@ from .llm_client import LLMClient
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
from .utils.hash import get_sha256 from .utils.hash import get_sha256
from .global_logger import logger from .global_logger import logger
from rich.traceback import install
from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
install(extra_lines=3)
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
@dataclass @dataclass
@@ -49,20 +64,35 @@ class EmbeddingStore:
def _get_embedding(self, s: str) -> List[float]: def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s) return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
def batch_insert_strs(self, strs: List[str]) -> None: def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串""" """向库中存入字符串"""
# 逐项处理 total = len(strs)
for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"): with Progress(
# 计算hash去重 SpinnerColumn(),
item_hash = self.namespace + "-" + get_sha256(s) TextColumn("[progress.description]{task.description}"),
if item_hash in self.store: BarColumn(),
continue TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
for s in strs:
# 计算hash去重
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash in self.store:
progress.update(task, advance=1)
continue
# 获取embedding # 获取embedding
embedding = self._get_embedding(s) embedding = self._get_embedding(s)
# 存入 # 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s) self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
progress.update(task, advance=1)
def save_to_file(self) -> None: def save_to_file(self) -> None:
"""保存到文件""" """保存到文件"""
@@ -188,7 +218,7 @@ class EmbeddingManager:
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]): def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库""" """将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values())) self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将实体编码存入Embedding库""" """将实体编码存入Embedding库"""
@@ -197,7 +227,7 @@ class EmbeddingManager:
for triple in triple_list: for triple in triple_list:
entities.add(triple[0]) entities.add(triple[0])
entities.add(triple[2]) entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities)) self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]): def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将关系编码存入Embedding库""" """将关系编码存入Embedding库"""
@@ -205,7 +235,7 @@ class EmbeddingManager:
for triples in triple_list_data.values(): for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples]) graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples)) graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples]) self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
def load_from_file(self): def load_from_file(self):
"""从文件加载""" """从文件加载"""

View File

@@ -5,7 +5,16 @@ from typing import Dict, List, Tuple
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import tqdm from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from quick_algo import di_graph, pagerank from quick_algo import di_graph, pagerank
@@ -132,41 +141,56 @@ class KGManager:
ent_hash_list = list(ent_hash_list) ent_hash_list = list(ent_hash_list)
synonym_hash_set = set() synonym_hash_set = set()
synonym_result = dict() synonym_result = dict()
# 对每个实体节点,查找其相似的实体节点,建立扩展连接 # rich 进度条
for ent_hash in tqdm.tqdm(ent_hash_list): total = len(ent_hash_list)
if ent_hash in synonym_hash_set: with Progress(
# 避免同一批次内重复添加 SpinnerColumn(),
continue TextColumn("[progress.description]{task.description}"),
ent = embedding_manager.entities_embedding_store.store.get(ent_hash) BarColumn(),
assert isinstance(ent, EmbeddingStoreItem) TaskProgressColumn(),
if ent is None: MofNCompleteColumn(),
continue "",
# 查询相似实体 TimeElapsedColumn(),
similar_ents = embedding_manager.entities_embedding_store.search_top_k( "<",
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"] TimeRemainingColumn(),
) transient=False,
res_ent = [] # Debug ) as progress:
for res_ent_hash, similarity in similar_ents: task = progress.add_task("同义词连接", total=total)
if res_ent_hash == ent_hash: for ent_hash in ent_hash_list:
# 避免自连接 if ent_hash in synonym_hash_set:
progress.update(task, advance=1)
continue continue
if similarity < global_config["rag"]["params"]["synonym_threshold"]: ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
# 相似度阈值 assert isinstance(ent, EmbeddingStoreItem)
if ent is None:
progress.update(task, advance=1)
continue continue
node_to_node[(res_ent_hash, ent_hash)] = similarity # 查询相似实体
node_to_node[(ent_hash, res_ent_hash)] = similarity similar_ents = embedding_manager.entities_embedding_store.search_top_k(
synonym_hash_set.add(res_ent_hash) ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
new_edge_cnt += 1 )
res_ent.append( res_ent = [] # Debug
( for res_ent_hash, similarity in similar_ents:
embedding_manager.entities_embedding_store.store[res_ent_hash].str, if res_ent_hash == ent_hash:
similarity, # 避免自连接
) continue
) # Debug if similarity < global_config["rag"]["params"]["synonym_threshold"]:
synonym_result[ent.str] = res_ent # 相似度阈值
continue
node_to_node[(res_ent_hash, ent_hash)] = similarity
node_to_node[(ent_hash, res_ent_hash)] = similarity
synonym_hash_set.add(res_ent_hash)
new_edge_cnt += 1
res_ent.append(
(
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
similarity,
)
) # Debug
synonym_result[ent.str] = res_ent
progress.update(task, advance=1)
for k, v in synonym_result.items(): for k, v in synonym_result.items():
print(f'"{k}"的相似实体为:{v}') print(f'"{k}"的相似实体为:{v}')

View File

@@ -1,9 +1,13 @@
import json import json
import os
import glob
from typing import Any, Dict, List from typing import Any, Dict, List
from .lpmmconfig import INVALID_ENTITY, global_config from .lpmmconfig import INVALID_ENTITY, global_config
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
def _filter_invalid_entities(entities: List[str]) -> List[str]: def _filter_invalid_entities(entities: List[str]) -> List[str]:
"""过滤无效的实体""" """过滤无效的实体"""
@@ -74,12 +78,22 @@ class OpenIE:
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"]) doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
@staticmethod @staticmethod
def _from_dict(data): def _from_dict(data_list):
"""字典中获取OpenIE对象""" """多个字典合并OpenIE对象"""
# data_list: List[dict]
all_docs = []
for data in data_list:
all_docs.extend(data.get("docs", []))
# 重新计算统计
sum_phrase_chars = sum([len(e) for chunk in all_docs for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in all_docs for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in all_docs])
avg_ent_chars = round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0
avg_ent_words = round(sum_phrase_words / num_phrases, 4) if num_phrases else 0
return OpenIE( return OpenIE(
docs=data["docs"], docs=all_docs,
avg_ent_chars=data["avg_ent_chars"], avg_ent_chars=avg_ent_chars,
avg_ent_words=data["avg_ent_words"], avg_ent_words=avg_ent_words,
) )
def _to_dict(self): def _to_dict(self):
@@ -92,12 +106,20 @@ class OpenIE:
@staticmethod @staticmethod
def load() -> "OpenIE": def load() -> "OpenIE":
"""文件中加载OpenIE数据""" """OPENIE_DIR下所有json文件合并加载OpenIE数据"""
with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f: openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
data = json.loads(f.read()) if not os.path.exists(openie_dir):
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
openie_data = OpenIE._from_dict(data) json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
data_list = []
for file in json_files:
with open(file, "r", encoding="utf-8") as f:
data = json.load(f)
data_list.append(data)
if not data_list:
# print(f"111111111111111111111Root Path : \n{ROOT_PATH}")
raise Exception(f"未在 {openie_dir} 找到任何OpenIE json文件")
openie_data = OpenIE._from_dict(data_list)
return openie_data return openie_data
@staticmethod @staticmethod
@@ -132,3 +154,8 @@ class OpenIE:
"""提取原始段落""" """提取原始段落"""
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs}) raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
return raw_paragraph_dict return raw_paragraph_dict
if __name__ == "__main__":
# 测试代码
print(ROOT_PATH)

View File

@@ -6,21 +6,25 @@ from .lpmmconfig import global_config
from .utils.hash import get_sha256 from .utils.hash import get_sha256
def load_raw_data() -> tuple[list[str], list[str]]: def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
"""加载原始数据文件 """加载原始数据文件
读取原始数据文件,将原始数据加载到内存中 读取原始数据文件,将原始数据加载到内存中
Args:
path: 可选指定要读取的json文件绝对路径
Returns: Returns:
- raw_data: 原始数据字典 - raw_data: 原始数据列表
- md5_set: 原始数据的SHA256集合 - sha256_list: 原始数据的SHA256集合
""" """
# 读取import.json文件 # 读取指定路径或默认路径的json文件
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True: json_path = path if path else global_config["persistence"]["raw_data_path"]
with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f: if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as f:
import_json = json.loads(f.read()) import_json = json.loads(f.read())
else: else:
raise Exception("原始数据文件读取失败") raise Exception(f"原始数据文件读取失败: {json_path}")
# import_json内容示例 # import_json内容示例
# import_json = [ # import_json = [
# "The capital of China is Beijing. The capital of France is Paris.", # "The capital of China is Beijing. The capital of France is Paris.",

View File

@@ -20,6 +20,9 @@ from ..utils.chat_message_builder import (
) # 导入 build_readable_messages ) # 导入 build_readable_messages
from ..chat.utils import translate_timestamp_to_human_readable from ..chat.utils import translate_timestamp_to_human_readable
from .memory_config import MemoryConfig from .memory_config import MemoryConfig
from rich.traceback import install
install(extra_lines=3)
def calculate_information_content(text): def calculate_information_content(text):

View File

@@ -8,6 +8,9 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))) sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.plugins.memory_system.Hippocampus import HippocampusManager from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.config.config import global_config from src.config.config import global_config
from rich.traceback import install
install(extra_lines=3)
async def test_memory_system(): async def test_memory_system():

View File

@@ -9,6 +9,9 @@ from Hippocampus import Hippocampus # 海马体和记忆图
from dotenv import load_dotenv from dotenv import load_dotenv
from rich.traceback import install
install(extra_lines=3)
""" """

View File

@@ -6,6 +6,9 @@ from typing import Tuple, Union
import aiohttp import aiohttp
import requests import requests
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("offline_llm") logger = get_module_logger("offline_llm")

View File

@@ -1,6 +1,9 @@
import numpy as np import numpy as np
from scipy import stats from scipy import stats
from datetime import datetime, timedelta from datetime import datetime, timedelta
from rich.traceback import install
install(extra_lines=3)
class DistributionVisualizer: class DistributionVisualizer:

View File

@@ -14,6 +14,9 @@ import io
import os import os
from ...common.database import db from ...common.database import db
from ...config.config import global_config from ...config.config import global_config
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("model_utils") logger = get_module_logger("model_utils")

View File

@@ -3,7 +3,11 @@ import re
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
import asyncio import asyncio
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
# import traceback # import traceback
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("prompt_build") logger = get_module_logger("prompt_build")

View File

@@ -2,6 +2,9 @@ from time import perf_counter
from functools import wraps from functools import wraps
from typing import Optional, Dict, Callable from typing import Optional, Dict, Callable
import asyncio import asyncio
from rich.traceback import install
install(extra_lines=3)
""" """
# 更好的计时器 # 更好的计时器

View File

@@ -8,6 +8,9 @@ from abc import ABC, abstractmethod
import importlib import importlib
from typing import Dict, Optional from typing import Dict, Optional
import asyncio import asyncio
from rich.traceback import install
install(extra_lines=3)
""" """
基类方法概览: 基类方法概览:

View File

@@ -7,6 +7,9 @@ from datetime import datetime
from tqdm import tqdm from tqdm import tqdm
from rich.console import Console from rich.console import Console
from rich.table import Table from rich.table import Table
from rich.traceback import install
install(extra_lines=3)
# 添加项目根目录到 Python 路径 # 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../..")) root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
@@ -15,6 +18,7 @@ sys.path.append(root_path)
# 现在可以导入src模块 # 现在可以导入src模块
from src.common.database import db # noqa E402 from src.common.database import db # noqa E402
# 加载根目录下的env.edv文件 # 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env") env_path = os.path.join(root_path, ".env")
if not os.path.exists(env_path): if not os.path.exists(env_path):

View File

@@ -51,7 +51,7 @@ res_top_k = 3 # 最终提供的文段TopK
[persistence] [persistence]
# 持久化配置(存储中间数据,防止重复计算) # 持久化配置(存储中间数据,防止重复计算)
data_root_path = "data" # 数据根目录 data_root_path = "data" # 数据根目录
raw_data_path = "data/import.json" # 原始数据路径 raw_data_path = "data/imported_lpmm_data" # 原始数据路径
openie_data_path = "data/openie.json" # OpenIE数据路径 openie_data_path = "data/openie" # OpenIE数据路径
embedding_data_dir = "data/embedding" # 嵌入数据目录 embedding_data_dir = "data/embedding" # 嵌入数据目录
rag_data_dir = "data/rag" # RAG数据目录 rag_data_dir = "data/rag" # RAG数据目录