feat: 添加同步获取embedding向量和生成响应的方法
This commit is contained in:
@@ -26,7 +26,7 @@ from rich.progress import (
|
|||||||
TextColumn,
|
TextColumn,
|
||||||
)
|
)
|
||||||
from src.manager.local_store_manager import local_storage
|
from src.manager.local_store_manager import local_storage
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding_sync
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
|
||||||
@@ -99,7 +99,7 @@ class EmbeddingStore:
|
|||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
|
||||||
def _get_embedding(self, s: str) -> List[float]:
|
def _get_embedding(self, s: str) -> List[float]:
|
||||||
return get_embedding(s)
|
return get_embedding_sync(s)
|
||||||
|
|
||||||
def get_test_file_path(self):
|
def get_test_file_path(self):
|
||||||
return EMBEDDING_TEST_FILE
|
return EMBEDDING_TEST_FILE
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ def _extract_json_from_text(text: str) -> dict:
|
|||||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
|
response, (reasoning_content, model_name) = llm_req.generate_response_sync(entity_extract_context)
|
||||||
|
|
||||||
entity_extract_result = _extract_json_from_text(response)
|
entity_extract_result = _extract_json_from_text(response)
|
||||||
# 尝试load JSON数据
|
# 尝试load JSON数据
|
||||||
@@ -50,7 +50,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
|||||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||||
)
|
)
|
||||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
|
response, (reasoning_content, model_name) = llm_req.generate_response_sync(rdf_extract_context)
|
||||||
|
|
||||||
entity_extract_result = _extract_json_from_text(response)
|
entity_extract_result = _extract_json_from_text(response)
|
||||||
# 尝试load JSON数据
|
# 尝试load JSON数据
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from .kg_manager import KGManager
|
|||||||
# from .lpmmconfig import global_config
|
# from .lpmmconfig import global_config
|
||||||
from .utils.dyn_topk import dyn_select_top_k
|
from .utils.dyn_topk import dyn_select_top_k
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.chat.utils.utils import get_embedding
|
from src.chat.utils.utils import get_embedding_sync
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||||
@@ -36,7 +36,7 @@ class QAManager:
|
|||||||
|
|
||||||
# 生成问题的Embedding
|
# 生成问题的Embedding
|
||||||
part_start_time = time.perf_counter()
|
part_start_time = time.perf_counter()
|
||||||
question_embedding = await get_embedding(question)
|
question_embedding = await get_embedding_sync(question)
|
||||||
if question_embedding is None:
|
if question_embedding is None:
|
||||||
logger.error("生成问题Embedding失败")
|
logger.error("生成问题Embedding失败")
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -122,6 +122,18 @@ async def get_embedding(text, request_type="embedding"):
|
|||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedding_sync(text, request_type="embedding"):
|
||||||
|
"""获取文本的embedding向量(同步版本)"""
|
||||||
|
# TODO: API-Adapter修改标记
|
||||||
|
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
|
||||||
|
try:
|
||||||
|
embedding = llm.get_embedding_sync(text)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"获取embedding失败: {str(e)}")
|
||||||
|
embedding = None
|
||||||
|
return embedding
|
||||||
|
|
||||||
|
|
||||||
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
||||||
# 获取当前群聊记录内发言的人
|
# 获取当前群聊记录内发言的人
|
||||||
filter_query = {"chat_id": chat_stream_id}
|
filter_query = {"chat_id": chat_stream_id}
|
||||||
|
|||||||
@@ -827,6 +827,29 @@ class LLMRequest:
|
|||||||
)
|
)
|
||||||
return embedding
|
return embedding
|
||||||
|
|
||||||
|
def get_embedding_sync(self, text: str) -> Union[list, None]:
|
||||||
|
"""同步方法:获取文本的embedding向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 需要获取embedding的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: embedding向量,如果失败则返回None
|
||||||
|
"""
|
||||||
|
return asyncio.run(self.get_embedding(text))
|
||||||
|
|
||||||
|
def generate_response_sync(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
||||||
|
"""同步方式根据输入的提示生成模型的响应
|
||||||
|
|
||||||
|
Args:
|
||||||
|
prompt: 输入的提示文本
|
||||||
|
**kwargs: 额外的参数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Union[str, Tuple]: 模型响应内容,如果有工具调用则返回元组
|
||||||
|
"""
|
||||||
|
return asyncio.run(self.generate_response_async(prompt, **kwargs))
|
||||||
|
|
||||||
|
|
||||||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||||||
"""压缩base64格式的图片到指定大小
|
"""压缩base64格式的图片到指定大小
|
||||||
|
|||||||
Reference in New Issue
Block a user