This commit is contained in:
SengokuCola
2025-07-16 11:02:43 +08:00
8 changed files with 56 additions and 23 deletions

View File

@@ -26,7 +26,7 @@ from rich.progress import (
TextColumn, TextColumn,
) )
from src.manager.local_store_manager import local_storage from src.manager.local_store_manager import local_storage
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding_sync
from src.config.config import global_config from src.config.config import global_config
@@ -99,7 +99,7 @@ class EmbeddingStore:
self.idx2hash = None self.idx2hash = None
def _get_embedding(self, s: str) -> List[float]: def _get_embedding(self, s: str) -> List[float]:
return get_embedding(s) return get_embedding_sync(s)
def get_test_file_path(self): def get_test_file_path(self):
return EMBEDDING_TEST_FILE return EMBEDDING_TEST_FILE

View File

@@ -28,7 +28,7 @@ def _extract_json_from_text(text: str) -> dict:
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]: def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
"""对段落进行实体提取返回提取出的实体列表JSON格式""" """对段落进行实体提取返回提取出的实体列表JSON格式"""
entity_extract_context = prompt_template.build_entity_extract_context(paragraph) entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context) response, (reasoning_content, model_name) = llm_req.generate_response_sync(entity_extract_context)
entity_extract_result = _extract_json_from_text(response) entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据 # 尝试load JSON数据
@@ -50,7 +50,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
rdf_extract_context = prompt_template.build_rdf_triple_extract_context( rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
paragraph, entities=json.dumps(entities, ensure_ascii=False) paragraph, entities=json.dumps(entities, ensure_ascii=False)
) )
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context) response, (reasoning_content, model_name) = llm_req.generate_response_sync(rdf_extract_context)
entity_extract_result = _extract_json_from_text(response) entity_extract_result = _extract_json_from_text(response)
# 尝试load JSON数据 # 尝试load JSON数据

View File

@@ -10,7 +10,7 @@ from .kg_manager import KGManager
# from .lpmmconfig import global_config # from .lpmmconfig import global_config
from .utils.dyn_topk import dyn_select_top_k from .utils.dyn_topk import dyn_select_top_k
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.chat.utils.utils import get_embedding from src.chat.utils.utils import get_embedding_sync
from src.config.config import global_config from src.config.config import global_config
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度 MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
@@ -36,7 +36,7 @@ class QAManager:
# 生成问题的Embedding # 生成问题的Embedding
part_start_time = time.perf_counter() part_start_time = time.perf_counter()
question_embedding = await get_embedding(question) question_embedding = await get_embedding_sync(question)
if question_embedding is None: if question_embedding is None:
logger.error("生成问题Embedding失败") logger.error("生成问题Embedding失败")
return None return None

View File

@@ -1,7 +1,7 @@
import json import json
import time import time
import traceback import traceback
from typing import Dict, Any, Optional from typing import Dict, Any, Optional, Tuple
from rich.traceback import install from rich.traceback import install
from datetime import datetime from datetime import datetime
from json_repair import repair_json from json_repair import repair_json
@@ -81,6 +81,7 @@ class ActionPlanner:
self.last_obs_time_mark = 0.0 self.last_obs_time_mark = 0.0
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]: def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
# sourcery skip: use-next
""" """
根据message_id从message_id_list中查找对应的原始消息 根据message_id从message_id_list中查找对应的原始消息
@@ -98,7 +99,7 @@ class ActionPlanner:
async def plan( async def plan(
self, mode: ChatMode = ChatMode.FOCUS self, mode: ChatMode = ChatMode.FOCUS
) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension ) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: # sourcery skip: dict-comprehension
""" """
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。 规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
""" """
@@ -107,7 +108,8 @@ class ActionPlanner:
reasoning = "规划器初始化默认" reasoning = "规划器初始化默认"
action_data = {} action_data = {}
current_available_actions: Dict[str, ActionInfo] = {} current_available_actions: Dict[str, ActionInfo] = {}
target_message = None # 初始化target_message变量 target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
prompt: str = ""
try: try:
is_group_chat = True is_group_chat = True
@@ -128,10 +130,7 @@ class ActionPlanner:
# 如果没有可用动作或只有no_reply动作直接返回no_reply # 如果没有可用动作或只有no_reply动作直接返回no_reply
if not current_available_actions: if not current_available_actions:
if mode == ChatMode.FOCUS: action = "no_reply" if mode == ChatMode.FOCUS else "no_action"
action = "no_reply"
else:
action = "no_action"
reasoning = "没有可用的动作" reasoning = "没有可用的动作"
logger.info(f"{self.log_prefix}{reasoning}") logger.info(f"{self.log_prefix}{reasoning}")
return { return {
@@ -140,7 +139,7 @@ class ActionPlanner:
"action_data": action_data, "action_data": action_data,
"reasoning": reasoning, "reasoning": reasoning,
}, },
} }, None
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) --- # --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
prompt, message_id_list = await self.build_planner_prompt( prompt, message_id_list = await self.build_planner_prompt(
@@ -196,8 +195,7 @@ class ActionPlanner:
# 在FOCUS模式下非no_reply动作需要target_message_id # 在FOCUS模式下非no_reply动作需要target_message_id
if mode == ChatMode.FOCUS and action != "no_reply": if mode == ChatMode.FOCUS and action != "no_reply":
target_message_id = parsed_json.get("target_message_id") if target_message_id := parsed_json.get("target_message_id"):
if target_message_id:
# 根据target_message_id查找原始消息 # 根据target_message_id查找原始消息
target_message = self.find_message_by_id(target_message_id, message_id_list) target_message = self.find_message_by_id(target_message_id, message_id_list)
else: else:
@@ -278,7 +276,7 @@ class ActionPlanner:
if mode == ChatMode.FOCUS: if mode == ChatMode.FOCUS:
by_what = "聊天内容" by_what = "聊天内容"
target_prompt = "\n \"target_message_id\":\"触发action的消息id\"" target_prompt = '\n "target_message_id":"触发action的消息id"'
no_action_block = """重要说明1 no_action_block = """重要说明1
- 'no_reply' 表示只进行不进行回复,等待合适的回复时机 - 'no_reply' 表示只进行不进行回复,等待合适的回复时机
- 当你刚刚发送了消息没有人回复时选择no_reply - 当你刚刚发送了消息没有人回复时选择no_reply

View File

@@ -122,6 +122,18 @@ async def get_embedding(text, request_type="embedding"):
return embedding return embedding
def get_embedding_sync(text, request_type="embedding"):
"""获取文本的embedding向量同步版本"""
# TODO: API-Adapter修改标记
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
try:
embedding = llm.get_embedding_sync(text)
except Exception as e:
logger.error(f"获取embedding失败: {str(e)}")
embedding = None
return embedding
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list: def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
# 获取当前群聊记录内发言的人 # 获取当前群聊记录内发言的人
filter_query = {"chat_id": chat_stream_id} filter_query = {"chat_id": chat_stream_id}

View File

@@ -827,6 +827,29 @@ class LLMRequest:
) )
return embedding return embedding
def get_embedding_sync(self, text: str) -> Union[list, None]:
"""同步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
Returns:
list: embedding向量如果失败则返回None
"""
return asyncio.run(self.get_embedding(text))
def generate_response_sync(self, prompt: str, **kwargs) -> Union[str, Tuple]:
"""同步方式根据输入的提示生成模型的响应
Args:
prompt: 输入的提示文本
**kwargs: 额外的参数
Returns:
Union[str, Tuple]: 模型响应内容,如果有工具调用则返回元组
"""
return asyncio.run(self.generate_response_async(prompt, **kwargs))
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str: def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
"""压缩base64格式的图片到指定大小 """压缩base64格式的图片到指定大小

View File

@@ -24,8 +24,8 @@ class NoReplyAction(BaseAction):
2. 累计新消息数量达到随机阈值 (默认5-10条) 则结束等待 2. 累计新消息数量达到随机阈值 (默认5-10条) 则结束等待
""" """
focus_activation_type = ActionActivationType.NEVER focus_activation_type = ActionActivationType.ALWAYS
normal_activation_type = ActionActivationType.NEVER normal_activation_type = ActionActivationType.ALWAYS
mode_enable = ChatMode.FOCUS mode_enable = ChatMode.FOCUS
parallel_action = False parallel_action = False

View File

@@ -36,8 +36,8 @@ class ReplyAction(BaseAction):
"""回复动作 - 参与聊天回复""" """回复动作 - 参与聊天回复"""
# 激活设置 # 激活设置
focus_activation_type = ActionActivationType.NEVER focus_activation_type = ActionActivationType.ALWAYS
normal_activation_type = ActionActivationType.NEVER normal_activation_type = ActionActivationType.ALWAYS
mode_enable = ChatMode.FOCUS mode_enable = ChatMode.FOCUS
parallel_action = False parallel_action = False