Merge branch 'dev' of https://github.com/MaiM-with-u/MaiBot into dev
This commit is contained in:
@@ -26,7 +26,7 @@ from rich.progress import (
|
||||
TextColumn,
|
||||
)
|
||||
from src.manager.local_store_manager import local_storage
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.chat.utils.utils import get_embedding_sync
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ class EmbeddingStore:
|
||||
self.idx2hash = None
|
||||
|
||||
def _get_embedding(self, s: str) -> List[float]:
|
||||
return get_embedding(s)
|
||||
return get_embedding_sync(s)
|
||||
|
||||
def get_test_file_path(self):
|
||||
return EMBEDDING_TEST_FILE
|
||||
|
||||
@@ -28,7 +28,7 @@ def _extract_json_from_text(text: str) -> dict:
|
||||
def _entity_extract(llm_req: LLMRequest, paragraph: str) -> List[str]:
|
||||
"""对段落进行实体提取,返回提取出的实体列表(JSON格式)"""
|
||||
entity_extract_context = prompt_template.build_entity_extract_context(paragraph)
|
||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(entity_extract_context)
|
||||
response, (reasoning_content, model_name) = llm_req.generate_response_sync(entity_extract_context)
|
||||
|
||||
entity_extract_result = _extract_json_from_text(response)
|
||||
# 尝试load JSON数据
|
||||
@@ -50,7 +50,7 @@ def _rdf_triple_extract(llm_req: LLMRequest, paragraph: str, entities: list) ->
|
||||
rdf_extract_context = prompt_template.build_rdf_triple_extract_context(
|
||||
paragraph, entities=json.dumps(entities, ensure_ascii=False)
|
||||
)
|
||||
response, (reasoning_content, model_name) = llm_req.generate_response_async(rdf_extract_context)
|
||||
response, (reasoning_content, model_name) = llm_req.generate_response_sync(rdf_extract_context)
|
||||
|
||||
entity_extract_result = _extract_json_from_text(response)
|
||||
# 尝试load JSON数据
|
||||
|
||||
@@ -10,7 +10,7 @@ from .kg_manager import KGManager
|
||||
# from .lpmmconfig import global_config
|
||||
from .utils.dyn_topk import dyn_select_top_k
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.utils.utils import get_embedding
|
||||
from src.chat.utils.utils import get_embedding_sync
|
||||
from src.config.config import global_config
|
||||
|
||||
MAX_KNOWLEDGE_LENGTH = 10000 # 最大知识长度
|
||||
@@ -36,7 +36,7 @@ class QAManager:
|
||||
|
||||
# 生成问题的Embedding
|
||||
part_start_time = time.perf_counter()
|
||||
question_embedding = await get_embedding(question)
|
||||
question_embedding = await get_embedding_sync(question)
|
||||
if question_embedding is None:
|
||||
logger.error("生成问题Embedding失败")
|
||||
return None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, Any, Optional
|
||||
from typing import Dict, Any, Optional, Tuple
|
||||
from rich.traceback import install
|
||||
from datetime import datetime
|
||||
from json_repair import repair_json
|
||||
@@ -81,6 +81,7 @@ class ActionPlanner:
|
||||
self.last_obs_time_mark = 0.0
|
||||
|
||||
def find_message_by_id(self, message_id: str, message_id_list: list) -> Optional[Dict[str, Any]]:
|
||||
# sourcery skip: use-next
|
||||
"""
|
||||
根据message_id从message_id_list中查找对应的原始消息
|
||||
|
||||
@@ -98,7 +99,7 @@ class ActionPlanner:
|
||||
|
||||
async def plan(
|
||||
self, mode: ChatMode = ChatMode.FOCUS
|
||||
) -> Dict[str, Dict[str, Any] | str]: # sourcery skip: dict-comprehension
|
||||
) -> Tuple[Dict[str, Dict[str, Any] | str], Optional[Dict[str, Any]]]: # sourcery skip: dict-comprehension
|
||||
"""
|
||||
规划器 (Planner): 使用LLM根据上下文决定做出什么动作。
|
||||
"""
|
||||
@@ -107,7 +108,8 @@ class ActionPlanner:
|
||||
reasoning = "规划器初始化默认"
|
||||
action_data = {}
|
||||
current_available_actions: Dict[str, ActionInfo] = {}
|
||||
target_message = None # 初始化target_message变量
|
||||
target_message: Optional[Dict[str, Any]] = None # 初始化target_message变量
|
||||
prompt: str = ""
|
||||
|
||||
try:
|
||||
is_group_chat = True
|
||||
@@ -128,10 +130,7 @@ class ActionPlanner:
|
||||
|
||||
# 如果没有可用动作或只有no_reply动作,直接返回no_reply
|
||||
if not current_available_actions:
|
||||
if mode == ChatMode.FOCUS:
|
||||
action = "no_reply"
|
||||
else:
|
||||
action = "no_action"
|
||||
action = "no_reply" if mode == ChatMode.FOCUS else "no_action"
|
||||
reasoning = "没有可用的动作"
|
||||
logger.info(f"{self.log_prefix}{reasoning}")
|
||||
return {
|
||||
@@ -140,7 +139,7 @@ class ActionPlanner:
|
||||
"action_data": action_data,
|
||||
"reasoning": reasoning,
|
||||
},
|
||||
}
|
||||
}, None
|
||||
|
||||
# --- 构建提示词 (调用修改后的 PromptBuilder 方法) ---
|
||||
prompt, message_id_list = await self.build_planner_prompt(
|
||||
@@ -196,8 +195,7 @@ class ActionPlanner:
|
||||
|
||||
# 在FOCUS模式下,非no_reply动作需要target_message_id
|
||||
if mode == ChatMode.FOCUS and action != "no_reply":
|
||||
target_message_id = parsed_json.get("target_message_id")
|
||||
if target_message_id:
|
||||
if target_message_id := parsed_json.get("target_message_id"):
|
||||
# 根据target_message_id查找原始消息
|
||||
target_message = self.find_message_by_id(target_message_id, message_id_list)
|
||||
else:
|
||||
@@ -278,7 +276,7 @@ class ActionPlanner:
|
||||
|
||||
if mode == ChatMode.FOCUS:
|
||||
by_what = "聊天内容"
|
||||
target_prompt = "\n \"target_message_id\":\"触发action的消息id\""
|
||||
target_prompt = '\n "target_message_id":"触发action的消息id"'
|
||||
no_action_block = """重要说明1:
|
||||
- 'no_reply' 表示只进行不进行回复,等待合适的回复时机
|
||||
- 当你刚刚发送了消息,没有人回复时,选择no_reply
|
||||
|
||||
@@ -122,6 +122,18 @@ async def get_embedding(text, request_type="embedding"):
|
||||
return embedding
|
||||
|
||||
|
||||
def get_embedding_sync(text, request_type="embedding"):
|
||||
"""获取文本的embedding向量(同步版本)"""
|
||||
# TODO: API-Adapter修改标记
|
||||
llm = LLMRequest(model=global_config.model.embedding, request_type=request_type)
|
||||
try:
|
||||
embedding = llm.get_embedding_sync(text)
|
||||
except Exception as e:
|
||||
logger.error(f"获取embedding失败: {str(e)}")
|
||||
embedding = None
|
||||
return embedding
|
||||
|
||||
|
||||
def get_recent_group_speaker(chat_stream_id: str, sender, limit: int = 12) -> list:
|
||||
# 获取当前群聊记录内发言的人
|
||||
filter_query = {"chat_id": chat_stream_id}
|
||||
|
||||
@@ -827,6 +827,29 @@ class LLMRequest:
|
||||
)
|
||||
return embedding
|
||||
|
||||
def get_embedding_sync(self, text: str) -> Union[list, None]:
|
||||
"""同步方法:获取文本的embedding向量
|
||||
|
||||
Args:
|
||||
text: 需要获取embedding的文本
|
||||
|
||||
Returns:
|
||||
list: embedding向量,如果失败则返回None
|
||||
"""
|
||||
return asyncio.run(self.get_embedding(text))
|
||||
|
||||
def generate_response_sync(self, prompt: str, **kwargs) -> Union[str, Tuple]:
|
||||
"""同步方式根据输入的提示生成模型的响应
|
||||
|
||||
Args:
|
||||
prompt: 输入的提示文本
|
||||
**kwargs: 额外的参数
|
||||
|
||||
Returns:
|
||||
Union[str, Tuple]: 模型响应内容,如果有工具调用则返回元组
|
||||
"""
|
||||
return asyncio.run(self.generate_response_async(prompt, **kwargs))
|
||||
|
||||
|
||||
def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 1024 * 1024) -> str:
|
||||
"""压缩base64格式的图片到指定大小
|
||||
|
||||
@@ -24,8 +24,8 @@ class NoReplyAction(BaseAction):
|
||||
2. 累计新消息数量达到随机阈值 (默认5-10条) 则结束等待
|
||||
"""
|
||||
|
||||
focus_activation_type = ActionActivationType.NEVER
|
||||
normal_activation_type = ActionActivationType.NEVER
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
normal_activation_type = ActionActivationType.ALWAYS
|
||||
mode_enable = ChatMode.FOCUS
|
||||
parallel_action = False
|
||||
|
||||
|
||||
@@ -36,8 +36,8 @@ class ReplyAction(BaseAction):
|
||||
"""回复动作 - 参与聊天回复"""
|
||||
|
||||
# 激活设置
|
||||
focus_activation_type = ActionActivationType.NEVER
|
||||
normal_activation_type = ActionActivationType.NEVER
|
||||
focus_activation_type = ActionActivationType.ALWAYS
|
||||
normal_activation_type = ActionActivationType.ALWAYS
|
||||
mode_enable = ChatMode.FOCUS
|
||||
parallel_action = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user