Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev
This commit is contained in:
@@ -8,6 +8,7 @@ from typing import Any, Dict, Optional, Union
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config, model_config
|
||||
from src.common.config_helpers import resolve_embedding_dimension
|
||||
from src.common.database.sqlalchemy_models import CacheEntries
|
||||
from src.common.database.sqlalchemy_database_api import db_query, db_save
|
||||
from src.common.vector_db import vector_db_service
|
||||
@@ -40,7 +41,11 @@ class CacheManager:
|
||||
|
||||
# L1 缓存 (内存)
|
||||
self.l1_kv_cache: Dict[str, Dict[str, Any]] = {}
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
embedding_dim = resolve_embedding_dimension(global_config.lpmm_knowledge.embedding_dimension)
|
||||
if not embedding_dim:
|
||||
embedding_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
|
||||
self.embedding_dimension = embedding_dim
|
||||
self.l1_vector_index = faiss.IndexFlatIP(embedding_dim)
|
||||
self.l1_vector_id_to_key: Dict[int, str] = {}
|
||||
|
||||
@@ -72,7 +77,7 @@ class CacheManager:
|
||||
embedding_array = embedding_array.flatten()
|
||||
|
||||
# 检查维度是否符合预期
|
||||
expected_dim = global_config.lpmm_knowledge.embedding_dimension
|
||||
expected_dim = getattr(CacheManager, "embedding_dimension", None) or global_config.lpmm_knowledge.embedding_dimension
|
||||
if embedding_array.shape[0] != expected_dim:
|
||||
logger.warning(f"嵌入向量维度不匹配: 期望 {expected_dim}, 实际 {embedding_array.shape[0]}")
|
||||
return None
|
||||
|
||||
42
src/common/config_helpers.py
Normal file
42
src/common/config_helpers.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from src.config.config import global_config, model_config
|
||||
|
||||
|
||||
def resolve_embedding_dimension(fallback: Optional[int] = None, *, sync_global: bool = True) -> Optional[int]:
|
||||
"""获取当前配置的嵌入向量维度。
|
||||
|
||||
优先顺序:
|
||||
1. 模型配置中 `model_task_config.embedding.embedding_dimension`
|
||||
2. 机器人配置中 `lpmm_knowledge.embedding_dimension`
|
||||
3. 调用方提供的 fallback
|
||||
"""
|
||||
|
||||
candidates: list[Optional[int]] = []
|
||||
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task is not None:
|
||||
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
|
||||
try:
|
||||
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
|
||||
candidates.append(fallback)
|
||||
|
||||
resolved: Optional[int] = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
||||
|
||||
if resolved and sync_global:
|
||||
try:
|
||||
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
|
||||
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return resolved
|
||||
Reference in New Issue
Block a user