This commit is contained in:
SengokuCola
2025-05-02 19:15:44 +08:00
40 changed files with 408 additions and 110 deletions

View File

@@ -7,6 +7,9 @@ from maim_message import UserInfo
from ...config.config import global_config
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
from .message_storage import MongoDBMessageStorage
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("chat_observer")

View File

@@ -23,6 +23,9 @@ from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .waiter import Waiter
import traceback
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("pfc")

View File

@@ -8,6 +8,9 @@ from src.plugins.chat.message import MessageSending, MessageSet
from src.plugins.chat.message_sender import message_manager
from ..storage.storage import MessageStorage
from ...config.config import global_config
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("message_sender")

View File

@@ -8,6 +8,9 @@ from src.individuality.individuality import Individuality
from .conversation_info import ConversationInfo
from .observation_info import ObservationInfo
from src.plugins.utils.chat_message_builder import build_readable_messages
from rich.traceback import install
install(extra_lines=3)
if TYPE_CHECKING:
pass

View File

@@ -9,6 +9,9 @@ from ...common.database import db
from maim_message import GroupInfo, UserInfo
from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_stream")

View File

@@ -9,6 +9,9 @@ from src.common.logger_manager import get_logger
from .chat_stream import ChatStream
from .utils_image import image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_message")

View File

@@ -13,6 +13,9 @@ from ...config.config import global_config
from .utils import truncate_message, calculate_typing_time, count_messages_between
from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("sender")

View File

@@ -13,6 +13,9 @@ from ...config.config import global_config
from ..models.utils_model import LLMRequest
from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("chat_image")

View File

@@ -1,4 +1,7 @@
from fastapi import APIRouter, HTTPException
from rich.traceback import install
install(extra_lines=3)
# 创建APIRouter而不是FastAPI实例
router = APIRouter()

View File

@@ -15,6 +15,9 @@ from ...config.config import global_config
from ..chat.utils_image import image_path_to_base64, image_manager
from ..models.utils_model import LLMRequest
from src.common.logger_manager import get_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("emoji")

View File

@@ -27,6 +27,9 @@ from src.plugins.chat.utils import process_llm_response
from src.plugins.respon_info_catcher.info_catcher import info_catcher_manager
from src.plugins.moods.moods import MoodManager
from src.heart_flow.utils_chat import get_chat_type_and_target_info
from rich.traceback import install
install(extra_lines=3)
WAITING_TIME_THRESHOLD = 300 # 等待新消息时间阈值,单位秒

View File

@@ -9,6 +9,9 @@ from ..storage.storage import MessageStorage
from ..chat.utils import truncate_message
from src.common.logger_manager import get_logger
from src.plugins.chat.utils import calculate_typing_time
from rich.traceback import install
install(extra_lines=3)
logger = get_logger("sender")

View File

@@ -26,6 +26,7 @@ try:
embed_manager.load_from_file()
except Exception as e:
logger.error("从文件加载Embedding库时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("Embedding库加载完成")
# 初始化KG
kg_manager = KGManager()
@@ -34,6 +35,7 @@ try:
kg_manager.load_from_file()
except Exception as e:
logger.error("从文件加载KG时发生错误{}".format(e))
logger.error("如果你是第一次导入知识,或者还未导入知识,请忽略此错误")
logger.info("KG加载完成")
logger.info(f"KG节点数量{len(kg_manager.graph.get_node_list())}")

View File

@@ -12,6 +12,21 @@ from .llm_client import LLMClient
from .lpmmconfig import ENT_NAMESPACE, PG_NAMESPACE, REL_NAMESPACE, global_config
from .utils.hash import get_sha256
from .global_logger import logger
from rich.traceback import install
from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
install(extra_lines=3)
TOTAL_EMBEDDING_TIMES = 3 # 统计嵌入次数
@dataclass
@@ -49,20 +64,35 @@ class EmbeddingStore:
def _get_embedding(self, s: str) -> List[float]:
return self.llm_client.send_embedding_request(global_config["embedding"]["model"], s)
def batch_insert_strs(self, strs: List[str]) -> None:
def batch_insert_strs(self, strs: List[str], times: int) -> None:
"""向库中存入字符串"""
# 逐项处理
for s in tqdm.tqdm(strs, desc="存入嵌入库", unit="items"):
# 计算hash去重
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash in self.store:
continue
total = len(strs)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
for s in strs:
# 计算hash去重
item_hash = self.namespace + "-" + get_sha256(s)
if item_hash in self.store:
progress.update(task, advance=1)
continue
# 获取embedding
embedding = self._get_embedding(s)
# 获取embedding
embedding = self._get_embedding(s)
# 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
# 存入
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
progress.update(task, advance=1)
def save_to_file(self) -> None:
"""保存到文件"""
@@ -188,7 +218,7 @@ class EmbeddingManager:
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
"""将段落编码存入Embedding库"""
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()))
self.paragraphs_embedding_store.batch_insert_strs(list(raw_paragraphs.values()), times=1)
def _store_ent_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将实体编码存入Embedding库"""
@@ -197,7 +227,7 @@ class EmbeddingManager:
for triple in triple_list:
entities.add(triple[0])
entities.add(triple[2])
self.entities_embedding_store.batch_insert_strs(list(entities))
self.entities_embedding_store.batch_insert_strs(list(entities), times=2)
def _store_rel_into_embedding(self, triple_list_data: Dict[str, List[List[str]]]):
"""将关系编码存入Embedding库"""
@@ -205,7 +235,7 @@ class EmbeddingManager:
for triples in triple_list_data.values():
graph_triples.extend([tuple(t) for t in triples])
graph_triples = list(set(graph_triples))
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples])
self.relation_embedding_store.batch_insert_strs([str(triple) for triple in graph_triples], times=3)
def load_from_file(self):
"""从文件加载"""

View File

@@ -5,7 +5,16 @@ from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
import tqdm
from rich.progress import (
Progress,
BarColumn,
TimeElapsedColumn,
TimeRemainingColumn,
TaskProgressColumn,
MofNCompleteColumn,
SpinnerColumn,
TextColumn,
)
from quick_algo import di_graph, pagerank
@@ -132,41 +141,56 @@ class KGManager:
ent_hash_list = list(ent_hash_list)
synonym_hash_set = set()
synonym_result = dict()
# 对每个实体节点,查找其相似的实体节点,建立扩展连接
for ent_hash in tqdm.tqdm(ent_hash_list):
if ent_hash in synonym_hash_set:
# 避免同一批次内重复添加
continue
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
assert isinstance(ent, EmbeddingStoreItem)
if ent is None:
continue
# 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
)
res_ent = [] # Debug
for res_ent_hash, similarity in similar_ents:
if res_ent_hash == ent_hash:
# 避免自连接
# rich 进度条
total = len(ent_hash_list)
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
MofNCompleteColumn(),
"",
TimeElapsedColumn(),
"<",
TimeRemainingColumn(),
transient=False,
) as progress:
task = progress.add_task("同义词连接", total=total)
for ent_hash in ent_hash_list:
if ent_hash in synonym_hash_set:
progress.update(task, advance=1)
continue
if similarity < global_config["rag"]["params"]["synonym_threshold"]:
# 相似度阈值
ent = embedding_manager.entities_embedding_store.store.get(ent_hash)
assert isinstance(ent, EmbeddingStoreItem)
if ent is None:
progress.update(task, advance=1)
continue
node_to_node[(res_ent_hash, ent_hash)] = similarity
node_to_node[(ent_hash, res_ent_hash)] = similarity
synonym_hash_set.add(res_ent_hash)
new_edge_cnt += 1
res_ent.append(
(
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
similarity,
)
) # Debug
synonym_result[ent.str] = res_ent
# 查询相似实体
similar_ents = embedding_manager.entities_embedding_store.search_top_k(
ent.embedding, global_config["rag"]["params"]["synonym_search_top_k"]
)
res_ent = [] # Debug
for res_ent_hash, similarity in similar_ents:
if res_ent_hash == ent_hash:
# 避免自连接
continue
if similarity < global_config["rag"]["params"]["synonym_threshold"]:
# 相似度阈值
continue
node_to_node[(res_ent_hash, ent_hash)] = similarity
node_to_node[(ent_hash, res_ent_hash)] = similarity
synonym_hash_set.add(res_ent_hash)
new_edge_cnt += 1
res_ent.append(
(
embedding_manager.entities_embedding_store.store[res_ent_hash].str,
similarity,
)
) # Debug
synonym_result[ent.str] = res_ent
progress.update(task, advance=1)
for k, v in synonym_result.items():
print(f'"{k}"的相似实体为:{v}')

View File

@@ -1,9 +1,13 @@
import json
import os
import glob
from typing import Any, Dict, List
from .lpmmconfig import INVALID_ENTITY, global_config
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
def _filter_invalid_entities(entities: List[str]) -> List[str]:
"""过滤无效的实体"""
@@ -74,12 +78,22 @@ class OpenIE:
doc["extracted_triples"] = _filter_invalid_triples(doc["extracted_triples"])
@staticmethod
def _from_dict(data):
"""字典中获取OpenIE对象"""
def _from_dict(data_list):
"""多个字典合并OpenIE对象"""
# data_list: List[dict]
all_docs = []
for data in data_list:
all_docs.extend(data.get("docs", []))
# 重新计算统计
sum_phrase_chars = sum([len(e) for chunk in all_docs for e in chunk["extracted_entities"]])
sum_phrase_words = sum([len(e.split()) for chunk in all_docs for e in chunk["extracted_entities"]])
num_phrases = sum([len(chunk["extracted_entities"]) for chunk in all_docs])
avg_ent_chars = round(sum_phrase_chars / num_phrases, 4) if num_phrases else 0
avg_ent_words = round(sum_phrase_words / num_phrases, 4) if num_phrases else 0
return OpenIE(
docs=data["docs"],
avg_ent_chars=data["avg_ent_chars"],
avg_ent_words=data["avg_ent_words"],
docs=all_docs,
avg_ent_chars=avg_ent_chars,
avg_ent_words=avg_ent_words,
)
def _to_dict(self):
@@ -92,12 +106,20 @@ class OpenIE:
@staticmethod
def load() -> "OpenIE":
"""文件中加载OpenIE数据"""
with open(global_config["persistence"]["openie_data_path"], "r", encoding="utf-8") as f:
data = json.loads(f.read())
openie_data = OpenIE._from_dict(data)
"""OPENIE_DIR下所有json文件合并加载OpenIE数据"""
openie_dir = os.path.join(ROOT_PATH, global_config["persistence"]["openie_data_path"])
if not os.path.exists(openie_dir):
raise Exception(f"OpenIE数据目录不存在: {openie_dir}")
json_files = sorted(glob.glob(os.path.join(openie_dir, "*.json")))
data_list = []
for file in json_files:
with open(file, "r", encoding="utf-8") as f:
data = json.load(f)
data_list.append(data)
if not data_list:
# print(f"111111111111111111111Root Path : \n{ROOT_PATH}")
raise Exception(f"未在 {openie_dir} 找到任何OpenIE json文件")
openie_data = OpenIE._from_dict(data_list)
return openie_data
@staticmethod
@@ -132,3 +154,8 @@ class OpenIE:
"""提取原始段落"""
raw_paragraph_dict = dict({doc_item["idx"]: doc_item["passage"] for doc_item in self.docs})
return raw_paragraph_dict
if __name__ == "__main__":
# 测试代码
print(ROOT_PATH)

View File

@@ -6,21 +6,25 @@ from .lpmmconfig import global_config
from .utils.hash import get_sha256
def load_raw_data() -> tuple[list[str], list[str]]:
def load_raw_data(path: str = None) -> tuple[list[str], list[str]]:
"""加载原始数据文件
读取原始数据文件,将原始数据加载到内存中
Args:
path: 可选指定要读取的json文件绝对路径
Returns:
- raw_data: 原始数据字典
- md5_set: 原始数据的SHA256集合
- raw_data: 原始数据列表
- sha256_list: 原始数据的SHA256集合
"""
# 读取import.json文件
if os.path.exists(global_config["persistence"]["raw_data_path"]) is True:
with open(global_config["persistence"]["raw_data_path"], "r", encoding="utf-8") as f:
# 读取指定路径或默认路径的json文件
json_path = path if path else global_config["persistence"]["raw_data_path"]
if os.path.exists(json_path):
with open(json_path, "r", encoding="utf-8") as f:
import_json = json.loads(f.read())
else:
raise Exception("原始数据文件读取失败")
raise Exception(f"原始数据文件读取失败: {json_path}")
# import_json内容示例
# import_json = [
# "The capital of China is Beijing. The capital of France is Paris.",

View File

@@ -20,6 +20,9 @@ from ..utils.chat_message_builder import (
) # 导入 build_readable_messages
from ..chat.utils import translate_timestamp_to_human_readable
from .memory_config import MemoryConfig
from rich.traceback import install
install(extra_lines=3)
def calculate_information_content(text):

View File

@@ -8,6 +8,9 @@ import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))))
from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.config.config import global_config
from rich.traceback import install
install(extra_lines=3)
async def test_memory_system():

View File

@@ -9,6 +9,9 @@ from Hippocampus import Hippocampus # 海马体和记忆图
from dotenv import load_dotenv
from rich.traceback import install
install(extra_lines=3)
"""

View File

@@ -6,6 +6,9 @@ from typing import Tuple, Union
import aiohttp
import requests
from src.common.logger import get_module_logger
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("offline_llm")

View File

@@ -1,6 +1,9 @@
import numpy as np
from scipy import stats
from datetime import datetime, timedelta
from rich.traceback import install
install(extra_lines=3)
class DistributionVisualizer:

View File

@@ -14,6 +14,9 @@ import io
import os
from ...common.database import db
from ...config.config import global_config
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("model_utils")

View File

@@ -3,7 +3,11 @@ import re
from contextlib import asynccontextmanager
import asyncio
from src.common.logger import get_module_logger
# import traceback
from rich.traceback import install
install(extra_lines=3)
logger = get_module_logger("prompt_build")

View File

@@ -2,6 +2,9 @@ from time import perf_counter
from functools import wraps
from typing import Optional, Dict, Callable
import asyncio
from rich.traceback import install
install(extra_lines=3)
"""
# 更好的计时器

View File

@@ -8,6 +8,9 @@ from abc import ABC, abstractmethod
import importlib
from typing import Dict, Optional
import asyncio
from rich.traceback import install
install(extra_lines=3)
"""
基类方法概览:

View File

@@ -7,6 +7,9 @@ from datetime import datetime
from tqdm import tqdm
from rich.console import Console
from rich.table import Table
from rich.traceback import install
install(extra_lines=3)
# 添加项目根目录到 Python 路径
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
@@ -15,6 +18,7 @@ sys.path.append(root_path)
# 现在可以导入src模块
from src.common.database import db # noqa E402
# 加载根目录下的env.edv文件
env_path = os.path.join(root_path, ".env")
if not os.path.exists(env_path):