diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index 2cb9fbdfb..3eb466d21 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -26,7 +26,7 @@ from rich.progress import ( TextColumn, ) from src.manager.local_store_manager import local_storage -from src.chat.utils.utils import get_embedding_sync +from src.chat.utils.utils import get_embedding from src.config.config import global_config @@ -99,7 +99,7 @@ class EmbeddingStore: self.idx2hash = None def _get_embedding(self, s: str) -> List[float]: - return get_embedding_sync(s) + return get_embedding(s) def get_test_file_path(self): return EMBEDDING_TEST_FILE diff --git a/src/chat/knowledge/qa_manager.py b/src/chat/knowledge/qa_manager.py index b4a0dc1fc..c83683b79 100644 --- a/src/chat/knowledge/qa_manager.py +++ b/src/chat/knowledge/qa_manager.py @@ -10,7 +10,7 @@ from .kg_manager import KGManager # from .lpmmconfig import global_config from .utils.dyn_topk import dyn_select_top_k from src.llm_models.utils_model import LLMRequest -from src.chat.utils.utils import get_embedding_sync +from src.chat.utils.utils import get_embedding from src.config.config import global_config MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 @@ -36,7 +36,7 @@ class QAManager: # 生成问题的Embedding part_start_time = time.perf_counter() - question_embedding = await get_embedding_sync(question) + question_embedding = await get_embedding(question) if question_embedding is None: logger.error("生成问题Embedding失败") return None diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 045e9e911..a329b3548 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -122,18 +122,6 @@ async def get_embedding(text, request_type="embedding"): return embedding -def get_embedding_sync(text, request_type="embedding"): - """获取文本的embedding向量(同步版本)""" - # TODO: API-Adapter修改标记 - llm = LLMRequest(model=global_config.model.embedding, request_type=request_type) - try: - embedding = llm.get_embedding_sync(text) - except Exception as e: - logger.error(f"获取embedding失败: {str(e)}") - embedding = None - return embedding - - def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list: # 获取当前群聊记录内发言的人 filter_query = {"chat_id": chat_stream_id} diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index e2e37fdbd..1077cfa09 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -827,29 +827,6 @@ class LLMRequest: ) return embedding - def get_embedding_sync(self, text: str) -> Union[list, None]: - """同步方法:获取文本的embedding向量 - - Args: - text: 需要获取embedding的文本 - - Returns: - list: embedding向量,如果失败则返回None - """ - return asyncio.run(self.get_embedding(text)) - - def generate_response_sync(self, prompt: str, **kwargs) -> Union[str, Tuple]: - """同步方式根据输入的提示生成模型的响应 - - Args: - prompt: 输入的提示文本 - **kwargs: 额外的参数 - - Returns: - Union[str, Tuple]: 模型响应内容,如果有工具调用则返回元组 - """ - return asyncio.run(self.generate_response_async(prompt, **kwargs)) - def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str: """压缩base64格式的图片到指定大小