Merge branch 'dev' of github.com:MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -25,7 +25,7 @@ from rich.progress import (
|
|||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
from raw_data_preprocessor import RAW_DATA_PATH, load_raw_data
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
@@ -96,11 +96,11 @@ open_ie_doc_lock = Lock()
|
|||||||
shutdown_event = Event()
|
shutdown_event = Event()
|
||||||
|
|
||||||
lpmm_entity_extract_llm = LLMRequest(
|
lpmm_entity_extract_llm = LLMRequest(
|
||||||
model=global_config.model.lpmm_entity_extract,
|
model_set=model_config.model_task_config.lpmm_entity_extract,
|
||||||
request_type="lpmm.entity_extract"
|
request_type="lpmm.entity_extract"
|
||||||
)
|
)
|
||||||
lpmm_rdf_build_llm = LLMRequest(
|
lpmm_rdf_build_llm = LLMRequest(
|
||||||
model=global_config.model.lpmm_rdf_build,
|
model_set=model_config.model_task_config.lpmm_rdf_build,
|
||||||
request_type="lpmm.rdf_build"
|
request_type="lpmm.rdf_build"
|
||||||
)
|
)
|
||||||
def process_single_text(pg_hash, raw_data):
|
def process_single_text(pg_hash, raw_data):
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
import math
|
import math
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import Dict, List, Tuple
|
from typing import Dict, List, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@@ -26,12 +27,20 @@ from rich.progress import (
|
|||||||
SpinnerColumn,
|
SpinnerColumn,
|
||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
from src.manager.local_store_manager import local_storage
|
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
# 多线程embedding配置常量
|
||||||
|
DEFAULT_MAX_WORKERS = 10 # 默认最大线程数
|
||||||
|
DEFAULT_CHUNK_SIZE = 10 # 默认每个线程处理的数据块大小
|
||||||
|
MIN_CHUNK_SIZE = 1 # 最小分块大小
|
||||||
|
MAX_CHUNK_SIZE = 50 # 最大分块大小
|
||||||
|
MIN_WORKERS = 1 # 最小线程数
|
||||||
|
MAX_WORKERS = 20 # 最大线程数
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
EMBEDDING_DATA_DIR = os.path.join(ROOT_PATH, "data", "embedding")
|
||||||
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
|
EMBEDDING_DATA_DIR_STR = str(EMBEDDING_DATA_DIR).replace("\\", "/")
|
||||||
@@ -87,13 +96,23 @@ class EmbeddingStoreItem:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingStore:
|
class EmbeddingStore:
|
||||||
def __init__(self, namespace: str, dir_path: str):
|
def __init__(self, namespace: str, dir_path: str, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||||
self.namespace = namespace
|
self.namespace = namespace
|
||||||
self.dir = dir_path
|
self.dir = dir_path
|
||||||
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
self.embedding_file_path = f"{dir_path}/{namespace}.parquet"
|
||||||
self.index_file_path = f"{dir_path}/{namespace}.index"
|
self.index_file_path = f"{dir_path}/{namespace}.index"
|
||||||
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
self.idx2hash_file_path = dir_path + "/" + namespace + "_i2h.json"
|
||||||
|
|
||||||
|
# 多线程配置参数验证和设置
|
||||||
|
self.max_workers = max(MIN_WORKERS, min(MAX_WORKERS, max_workers))
|
||||||
|
self.chunk_size = max(MIN_CHUNK_SIZE, min(MAX_CHUNK_SIZE, chunk_size))
|
||||||
|
|
||||||
|
# 如果配置值被调整,记录日志
|
||||||
|
if self.max_workers != max_workers:
|
||||||
|
logger.warning(f"max_workers 已从 {max_workers} 调整为 {self.max_workers} (范围: {MIN_WORKERS}-{MAX_WORKERS})")
|
||||||
|
if self.chunk_size != chunk_size:
|
||||||
|
logger.warning(f"chunk_size 已从 {chunk_size} 调整为 {self.chunk_size} (范围: {MIN_CHUNK_SIZE}-{MAX_CHUNK_SIZE})")
|
||||||
|
|
||||||
self.store = {}
|
self.store = {}
|
||||||
|
|
||||||
self.faiss_index = None
|
self.faiss_index = None
|
||||||
@@ -125,17 +144,135 @@ class EmbeddingStore:
|
|||||||
return []
|
return []
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def _get_embeddings_batch_threaded(self, strs: List[str], chunk_size: int = 10, max_workers: int = 10, progress_callback=None) -> List[Tuple[str, List[float]]]:
|
||||||
|
"""使用多线程批量获取嵌入向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
strs: 要获取嵌入的字符串列表
|
||||||
|
chunk_size: 每个线程处理的数据块大小
|
||||||
|
max_workers: 最大线程数
|
||||||
|
progress_callback: 进度回调函数,接收一个参数表示完成的数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含(原始字符串, 嵌入向量)的元组列表,保持与输入顺序一致
|
||||||
|
"""
|
||||||
|
if not strs:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 分块
|
||||||
|
chunks = []
|
||||||
|
for i in range(0, len(strs), chunk_size):
|
||||||
|
chunk = strs[i:i + chunk_size]
|
||||||
|
chunks.append((i, chunk)) # 保存起始索引以维持顺序
|
||||||
|
|
||||||
|
# 结果存储,使用字典按索引存储以保证顺序
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
def process_chunk(chunk_data):
|
||||||
|
"""处理单个数据块的函数"""
|
||||||
|
start_idx, chunk_strs = chunk_data
|
||||||
|
chunk_results = []
|
||||||
|
|
||||||
|
# 为每个线程创建独立的LLMRequest实例
|
||||||
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
from src.config.config import model_config
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 创建线程专用的LLM实例
|
||||||
|
llm = LLMRequest(model_set=model_config.model_task_config.embedding, request_type="embedding")
|
||||||
|
|
||||||
|
for i, s in enumerate(chunk_strs):
|
||||||
|
try:
|
||||||
|
# 直接使用异步函数
|
||||||
|
embedding = asyncio.run(llm.get_embedding(s))
|
||||||
|
if embedding and len(embedding) > 0:
|
||||||
|
chunk_results.append((start_idx + i, s, embedding[0])) # embedding[0] 是实际的向量
|
||||||
|
else:
|
||||||
|
logger.error(f"获取嵌入失败: {s}")
|
||||||
|
chunk_results.append((start_idx + i, s, []))
|
||||||
|
|
||||||
|
# 每完成一个嵌入立即更新进度
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取嵌入时发生异常: {s}, 错误: {e}")
|
||||||
|
chunk_results.append((start_idx + i, s, []))
|
||||||
|
|
||||||
|
# 即使失败也要更新进度
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(1)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"创建LLM实例失败: {e}")
|
||||||
|
# 如果创建LLM实例失败,返回空结果
|
||||||
|
for i, s in enumerate(chunk_strs):
|
||||||
|
chunk_results.append((start_idx + i, s, []))
|
||||||
|
# 即使失败也要更新进度
|
||||||
|
if progress_callback:
|
||||||
|
progress_callback(1)
|
||||||
|
|
||||||
|
return chunk_results
|
||||||
|
|
||||||
|
# 使用线程池处理
|
||||||
|
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||||
|
# 提交所有任务
|
||||||
|
future_to_chunk = {executor.submit(process_chunk, chunk): chunk for chunk in chunks}
|
||||||
|
|
||||||
|
# 收集结果(进度已在process_chunk中实时更新)
|
||||||
|
for future in as_completed(future_to_chunk):
|
||||||
|
try:
|
||||||
|
chunk_results = future.result()
|
||||||
|
for idx, s, embedding in chunk_results:
|
||||||
|
results[idx] = (s, embedding)
|
||||||
|
except Exception as e:
|
||||||
|
chunk = future_to_chunk[future]
|
||||||
|
logger.error(f"处理数据块时发生异常: {chunk}, 错误: {e}")
|
||||||
|
# 为失败的块添加空结果
|
||||||
|
start_idx, chunk_strs = chunk
|
||||||
|
for i, s in enumerate(chunk_strs):
|
||||||
|
results[start_idx + i] = (s, [])
|
||||||
|
|
||||||
|
# 按原始顺序返回结果
|
||||||
|
ordered_results = []
|
||||||
|
for i in range(len(strs)):
|
||||||
|
if i in results:
|
||||||
|
ordered_results.append(results[i])
|
||||||
|
else:
|
||||||
|
# 防止遗漏
|
||||||
|
ordered_results.append((strs[i], []))
|
||||||
|
|
||||||
|
return ordered_results
|
||||||
|
|
||||||
def get_test_file_path(self):
|
def get_test_file_path(self):
|
||||||
return EMBEDDING_TEST_FILE
|
return EMBEDDING_TEST_FILE
|
||||||
|
|
||||||
def save_embedding_test_vectors(self):
|
def save_embedding_test_vectors(self):
|
||||||
"""保存测试字符串的嵌入到本地"""
|
"""保存测试字符串的嵌入到本地(使用多线程优化)"""
|
||||||
|
logger.info("开始保存测试字符串的嵌入向量...")
|
||||||
|
|
||||||
|
# 使用多线程批量获取测试字符串的嵌入
|
||||||
|
embedding_results = self._get_embeddings_batch_threaded(
|
||||||
|
EMBEDDING_TEST_STRINGS,
|
||||||
|
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||||
|
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 构建测试向量字典
|
||||||
test_vectors = {}
|
test_vectors = {}
|
||||||
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
for idx, (s, embedding) in enumerate(embedding_results):
|
||||||
test_vectors[str(idx)] = self._get_embedding(s)
|
if embedding:
|
||||||
|
test_vectors[str(idx)] = embedding
|
||||||
|
else:
|
||||||
|
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||||
|
# 使用原始单线程方法作为后备
|
||||||
|
test_vectors[str(idx)] = self._get_embedding(s)
|
||||||
|
|
||||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||||
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
json.dump(test_vectors, f, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
|
logger.info("测试字符串嵌入向量保存完成")
|
||||||
|
|
||||||
def load_embedding_test_vectors(self):
|
def load_embedding_test_vectors(self):
|
||||||
"""加载本地保存的测试字符串嵌入"""
|
"""加载本地保存的测试字符串嵌入"""
|
||||||
path = self.get_test_file_path()
|
path = self.get_test_file_path()
|
||||||
@@ -145,29 +282,64 @@ class EmbeddingStore:
|
|||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
def check_embedding_model_consistency(self):
|
def check_embedding_model_consistency(self):
|
||||||
"""校验当前模型与本地嵌入模型是否一致"""
|
"""校验当前模型与本地嵌入模型是否一致(使用多线程优化)"""
|
||||||
local_vectors = self.load_embedding_test_vectors()
|
local_vectors = self.load_embedding_test_vectors()
|
||||||
if local_vectors is None:
|
if local_vectors is None:
|
||||||
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
logger.warning("未检测到本地嵌入模型测试文件,将保存当前模型的测试嵌入。")
|
||||||
self.save_embedding_test_vectors()
|
self.save_embedding_test_vectors()
|
||||||
return True
|
return True
|
||||||
for idx, s in enumerate(EMBEDDING_TEST_STRINGS):
|
|
||||||
local_emb = local_vectors.get(str(idx))
|
# 检查本地向量完整性
|
||||||
if local_emb is None:
|
for idx in range(len(EMBEDDING_TEST_STRINGS)):
|
||||||
|
if local_vectors.get(str(idx)) is None:
|
||||||
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
logger.warning("本地嵌入模型测试文件缺失部分测试字符串,将重新保存。")
|
||||||
self.save_embedding_test_vectors()
|
self.save_embedding_test_vectors()
|
||||||
return True
|
return True
|
||||||
new_emb = self._get_embedding(s)
|
|
||||||
|
logger.info("开始检验嵌入模型一致性...")
|
||||||
|
|
||||||
|
# 使用多线程批量获取当前模型的嵌入
|
||||||
|
embedding_results = self._get_embeddings_batch_threaded(
|
||||||
|
EMBEDDING_TEST_STRINGS,
|
||||||
|
chunk_size=min(self.chunk_size, len(EMBEDDING_TEST_STRINGS)),
|
||||||
|
max_workers=min(self.max_workers, len(EMBEDDING_TEST_STRINGS))
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查一致性
|
||||||
|
for idx, (s, new_emb) in enumerate(embedding_results):
|
||||||
|
local_emb = local_vectors.get(str(idx))
|
||||||
|
if not new_emb:
|
||||||
|
logger.error(f"获取测试字符串嵌入失败: {s}")
|
||||||
|
return False
|
||||||
|
|
||||||
sim = cosine_similarity(local_emb, new_emb)
|
sim = cosine_similarity(local_emb, new_emb)
|
||||||
if sim < EMBEDDING_SIM_THRESHOLD:
|
if sim < EMBEDDING_SIM_THRESHOLD:
|
||||||
logger.error("嵌入模型一致性校验失败")
|
logger.error(f"嵌入模型一致性校验失败,字符串: {s}, 相似度: {sim:.4f}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
logger.info("嵌入模型一致性校验通过。")
|
logger.info("嵌入模型一致性校验通过。")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
def batch_insert_strs(self, strs: List[str], times: int) -> None:
|
||||||
"""向库中存入字符串"""
|
"""向库中存入字符串(使用多线程优化)"""
|
||||||
|
if not strs:
|
||||||
|
return
|
||||||
|
|
||||||
total = len(strs)
|
total = len(strs)
|
||||||
|
|
||||||
|
# 过滤已存在的字符串
|
||||||
|
new_strs = []
|
||||||
|
for s in strs:
|
||||||
|
item_hash = self.namespace + "-" + get_sha256(s)
|
||||||
|
if item_hash not in self.store:
|
||||||
|
new_strs.append(s)
|
||||||
|
|
||||||
|
if not new_strs:
|
||||||
|
logger.info(f"所有字符串已存在于{self.namespace}嵌入库中,跳过处理")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"需要处理 {len(new_strs)}/{total} 个新字符串")
|
||||||
|
|
||||||
with Progress(
|
with Progress(
|
||||||
SpinnerColumn(),
|
SpinnerColumn(),
|
||||||
TextColumn("[progress.description]{task.description}"),
|
TextColumn("[progress.description]{task.description}"),
|
||||||
@@ -181,19 +353,38 @@ class EmbeddingStore:
|
|||||||
transient=False,
|
transient=False,
|
||||||
) as progress:
|
) as progress:
|
||||||
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
task = progress.add_task(f"存入嵌入库:({times}/{TOTAL_EMBEDDING_TIMES})", total=total)
|
||||||
for s in strs:
|
|
||||||
# 计算hash去重
|
|
||||||
item_hash = self.namespace + "-" + get_sha256(s)
|
|
||||||
if item_hash in self.store:
|
|
||||||
progress.update(task, advance=1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取embedding
|
# 首先更新已存在项的进度
|
||||||
embedding = self._get_embedding(s)
|
already_processed = total - len(new_strs)
|
||||||
|
if already_processed > 0:
|
||||||
|
progress.update(task, advance=already_processed)
|
||||||
|
|
||||||
# 存入
|
if new_strs:
|
||||||
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
# 使用实例配置的参数,智能调整分块和线程数
|
||||||
progress.update(task, advance=1)
|
optimal_chunk_size = max(MIN_CHUNK_SIZE, min(self.chunk_size, len(new_strs) // self.max_workers if self.max_workers > 0 else self.chunk_size))
|
||||||
|
optimal_max_workers = min(self.max_workers, max(MIN_WORKERS, len(new_strs) // optimal_chunk_size if optimal_chunk_size > 0 else 1))
|
||||||
|
|
||||||
|
logger.debug(f"使用多线程处理: chunk_size={optimal_chunk_size}, max_workers={optimal_max_workers}")
|
||||||
|
|
||||||
|
# 定义进度更新回调函数
|
||||||
|
def update_progress(count):
|
||||||
|
progress.update(task, advance=count)
|
||||||
|
|
||||||
|
# 批量获取嵌入,并实时更新进度
|
||||||
|
embedding_results = self._get_embeddings_batch_threaded(
|
||||||
|
new_strs,
|
||||||
|
chunk_size=optimal_chunk_size,
|
||||||
|
max_workers=optimal_max_workers,
|
||||||
|
progress_callback=update_progress
|
||||||
|
)
|
||||||
|
|
||||||
|
# 存入结果(不再需要在这里更新进度,因为已经在回调中更新了)
|
||||||
|
for s, embedding in embedding_results:
|
||||||
|
item_hash = self.namespace + "-" + get_sha256(s)
|
||||||
|
if embedding: # 只有成功获取到嵌入才存入
|
||||||
|
self.store[item_hash] = EmbeddingStoreItem(item_hash, embedding, s)
|
||||||
|
else:
|
||||||
|
logger.warning(f"跳过存储失败的嵌入: {s[:50]}...")
|
||||||
|
|
||||||
def save_to_file(self) -> None:
|
def save_to_file(self) -> None:
|
||||||
"""保存到文件"""
|
"""保存到文件"""
|
||||||
@@ -316,31 +507,37 @@ class EmbeddingStore:
|
|||||||
|
|
||||||
|
|
||||||
class EmbeddingManager:
|
class EmbeddingManager:
|
||||||
def __init__(self):
|
def __init__(self, max_workers: int = DEFAULT_MAX_WORKERS, chunk_size: int = DEFAULT_CHUNK_SIZE):
|
||||||
|
"""
|
||||||
|
初始化EmbeddingManager
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_workers: 最大线程数
|
||||||
|
chunk_size: 每个线程处理的数据块大小
|
||||||
|
"""
|
||||||
self.paragraphs_embedding_store = EmbeddingStore(
|
self.paragraphs_embedding_store = EmbeddingStore(
|
||||||
local_storage["pg_namespace"], # type: ignore
|
"paragraph", # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
|
max_workers=max_workers,
|
||||||
|
chunk_size=chunk_size,
|
||||||
)
|
)
|
||||||
self.entities_embedding_store = EmbeddingStore(
|
self.entities_embedding_store = EmbeddingStore(
|
||||||
local_storage["pg_namespace"], # type: ignore
|
"entity", # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
|
max_workers=max_workers,
|
||||||
|
chunk_size=chunk_size,
|
||||||
)
|
)
|
||||||
self.relation_embedding_store = EmbeddingStore(
|
self.relation_embedding_store = EmbeddingStore(
|
||||||
local_storage["pg_namespace"], # type: ignore
|
"relation", # type: ignore
|
||||||
EMBEDDING_DATA_DIR_STR,
|
EMBEDDING_DATA_DIR_STR,
|
||||||
|
max_workers=max_workers,
|
||||||
|
chunk_size=chunk_size,
|
||||||
)
|
)
|
||||||
self.stored_pg_hashes = set()
|
self.stored_pg_hashes = set()
|
||||||
|
|
||||||
def check_all_embedding_model_consistency(self):
|
def check_all_embedding_model_consistency(self):
|
||||||
"""对所有嵌入库做模型一致性校验"""
|
"""对所有嵌入库做模型一致性校验"""
|
||||||
for store in [
|
return self.paragraphs_embedding_store.check_embedding_model_consistency()
|
||||||
self.paragraphs_embedding_store,
|
|
||||||
self.entities_embedding_store,
|
|
||||||
self.relation_embedding_store,
|
|
||||||
]:
|
|
||||||
if not store.check_embedding_model_consistency():
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
def _store_pg_into_embedding(self, raw_paragraphs: Dict[str, str]):
|
||||||
"""将段落编码存入Embedding库"""
|
"""将段落编码存入Embedding库"""
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ from src.chat.knowledge.qa_manager import QAManager
|
|||||||
from src.chat.knowledge.kg_manager import KGManager
|
from src.chat.knowledge.kg_manager import KGManager
|
||||||
from src.chat.knowledge.global_logger import logger
|
from src.chat.knowledge.global_logger import logger
|
||||||
from src.config.config import global_config as bot_global_config
|
from src.config.config import global_config as bot_global_config
|
||||||
from src.manager.local_store_manager import local_storage
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
INVALID_ENTITY = [
|
INVALID_ENTITY = [
|
||||||
@@ -21,9 +20,6 @@ INVALID_ENTITY = [
|
|||||||
"她们",
|
"她们",
|
||||||
"它们",
|
"它们",
|
||||||
]
|
]
|
||||||
PG_NAMESPACE = "paragraph"
|
|
||||||
ENT_NAMESPACE = "entity"
|
|
||||||
REL_NAMESPACE = "relation"
|
|
||||||
|
|
||||||
RAG_GRAPH_NAMESPACE = "rag-graph"
|
RAG_GRAPH_NAMESPACE = "rag-graph"
|
||||||
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
RAG_ENT_CNT_NAMESPACE = "rag-ent-cnt"
|
||||||
@@ -34,54 +30,6 @@ ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..",
|
|||||||
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
DATA_PATH = os.path.join(ROOT_PATH, "data")
|
||||||
|
|
||||||
|
|
||||||
def _initialize_knowledge_local_storage():
|
|
||||||
"""
|
|
||||||
初始化知识库相关的本地存储配置
|
|
||||||
使用字典批量设置,避免重复的if判断
|
|
||||||
"""
|
|
||||||
# 定义所有需要初始化的配置项
|
|
||||||
default_configs = {
|
|
||||||
# 路径配置
|
|
||||||
"root_path": ROOT_PATH,
|
|
||||||
"data_path": f"{ROOT_PATH}/data",
|
|
||||||
# 实体和命名空间配置
|
|
||||||
"lpmm_invalid_entity": INVALID_ENTITY,
|
|
||||||
"pg_namespace": PG_NAMESPACE,
|
|
||||||
"ent_namespace": ENT_NAMESPACE,
|
|
||||||
"rel_namespace": REL_NAMESPACE,
|
|
||||||
# RAG相关命名空间配置
|
|
||||||
"rag_graph_namespace": RAG_GRAPH_NAMESPACE,
|
|
||||||
"rag_ent_cnt_namespace": RAG_ENT_CNT_NAMESPACE,
|
|
||||||
"rag_pg_hash_namespace": RAG_PG_HASH_NAMESPACE,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 日志级别映射:重要配置用info,其他用debug
|
|
||||||
important_configs = {"root_path", "data_path"}
|
|
||||||
|
|
||||||
# 批量设置配置项
|
|
||||||
initialized_count = 0
|
|
||||||
for key, default_value in default_configs.items():
|
|
||||||
if local_storage[key] is None:
|
|
||||||
local_storage[key] = default_value
|
|
||||||
|
|
||||||
# 根据重要性选择日志级别
|
|
||||||
if key in important_configs:
|
|
||||||
logger.info(f"设置{key}: {default_value}")
|
|
||||||
else:
|
|
||||||
logger.debug(f"设置{key}: {default_value}")
|
|
||||||
|
|
||||||
initialized_count += 1
|
|
||||||
|
|
||||||
if initialized_count > 0:
|
|
||||||
logger.info(f"知识库本地存储初始化完成,共设置 {initialized_count} 项配置")
|
|
||||||
else:
|
|
||||||
logger.debug("知识库本地存储配置已存在,跳过初始化")
|
|
||||||
|
|
||||||
|
|
||||||
# 初始化本地存储路径
|
|
||||||
# sourcery skip: dict-comprehension
|
|
||||||
_initialize_knowledge_local_storage()
|
|
||||||
|
|
||||||
qa_manager = None
|
qa_manager = None
|
||||||
inspire_manager = None
|
inspire_manager = None
|
||||||
|
|
||||||
@@ -120,7 +68,7 @@ if bot_global_config.lpmm_knowledge.enable:
|
|||||||
|
|
||||||
# 数据比对:Embedding库与KG的段落hash集合
|
# 数据比对:Embedding库与KG的段落hash集合
|
||||||
for pg_hash in kg_manager.stored_paragraph_hashes:
|
for pg_hash in kg_manager.stored_paragraph_hashes:
|
||||||
key = f"{PG_NAMESPACE}-{pg_hash}"
|
key = f"paragraph-{pg_hash}"
|
||||||
if key not in embed_manager.stored_pg_hashes:
|
if key not in embed_manager.stored_pg_hashes:
|
||||||
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
logger.warning(f"KG中存在Embedding库中不存在的段落:{key}")
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ async def get_voice_text(voice_base64: str) -> str:
|
|||||||
logger.warning("语音识别未启用,无法处理语音消息")
|
logger.warning("语音识别未启用,无法处理语音消息")
|
||||||
return "[语音]"
|
return "[语音]"
|
||||||
try:
|
try:
|
||||||
_llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="voice")
|
_llm = LLMRequest(model_set=model_config.model_task_config.voice, request_type="audio")
|
||||||
text = await _llm.generate_response_for_voice(voice_base64)
|
text = await _llm.generate_response_for_voice(voice_base64)
|
||||||
if text is None:
|
if text is None:
|
||||||
logger.warning("未能生成语音文本")
|
logger.warning("未能生成语音文本")
|
||||||
|
|||||||
@@ -277,7 +277,7 @@ class LLMRequest:
|
|||||||
extra_params=model_info.extra_params,
|
extra_params=model_info.extra_params,
|
||||||
)
|
)
|
||||||
elif request_type == RequestType.AUDIO:
|
elif request_type == RequestType.AUDIO:
|
||||||
assert message_list is not None, "message_list cannot be None for audio requests"
|
assert audio_base64 is not None, "audio_base64 cannot be None for audio requests"
|
||||||
return await client.get_audio_transcriptions(
|
return await client.get_audio_transcriptions(
|
||||||
model_info=model_info,
|
model_info=model_info,
|
||||||
audio_base64=audio_base64,
|
audio_base64=audio_base64,
|
||||||
|
|||||||
Reference in New Issue
Block a user