This commit is contained in:
SengokuCola
2025-04-24 14:19:26 +08:00
parent f8450f705a
commit 3075664480
13 changed files with 224 additions and 225 deletions

View File

@@ -159,7 +159,9 @@ class ToolUser:
tool_calls_str = "" tool_calls_str = ""
for tool_call in tool_calls: for tool_call in tool_calls:
tool_calls_str += f"{tool_call['function']['name']}\n" tool_calls_str += f"{tool_call['function']['name']}\n"
logger.info(f"根据:\n{prompt}\n\n内容:{content}\n\n模型请求调用{len(tool_calls)}个工具: {tool_calls_str}") logger.info(
f"根据:\n{prompt}\n\n内容:{content}\n\n模型请求调用{len(tool_calls)}个工具: {tool_calls_str}"
)
tool_results = [] tool_results = []
structured_info = {} # 动态生成键 structured_info = {} # 动态生成键

View File

@@ -82,29 +82,25 @@ class ChattingObservation(Observation):
new_messages_list = get_raw_msg_by_timestamp_with_chat( new_messages_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id, chat_id=self.chat_id,
timestamp_start=self.last_observe_time, timestamp_start=self.last_observe_time,
timestamp_end=datetime.now().timestamp(), timestamp_end=datetime.now().timestamp(),
limit=self.max_now_obs_len, limit=self.max_now_obs_len,
limit_mode="latest", limit_mode="latest",
) )
last_obs_time_mark = self.last_observe_time last_obs_time_mark = self.last_observe_time
if new_messages_list: if new_messages_list:
self.last_observe_time = new_messages_list[-1]["time"] self.last_observe_time = new_messages_list[-1]["time"]
self.talking_message.extend(new_messages_list) self.talking_message.extend(new_messages_list)
if len(self.talking_message) > self.max_now_obs_len: if len(self.talking_message) > self.max_now_obs_len:
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条 # 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
oldest_messages = self.talking_message[:messages_to_remove_count] oldest_messages = self.talking_message[:messages_to_remove_count]
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的 self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
oldest_messages_str = await build_readable_messages( oldest_messages_str = await build_readable_messages(
messages=oldest_messages, messages=oldest_messages, timestamp_mode="normal", read_mark=0
timestamp_mode="normal",
read_mark=0
) )
# 调用 LLM 总结主题 # 调用 LLM 总结主题
prompt = ( prompt = (
@@ -145,7 +141,7 @@ class ChattingObservation(Observation):
messages=self.talking_message, messages=self.talking_message,
timestamp_mode="normal", timestamp_mode="normal",
read_mark=last_obs_time_mark, read_mark=last_obs_time_mark,
) )
logger.trace( logger.trace(
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}" f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"

View File

@@ -6,12 +6,10 @@ from src.config.config import global_config
import time import time
from typing import Optional, List, Dict, Callable from typing import Optional, List, Dict, Callable
import traceback import traceback
from src.plugins.chat.utils import parse_text_timestamps
import enum import enum
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402 from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
import random import random
from src.plugins.person_info.relationship_manager import relationship_manager
from ..plugins.utils.prompt_builder import Prompt, global_prompt_manager from ..plugins.utils.prompt_builder import Prompt, global_prompt_manager
from src.plugins.chat.message import MessageRecv from src.plugins.chat.message import MessageRecv
from src.plugins.chat.chat_stream import chat_manager from src.plugins.chat.chat_stream import chat_manager
@@ -20,7 +18,7 @@ from src.plugins.heartFC_chat.heartFC_chat import HeartFChatting
from src.plugins.heartFC_chat.normal_chat import NormalChat from src.plugins.heartFC_chat.normal_chat import NormalChat
from src.do_tool.tool_use import ToolUser from src.do_tool.tool_use import ToolUser
from src.heart_flow.mai_state_manager import MaiStateInfo from src.heart_flow.mai_state_manager import MaiStateInfo
from src.plugins.utils.json_utils import safe_json_dumps, process_llm_tool_response, normalize_llm_response, process_llm_tool_calls from src.plugins.utils.json_utils import safe_json_dumps, normalize_llm_response, process_llm_tool_calls
# 定义常量 (从 interest.py 移动过来) # 定义常量 (从 interest.py 移动过来)
MAX_INTEREST = 15.0 MAX_INTEREST = 15.0
@@ -114,8 +112,6 @@ class InterestChatting:
self.above_threshold = False self.above_threshold = False
self.start_hfc_probability = 0.0 self.start_hfc_probability = 0.0
def add_interest_dict(self, message: MessageRecv, interest_value: float, is_mentioned: bool): def add_interest_dict(self, message: MessageRecv, interest_value: float, is_mentioned: bool):
self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned)
@@ -293,7 +289,7 @@ class SubHeartflow:
) )
self.log_prefix = chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id self.log_prefix = chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id
self.structured_info = {} self.structured_info = {}
async def add_time_current_state(self, add_time: float): async def add_time_current_state(self, add_time: float):
@@ -484,36 +480,36 @@ class SubHeartflow:
async def do_thinking_before_reply(self): async def do_thinking_before_reply(self):
""" """
在回复前进行思考,生成内心想法并收集工具调用结果 在回复前进行思考,生成内心想法并收集工具调用结果
返回: 返回:
tuple: (current_mind, past_mind) 当前想法和过去的想法列表 tuple: (current_mind, past_mind) 当前想法和过去的想法列表
""" """
# 更新活跃时间 # 更新活跃时间
self.last_active_time = time.time() self.last_active_time = time.time()
# ---------- 1. 准备基础数据 ---------- # ---------- 1. 准备基础数据 ----------
# 获取现有想法和情绪状态 # 获取现有想法和情绪状态
current_thinking_info = self.current_mind current_thinking_info = self.current_mind
mood_info = self.chat_state.mood mood_info = self.chat_state.mood
# 获取观察对象 # 获取观察对象
observation = self._get_primary_observation() observation = self._get_primary_observation()
if not observation: if not observation:
logger.error(f"[{self.subheartflow_id}] 无法获取观察对象") logger.error(f"[{self.subheartflow_id}] 无法获取观察对象")
self.update_current_mind("(我没看到任何聊天内容...)") self.update_current_mind("(我没看到任何聊天内容...)")
return self.current_mind, self.past_mind return self.current_mind, self.past_mind
# 获取观察内容 # 获取观察内容
chat_observe_info = observation.get_observe_info() chat_observe_info = observation.get_observe_info()
# ---------- 2. 准备工具和个性化数据 ---------- # ---------- 2. 准备工具和个性化数据 ----------
# 初始化工具 # 初始化工具
tool_instance = ToolUser() tool_instance = ToolUser()
tools = tool_instance._define_tools() tools = tool_instance._define_tools()
# 获取个性化信息 # 获取个性化信息
individuality = Individuality.get_instance() individuality = Individuality.get_instance()
# 构建个性部分 # 构建个性部分
prompt_personality = f"你的名字是{individuality.personality.bot_nickname},你" prompt_personality = f"你的名字是{individuality.personality.bot_nickname},你"
prompt_personality += individuality.personality.personality_core prompt_personality += individuality.personality.personality_core
@@ -547,9 +543,7 @@ class SubHeartflow:
# 加权随机选择思考指导 # 加权随机选择思考指导
hf_do_next = local_random.choices( hf_do_next = local_random.choices(
[option[0] for option in hf_options], [option[0] for option in hf_options], weights=[option[1] for option in hf_options], k=1
weights=[option[1] for option in hf_options],
k=1
)[0] )[0]
# ---------- 4. 构建最终提示词 ---------- # ---------- 4. 构建最终提示词 ----------
@@ -570,16 +564,16 @@ class SubHeartflow:
# ---------- 5. 执行LLM请求并处理响应 ---------- # ---------- 5. 执行LLM请求并处理响应 ----------
content = "" # 初始化内容变量 content = "" # 初始化内容变量
reasoning_content = "" # 初始化推理内容变量 reasoning_content = "" # 初始化推理内容变量
try: try:
# 调用LLM生成响应 # 调用LLM生成响应
response = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools) response = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools)
# 标准化响应格式 # 标准化响应格式
success, normalized_response, error_msg = normalize_llm_response( success, normalized_response, error_msg = normalize_llm_response(
response, log_prefix=f"[{self.subheartflow_id}] " response, log_prefix=f"[{self.subheartflow_id}] "
) )
if not success: if not success:
# 处理标准化失败情况 # 处理标准化失败情况
logger.warning(f"[{self.subheartflow_id}] {error_msg}") logger.warning(f"[{self.subheartflow_id}] {error_msg}")
@@ -588,23 +582,24 @@ class SubHeartflow:
# 从标准化响应中提取内容 # 从标准化响应中提取内容
if len(normalized_response) >= 2: if len(normalized_response) >= 2:
content = normalized_response[0] content = normalized_response[0]
reasoning_content = normalized_response[1] if len(normalized_response) > 1 else "" _reasoning_content = normalized_response[1] if len(normalized_response) > 1 else ""
# 处理可能的工具调用 # 处理可能的工具调用
if len(normalized_response) == 3: if len(normalized_response) == 3:
# 提取并验证工具调用 # 提取并验证工具调用
success, valid_tool_calls, error_msg = process_llm_tool_calls( success, valid_tool_calls, error_msg = process_llm_tool_calls(
normalized_response, log_prefix=f"[{self.subheartflow_id}] " normalized_response, log_prefix=f"[{self.subheartflow_id}] "
) )
if success and valid_tool_calls: if success and valid_tool_calls:
# 记录工具调用信息 # 记录工具调用信息
tool_calls_str = ", ".join([ tool_calls_str = ", ".join(
call.get("function", {}).get("name", "未知工具") [call.get("function", {}).get("name", "未知工具") for call in valid_tool_calls]
for call in valid_tool_calls )
]) logger.info(
logger.info(f"[{self.subheartflow_id}] 模型请求调用{len(valid_tool_calls)}个工具: {tool_calls_str}") f"[{self.subheartflow_id}] 模型请求调用{len(valid_tool_calls)}个工具: {tool_calls_str}"
)
# 收集工具执行结果 # 收集工具执行结果
await self._execute_tool_calls(valid_tool_calls, tool_instance) await self._execute_tool_calls(valid_tool_calls, tool_instance)
elif not success: elif not success:
@@ -628,37 +623,34 @@ class SubHeartflow:
self.update_current_mind(content) self.update_current_mind(content)
return self.current_mind, self.past_mind return self.current_mind, self.past_mind
async def _execute_tool_calls(self, tool_calls, tool_instance): async def _execute_tool_calls(self, tool_calls, tool_instance):
""" """
执行一组工具调用并收集结果 执行一组工具调用并收集结果
参数: 参数:
tool_calls: 工具调用列表 tool_calls: 工具调用列表
tool_instance: 工具使用器实例 tool_instance: 工具使用器实例
""" """
tool_results = [] tool_results = []
structured_info = {} # 动态生成键 structured_info = {} # 动态生成键
# 执行所有工具调用 # 执行所有工具调用
for tool_call in tool_calls: for tool_call in tool_calls:
try: try:
result = await tool_instance._execute_tool_call(tool_call) result = await tool_instance._execute_tool_call(tool_call)
if result: if result:
tool_results.append(result) tool_results.append(result)
# 使用工具名称作为键 # 使用工具名称作为键
tool_name = result["name"] tool_name = result["name"]
if tool_name not in structured_info: if tool_name not in structured_info:
structured_info[tool_name] = [] structured_info[tool_name] = []
structured_info[tool_name].append({ structured_info[tool_name].append({"name": result["name"], "content": result["content"]})
"name": result["name"],
"content": result["content"]
})
except Exception as tool_e: except Exception as tool_e:
logger.error(f"[{self.subheartflow_id}] 工具执行失败: {tool_e}") logger.error(f"[{self.subheartflow_id}] 工具执行失败: {tool_e}")
# 如果有工具结果,记录并更新结构化信息 # 如果有工具结果,记录并更新结构化信息
if structured_info: if structured_info:
logger.debug(f"工具调用收集到结构化信息: {safe_json_dumps(structured_info, ensure_ascii=False)}") logger.debug(f"工具调用收集到结构化信息: {safe_json_dumps(structured_info, ensure_ascii=False)}")

View File

@@ -290,9 +290,9 @@ class SubHeartflowManager:
log_prefix_flow = f"[{stream_name}]" log_prefix_flow = f"[{stream_name}]"
# 只处理 CHAT 状态的子心流 # 只处理 CHAT 状态的子心流
# The code snippet is checking if the `chat_status` attribute of `sub_hf.chat_state` is not equal to # The code snippet is checking if the `chat_status` attribute of `sub_hf.chat_state` is not equal to
# `ChatState.CHAT`. If the condition is met, the code will continue to the next iteration of the loop # `ChatState.CHAT`. If the condition is met, the code will continue to the next iteration of the loop
# or block of code where this snippet is located. # or block of code where this snippet is located.
# if sub_hf.chat_state.chat_status != ChatState.CHAT: # if sub_hf.chat_state.chat_status != ChatState.CHAT:
# continue # continue

View File

@@ -78,7 +78,6 @@ class ChatBot:
groupinfo = message.message_info.group_info groupinfo = message.message_info.group_info
userinfo = message.message_info.user_info userinfo = message.message_info.user_info
if userinfo.user_id in global_config.ban_user_id: if userinfo.user_id in global_config.ban_user_id:
logger.debug(f"用户{userinfo.user_id}被禁止回复") logger.debug(f"用户{userinfo.user_id}被禁止回复")
return return

View File

@@ -328,7 +328,9 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
final_sentences = [content for content, sep in merged_segments if content] # 只保留有内容的段 final_sentences = [content for content, sep in merged_segments if content] # 只保留有内容的段
# 清理可能引入的空字符串和仅包含空白的字符串 # 清理可能引入的空字符串和仅包含空白的字符串
final_sentences = [s for s in final_sentences if s.strip()] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串 final_sentences = [
s for s in final_sentences if s.strip()
] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串
logger.debug(f"分割并合并后的句子: {final_sentences}") logger.debug(f"分割并合并后的句子: {final_sentences}")
return final_sentences return final_sentences

View File

@@ -2,6 +2,7 @@ import asyncio
import time import time
import traceback import traceback
from typing import List, Optional, Dict, Any, TYPE_CHECKING from typing import List, Optional, Dict, Any, TYPE_CHECKING
# import json # 移除因为使用了json_utils # import json # 移除因为使用了json_utils
from src.plugins.chat.message import MessageRecv, BaseMessageInfo, MessageThinking, MessageSending from src.plugins.chat.message import MessageRecv, BaseMessageInfo, MessageThinking, MessageSending
from src.plugins.chat.message import MessageSet, Seg # Local import needed after move from src.plugins.chat.message import MessageSet, Seg # Local import needed after move
@@ -17,7 +18,7 @@ from src.plugins.heartFC_chat.heartFC_generator import HeartFCGenerator
from src.do_tool.tool_use import ToolUser from src.do_tool.tool_use import ToolUser
from ..chat.message_sender import message_manager # <-- Import the global manager from ..chat.message_sender import message_manager # <-- Import the global manager
from src.plugins.chat.emoji_manager import emoji_manager from src.plugins.chat.emoji_manager import emoji_manager
from src.plugins.utils.json_utils import extract_tool_call_arguments, safe_json_dumps, process_llm_tool_response # 导入新的JSON工具 from src.plugins.utils.json_utils import process_llm_tool_response # 导入新的JSON工具
# --- End import --- # --- End import ---
@@ -37,7 +38,7 @@ if TYPE_CHECKING:
# Keep this if HeartFCController methods are still needed elsewhere, # Keep this if HeartFCController methods are still needed elsewhere,
# but the instance variable will be removed from HeartFChatting # but the instance variable will be removed from HeartFChatting
# from .heartFC_controler import HeartFCController # from .heartFC_controler import HeartFCController
from src.heart_flow.heartflow import SubHeartflow, heartflow # <-- 同时导入 heartflow 实例用于类型检查 from src.heart_flow.heartflow import SubHeartflow # <-- 同时导入 heartflow 实例用于类型检查
PLANNER_TOOL_DEFINITION = [ PLANNER_TOOL_DEFINITION = [
{ {
@@ -327,7 +328,6 @@ class HeartFChatting:
with Timer("Wait New Msg", cycle_timers): # <--- Start Wait timer with Timer("Wait New Msg", cycle_timers): # <--- Start Wait timer
wait_start_time = time.monotonic() wait_start_time = time.monotonic()
while True: while True:
# 检查是否有新消息 # 检查是否有新消息
has_new = await observation.has_new_messages_since(planner_start_db_time) has_new = await observation.has_new_messages_since(planner_start_db_time)
if has_new: if has_new:
@@ -424,7 +424,7 @@ class HeartFChatting:
observed_messages: List[dict] = [] observed_messages: List[dict] = []
current_mind: Optional[str] = None current_mind: Optional[str] = None
llm_error = False llm_error = False
try: try:
observation = self.sub_hf._get_primary_observation() observation = self.sub_hf._get_primary_observation()
@@ -434,19 +434,17 @@ class HeartFChatting:
except Exception as e: except Exception as e:
logger.error(f"{log_prefix}[Planner] 获取观察信息时出错: {e}") logger.error(f"{log_prefix}[Planner] 获取观察信息时出错: {e}")
try: try:
current_mind, _past_mind = await self.sub_hf.do_thinking_before_reply() current_mind, _past_mind = await self.sub_hf.do_thinking_before_reply()
except Exception as e_subhf: except Exception as e_subhf:
logger.error(f"{log_prefix}[Planner] SubHeartflow 思考失败: {e_subhf}") logger.error(f"{log_prefix}[Planner] SubHeartflow 思考失败: {e_subhf}")
current_mind = "[思考时出错]" current_mind = "[思考时出错]"
# --- 使用 LLM 进行决策 --- # # --- 使用 LLM 进行决策 --- #
action = "no_reply" # 默认动作 action = "no_reply" # 默认动作
emoji_query = "" # 默认表情查询 emoji_query = "" # 默认表情查询
reasoning = "默认决策或获取决策失败" reasoning = "默认决策或获取决策失败"
llm_error = False # LLM错误标志 llm_error = False # LLM错误标志
try: try:
prompt = await self._build_planner_prompt(observed_messages_str, current_mind, self.sub_hf.structured_info) prompt = await self._build_planner_prompt(observed_messages_str, current_mind, self.sub_hf.structured_info)
@@ -475,21 +473,17 @@ class HeartFChatting:
# 使用辅助函数处理工具调用响应 # 使用辅助函数处理工具调用响应
success, arguments, error_msg = process_llm_tool_response( success, arguments, error_msg = process_llm_tool_response(
response, response, expected_tool_name="decide_reply_action", log_prefix=f"{log_prefix}[Planner] "
expected_tool_name="decide_reply_action",
log_prefix=f"{log_prefix}[Planner] "
) )
if success: if success:
# 提取决策参数 # 提取决策参数
action = arguments.get("action", "no_reply") action = arguments.get("action", "no_reply")
reasoning = arguments.get("reasoning", "未提供理由") reasoning = arguments.get("reasoning", "未提供理由")
emoji_query = arguments.get("emoji_query", "") emoji_query = arguments.get("emoji_query", "")
# 记录决策结果 # 记录决策结果
logger.debug( logger.debug(f"{log_prefix}[Planner] 决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'")
f"{log_prefix}[Planner] 决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'"
)
else: else:
# 处理工具调用失败 # 处理工具调用失败
logger.warning(f"{log_prefix}[Planner] {error_msg}") logger.warning(f"{log_prefix}[Planner] {error_msg}")
@@ -584,7 +578,7 @@ class HeartFChatting:
"""优雅关闭HeartFChatting实例取消活动循环任务""" """优雅关闭HeartFChatting实例取消活动循环任务"""
log_prefix = self._get_log_prefix() log_prefix = self._get_log_prefix()
logger.info(f"{log_prefix} 正在关闭HeartFChatting...") logger.info(f"{log_prefix} 正在关闭HeartFChatting...")
# 取消循环任务 # 取消循环任务
if self._loop_task and not self._loop_task.done(): if self._loop_task and not self._loop_task.done():
logger.info(f"{log_prefix} 正在取消HeartFChatting循环任务") logger.info(f"{log_prefix} 正在取消HeartFChatting循环任务")
@@ -605,17 +599,19 @@ class HeartFChatting:
if self._processing_lock.locked(): if self._processing_lock.locked():
self._processing_lock.release() self._processing_lock.release()
logger.warning(f"{log_prefix} 已释放处理锁") logger.warning(f"{log_prefix} 已释放处理锁")
logger.info(f"{log_prefix} HeartFChatting关闭完成") logger.info(f"{log_prefix} HeartFChatting关闭完成")
async def _build_planner_prompt(self, observed_messages_str: str, current_mind: Optional[str], structured_info: Dict[str, Any]) -> str: async def _build_planner_prompt(
self, observed_messages_str: str, current_mind: Optional[str], structured_info: Dict[str, Any]
) -> str:
"""构建 Planner LLM 的提示词""" """构建 Planner LLM 的提示词"""
prompt = f"你的名字是 {global_config.BOT_NICKNAME}。是一个女大二学生正在QQ聊天正在决定是否以及如何回应当前的聊天。\n" prompt = f"你的名字是 {global_config.BOT_NICKNAME}。是一个女大二学生正在QQ聊天正在决定是否以及如何回应当前的聊天。\n"
if structured_info: if structured_info:
prompt += f"以下是一些额外的信息:\n{structured_info}\n" prompt += f"以下是一些额外的信息:\n{structured_info}\n"
if observed_messages_str: if observed_messages_str:
prompt += "观察到的最新聊天内容如下 (最近的消息在最后)\n---\n" prompt += "观察到的最新聊天内容如下 (最近的消息在最后)\n---\n"
prompt += observed_messages_str prompt += observed_messages_str

View File

@@ -72,7 +72,13 @@ class HeartFCGenerator:
return None return None
async def _generate_response_with_model( async def _generate_response_with_model(
self, structured_info: str, current_mind_info: str, reason: str, message: MessageRecv, model: LLMRequest, thinking_id: str self,
structured_info: str,
current_mind_info: str,
reason: str,
message: MessageRecv,
model: LLMRequest,
thinking_id: str,
) -> str: ) -> str:
sender_name = "" sender_name = ""

View File

@@ -81,13 +81,22 @@ class PromptBuilder:
self.activate_messages = "" self.activate_messages = ""
async def build_prompt( async def build_prompt(
self, build_mode, reason, current_mind_info, structured_info, message_txt: str, sender_name: str = "某人", chat_stream=None self,
build_mode,
reason,
current_mind_info,
structured_info,
message_txt: str,
sender_name: str = "某人",
chat_stream=None,
) -> Optional[tuple[str, str]]: ) -> Optional[tuple[str, str]]:
if build_mode == "normal": if build_mode == "normal":
return await self._build_prompt_normal(chat_stream, message_txt, sender_name) return await self._build_prompt_normal(chat_stream, message_txt, sender_name)
elif build_mode == "focus": elif build_mode == "focus":
return await self._build_prompt_focus(reason, current_mind_info, structured_info, chat_stream, message_txt, sender_name) return await self._build_prompt_focus(
reason, current_mind_info, structured_info, chat_stream, message_txt, sender_name
)
return None return None
async def _build_prompt_focus( async def _build_prompt_focus(

View File

@@ -711,7 +711,7 @@ class LLMRequest:
reasoning_content = "" reasoning_content = ""
content = "" content = ""
tool_calls = None # 初始化工具调用变量 tool_calls = None # 初始化工具调用变量
async for line_bytes in response.content: async for line_bytes in response.content:
try: try:
line = line_bytes.decode("utf-8").strip() line = line_bytes.decode("utf-8").strip()
@@ -733,7 +733,7 @@ class LLMRequest:
if delta_content is None: if delta_content is None:
delta_content = "" delta_content = ""
accumulated_content += delta_content accumulated_content += delta_content
# 提取工具调用信息 # 提取工具调用信息
if "tool_calls" in delta: if "tool_calls" in delta:
if tool_calls is None: if tool_calls is None:
@@ -741,7 +741,7 @@ class LLMRequest:
else: else:
# 合并工具调用信息 # 合并工具调用信息
tool_calls.extend(delta["tool_calls"]) tool_calls.extend(delta["tool_calls"])
# 检测流式输出文本是否结束 # 检测流式输出文本是否结束
finish_reason = chunk["choices"][0].get("finish_reason") finish_reason = chunk["choices"][0].get("finish_reason")
if delta.get("reasoning_content", None): if delta.get("reasoning_content", None):
@@ -774,23 +774,19 @@ class LLMRequest:
if think_match: if think_match:
reasoning_content = think_match.group(1).strip() reasoning_content = think_match.group(1).strip()
content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip() content = re.sub(r"<think>.*?</think>", "", content, flags=re.DOTALL).strip()
# 构建消息对象 # 构建消息对象
message = { message = {
"content": content, "content": content,
"reasoning_content": reasoning_content, "reasoning_content": reasoning_content,
} }
# 如果有工具调用,添加到消息中 # 如果有工具调用,添加到消息中
if tool_calls: if tool_calls:
message["tool_calls"] = tool_calls message["tool_calls"] = tool_calls
result = { result = {
"choices": [ "choices": [{"message": message}],
{
"message": message
}
],
"usage": usage, "usage": usage,
} }
return result return result
@@ -1128,9 +1124,9 @@ class LLMRequest:
response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt) response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt)
# 原样返回响应,不做处理 # 原样返回响应,不做处理
return response return response
async def generate_response_tool_async(self, prompt: str, tools: list, **kwargs) -> Union[str, Tuple]: async def generate_response_tool_async(self, prompt: str, tools: list, **kwargs) -> Union[str, Tuple]:
"""异步方式根据输入的提示生成模型的响应""" """异步方式根据输入的提示生成模型的响应"""
# 构建请求体不硬编码max_tokens # 构建请求体不硬编码max_tokens
@@ -1139,7 +1135,7 @@ class LLMRequest:
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
**self.params, **self.params,
**kwargs, **kwargs,
"tools": tools "tools": tools,
} }
logger.debug(f"向模型 {self.model_name} 发送工具调用请求,包含 {len(tools)} 个工具") logger.debug(f"向模型 {self.model_name} 发送工具调用请求,包含 {len(tools)} 个工具")
@@ -1150,7 +1146,7 @@ class LLMRequest:
logger.debug(f"收到工具调用响应,包含 {len(tool_calls) if tool_calls else 0} 个工具调用") logger.debug(f"收到工具调用响应,包含 {len(tool_calls) if tool_calls else 0} 个工具调用")
return content, reasoning_content, tool_calls return content, reasoning_content, tool_calls
else: else:
logger.debug(f"收到普通响应,无工具调用") logger.debug("收到普通响应,无工具调用")
return response return response
async def get_embedding(self, text: str) -> Union[list, None]: async def get_embedding(self, text: str) -> Union[list, None]:

View File

@@ -303,7 +303,9 @@ async def build_readable_messages(
) )
readable_read_mark = translate_timestamp_to_human_readable(read_mark, mode=timestamp_mode) readable_read_mark = translate_timestamp_to_human_readable(read_mark, mode=timestamp_mode)
read_mark_line = f"\n\n--- 以上消息已读 (标记时间: {readable_read_mark}) ---\n--- 请关注你上次思考之后以下的新消息---\n" read_mark_line = (
f"\n\n--- 以上消息已读 (标记时间: {readable_read_mark}) ---\n--- 请关注你上次思考之后以下的新消息---\n"
)
# 组合结果,确保空部分不引入多余的标记或换行 # 组合结果,确保空部分不引入多余的标记或换行
if formatted_before and formatted_after: if formatted_before and formatted_after:

View File

@@ -1,27 +1,28 @@
import json import json
import logging import logging
from typing import Any, Dict, Optional, TypeVar, Generic, List, Union, Callable, Tuple from typing import Any, Dict, TypeVar, List, Union, Callable, Tuple
# 定义类型变量用于泛型类型提示 # 定义类型变量用于泛型类型提示
T = TypeVar('T') T = TypeVar("T")
# 获取logger # 获取logger
logger = logging.getLogger("json_utils") logger = logging.getLogger("json_utils")
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]: def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
""" """
安全地解析JSON字符串出错时返回默认值 安全地解析JSON字符串出错时返回默认值
参数: 参数:
json_str: 要解析的JSON字符串 json_str: 要解析的JSON字符串
default_value: 解析失败时返回的默认值 default_value: 解析失败时返回的默认值
返回: 返回:
解析后的Python对象或在解析失败时返回default_value 解析后的Python对象或在解析失败时返回default_value
""" """
if not json_str: if not json_str:
return default_value return default_value
try: try:
return json.loads(json_str) return json.loads(json_str)
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
@@ -31,66 +32,67 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
logger.error(f"JSON解析过程中发生意外错误: {e}") logger.error(f"JSON解析过程中发生意外错误: {e}")
return default_value return default_value
def extract_tool_call_arguments(tool_call: Dict[str, Any],
default_value: Dict[str, Any] = None) -> Dict[str, Any]: def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
""" """
从LLM工具调用对象中提取参数 从LLM工具调用对象中提取参数
参数: 参数:
tool_call: 工具调用对象字典 tool_call: 工具调用对象字典
default_value: 解析失败时返回的默认值 default_value: 解析失败时返回的默认值
返回: 返回:
解析后的参数字典或在解析失败时返回default_value 解析后的参数字典或在解析失败时返回default_value
""" """
default_result = default_value or {} default_result = default_value or {}
if not tool_call or not isinstance(tool_call, dict): if not tool_call or not isinstance(tool_call, dict):
logger.error(f"无效的工具调用对象: {tool_call}") logger.error(f"无效的工具调用对象: {tool_call}")
return default_result return default_result
try: try:
# 提取function参数 # 提取function参数
function_data = tool_call.get("function", {}) function_data = tool_call.get("function", {})
if not function_data or not isinstance(function_data, dict): if not function_data or not isinstance(function_data, dict):
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}") logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result return default_result
# 提取arguments # 提取arguments
arguments_str = function_data.get("arguments", "{}") arguments_str = function_data.get("arguments", "{}")
if not arguments_str: if not arguments_str:
return default_result return default_result
# 解析JSON # 解析JSON
return safe_json_loads(arguments_str, default_result) return safe_json_loads(arguments_str, default_result)
except Exception as e: except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}") logger.error(f"提取工具调用参数时出错: {e}")
return default_result return default_result
def get_json_value(json_obj: Dict[str, Any], key_path: str,
default_value: T = None, def get_json_value(
transform_func: Callable[[Any], T] = None) -> Union[Any, T]: json_obj: Dict[str, Any], key_path: str, default_value: T = None, transform_func: Callable[[Any], T] = None
) -> Union[Any, T]:
""" """
从JSON对象中按照路径提取值支持点表示法路径"data.items.0.name" 从JSON对象中按照路径提取值支持点表示法路径"data.items.0.name"
参数: 参数:
json_obj: JSON对象(已解析的字典) json_obj: JSON对象(已解析的字典)
key_path: 键路径,使用点表示法,如"data.items.0.name" key_path: 键路径,使用点表示法,如"data.items.0.name"
default_value: 获取失败时返回的默认值 default_value: 获取失败时返回的默认值
transform_func: 可选的转换函数,用于对获取的值进行转换 transform_func: 可选的转换函数,用于对获取的值进行转换
返回: 返回:
路径指向的值或在获取失败时返回default_value 路径指向的值或在获取失败时返回default_value
""" """
if not json_obj or not key_path: if not json_obj or not key_path:
return default_value return default_value
try: try:
# 分割路径 # 分割路径
keys = key_path.split(".") keys = key_path.split(".")
current = json_obj current = json_obj
# 遍历路径 # 遍历路径
for key in keys: for key in keys:
# 处理数组索引 # 处理数组索引
@@ -108,7 +110,7 @@ def get_json_value(json_obj: Dict[str, Any], key_path: str,
return default_value return default_value
else: else:
return default_value return default_value
# 应用转换函数(如果提供) # 应用转换函数(如果提供)
if transform_func and current is not None: if transform_func and current is not None:
return transform_func(current) return transform_func(current)
@@ -117,17 +119,17 @@ def get_json_value(json_obj: Dict[str, Any], key_path: str,
logger.error(f"从JSON获取值时出错: {e}, 路径: {key_path}") logger.error(f"从JSON获取值时出错: {e}, 路径: {key_path}")
return default_value return default_value
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False,
pretty: bool = False) -> str: def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
""" """
安全地将Python对象序列化为JSON字符串 安全地将Python对象序列化为JSON字符串
参数: 参数:
obj: 要序列化的Python对象 obj: 要序列化的Python对象
default_value: 序列化失败时返回的默认值 default_value: 序列化失败时返回的默认值
ensure_ascii: 是否确保ASCII编码(默认False允许中文等非ASCII字符) ensure_ascii: 是否确保ASCII编码(默认False允许中文等非ASCII字符)
pretty: 是否美化输出JSON pretty: 是否美化输出JSON
返回: 返回:
序列化后的JSON字符串或在序列化失败时返回default_value 序列化后的JSON字符串或在序列化失败时返回default_value
""" """
@@ -141,13 +143,14 @@ def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = Fa
logger.error(f"JSON序列化过程中发生意外错误: {e}") logger.error(f"JSON序列化过程中发生意外错误: {e}")
return default_value return default_value
def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]: def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]:
""" """
合并多个JSON对象(字典) 合并多个JSON对象(字典)
参数: 参数:
*objects: 要合并的JSON对象(字典) *objects: 要合并的JSON对象(字典)
返回: 返回:
合并后的字典,后面的对象会覆盖前面对象的相同键 合并后的字典,后面的对象会覆盖前面对象的相同键
""" """
@@ -157,109 +160,110 @@ def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]:
result.update(obj) result.update(obj)
return result return result
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]: def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
""" """
标准化LLM响应格式将各种格式如元组转换为统一的列表格式 标准化LLM响应格式将各种格式如元组转换为统一的列表格式
参数: 参数:
response: 原始LLM响应 response: 原始LLM响应
log_prefix: 日志前缀 log_prefix: 日志前缀
返回: 返回:
元组 (成功标志, 标准化后的响应列表, 错误消息) 元组 (成功标志, 标准化后的响应列表, 错误消息)
""" """
# 检查是否为None # 检查是否为None
if response is None: if response is None:
return False, [], "LLM响应为None" return False, [], "LLM响应为None"
# 记录原始类型 # 记录原始类型
logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}") logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
# 将元组转换为列表 # 将元组转换为列表
if isinstance(response, tuple): if isinstance(response, tuple):
logger.debug(f"{log_prefix}将元组响应转换为列表") logger.debug(f"{log_prefix}将元组响应转换为列表")
response = list(response) response = list(response)
# 确保是列表类型 # 确保是列表类型
if not isinstance(response, list): if not isinstance(response, list):
return False, [], f"无法处理的LLM响应类型: {type(response).__name__}" return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
# 处理工具调用部分(如果存在) # 处理工具调用部分(如果存在)
if len(response) == 3: if len(response) == 3:
content, reasoning, tool_calls = response content, reasoning, tool_calls = response
# 将工具调用部分转换为列表(如果是元组) # 将工具调用部分转换为列表(如果是元组)
if isinstance(tool_calls, tuple): if isinstance(tool_calls, tuple):
logger.debug(f"{log_prefix}将工具调用元组转换为列表") logger.debug(f"{log_prefix}将工具调用元组转换为列表")
tool_calls = list(tool_calls) tool_calls = list(tool_calls)
response[2] = tool_calls response[2] = tool_calls
return True, response, "" return True, response, ""
def process_llm_tool_calls(response: List[Any], log_prefix: str = "") -> Tuple[bool, List[Dict[str, Any]], str]: def process_llm_tool_calls(response: List[Any], log_prefix: str = "") -> Tuple[bool, List[Dict[str, Any]], str]:
""" """
处理并提取LLM响应中的工具调用列表 处理并提取LLM响应中的工具调用列表
参数: 参数:
response: 标准化后的LLM响应列表 response: 标准化后的LLM响应列表
log_prefix: 日志前缀 log_prefix: 日志前缀
返回: 返回:
元组 (成功标志, 工具调用列表, 错误消息) 元组 (成功标志, 工具调用列表, 错误消息)
""" """
# 确保响应格式正确 # 确保响应格式正确
if len(response) != 3: if len(response) != 3:
return False, [], f"LLM响应元素数量不正确: 预期3个元素实际{len(response)}" return False, [], f"LLM响应元素数量不正确: 预期3个元素实际{len(response)}"
# 提取工具调用部分 # 提取工具调用部分
tool_calls = response[2] tool_calls = response[2]
# 检查工具调用是否有效 # 检查工具调用是否有效
if tool_calls is None: if tool_calls is None:
return False, [], "工具调用部分为None" return False, [], "工具调用部分为None"
if not isinstance(tool_calls, list): if not isinstance(tool_calls, list):
return False, [], f"工具调用部分不是列表: {type(tool_calls).__name__}" return False, [], f"工具调用部分不是列表: {type(tool_calls).__name__}"
if len(tool_calls) == 0: if len(tool_calls) == 0:
return False, [], "工具调用列表为空" return False, [], "工具调用列表为空"
# 检查工具调用是否格式正确 # 检查工具调用是否格式正确
valid_tool_calls = [] valid_tool_calls = []
for i, tool_call in enumerate(tool_calls): for i, tool_call in enumerate(tool_calls):
if not isinstance(tool_call, dict): if not isinstance(tool_call, dict):
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}") logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}")
continue continue
if tool_call.get("type") != "function": if tool_call.get("type") != "function":
logger.warning(f"{log_prefix}工具调用[{i}]不是函数类型: {tool_call.get('type', '未知')}") logger.warning(f"{log_prefix}工具调用[{i}]不是函数类型: {tool_call.get('type', '未知')}")
continue continue
if "function" not in tool_call or not isinstance(tool_call["function"], dict): if "function" not in tool_call or not isinstance(tool_call["function"], dict):
logger.warning(f"{log_prefix}工具调用[{i}]缺少function字段或格式不正确") logger.warning(f"{log_prefix}工具调用[{i}]缺少function字段或格式不正确")
continue continue
valid_tool_calls.append(tool_call) valid_tool_calls.append(tool_call)
# 检查是否有有效的工具调用 # 检查是否有有效的工具调用
if not valid_tool_calls: if not valid_tool_calls:
return False, [], "没有找到有效的工具调用" return False, [], "没有找到有效的工具调用"
return True, valid_tool_calls, "" return True, valid_tool_calls, ""
def process_llm_tool_response( def process_llm_tool_response(
response: Any, response: Any, expected_tool_name: str = None, log_prefix: str = ""
expected_tool_name: str = None,
log_prefix: str = ""
) -> Tuple[bool, Dict[str, Any], str]: ) -> Tuple[bool, Dict[str, Any], str]:
""" """
处理LLM返回的工具调用响应进行常见错误检查并提取参数 处理LLM返回的工具调用响应进行常见错误检查并提取参数
参数: 参数:
response: LLM的响应预期是[content, reasoning, tool_calls]格式的列表或元组 response: LLM的响应预期是[content, reasoning, tool_calls]格式的列表或元组
expected_tool_name: 预期的工具名称,如不指定则不检查 expected_tool_name: 预期的工具名称,如不指定则不检查
log_prefix: 日志前缀,用于标识日志来源 log_prefix: 日志前缀,用于标识日志来源
返回: 返回:
三元组(成功标志, 参数字典, 错误描述) 三元组(成功标志, 参数字典, 错误描述)
- 如果成功解析,返回(True, 参数字典, "") - 如果成功解析,返回(True, 参数字典, "")
@@ -269,29 +273,29 @@ def process_llm_tool_response(
success, normalized_response, error_msg = normalize_llm_response(response, log_prefix) success, normalized_response, error_msg = normalize_llm_response(response, log_prefix)
if not success: if not success:
return False, {}, error_msg return False, {}, error_msg
# 使用新的工具调用处理函数 # 使用新的工具调用处理函数
success, valid_tool_calls, error_msg = process_llm_tool_calls(normalized_response, log_prefix) success, valid_tool_calls, error_msg = process_llm_tool_calls(normalized_response, log_prefix)
if not success: if not success:
return False, {}, error_msg return False, {}, error_msg
# 检查是否有工具调用 # 检查是否有工具调用
if not valid_tool_calls: if not valid_tool_calls:
return False, {}, "没有有效的工具调用" return False, {}, "没有有效的工具调用"
# 获取第一个工具调用 # 获取第一个工具调用
tool_call = valid_tool_calls[0] tool_call = valid_tool_calls[0]
# 检查工具名称(如果提供了预期名称) # 检查工具名称(如果提供了预期名称)
if expected_tool_name: if expected_tool_name:
actual_name = tool_call.get("function", {}).get("name") actual_name = tool_call.get("function", {}).get("name")
if actual_name != expected_tool_name: if actual_name != expected_tool_name:
return False, {}, f"工具名称不匹配: 预期'{expected_tool_name}',实际'{actual_name}'" return False, {}, f"工具名称不匹配: 预期'{expected_tool_name}',实际'{actual_name}'"
# 提取并解析参数 # 提取并解析参数
try: try:
arguments = extract_tool_call_arguments(tool_call, {}) arguments = extract_tool_call_arguments(tool_call, {})
return True, arguments, "" return True, arguments, ""
except Exception as e: except Exception as e:
logger.error(f"{log_prefix}解析工具参数时出错: {e}") logger.error(f"{log_prefix}解析工具参数时出错: {e}")
return False, {}, f"解析参数失败: {str(e)}" return False, {}, f"解析参数失败: {str(e)}"

View File

@@ -6,24 +6,25 @@ from src.do_tool.tool_use import ToolUser
import statistics import statistics
import json import json
async def run_test(test_name, test_function, iterations=5): async def run_test(test_name, test_function, iterations=5):
""" """
运行指定次数的测试并计算平均响应时间 运行指定次数的测试并计算平均响应时间
参数: 参数:
test_name: 测试名称 test_name: 测试名称
test_function: 要执行的测试函数 test_function: 要执行的测试函数
iterations: 测试迭代次数 iterations: 测试迭代次数
返回: 返回:
测试结果统计 测试结果统计
""" """
print(f"开始 {test_name} 测试({iterations}次迭代)...") print(f"开始 {test_name} 测试({iterations}次迭代)...")
times = [] times = []
responses = [] responses = []
for i in range(iterations): for i in range(iterations):
print(f" 运行第 {i+1}/{iterations} 次测试...") print(f" 运行第 {i + 1}/{iterations} 次测试...")
start_time = time.time() start_time = time.time()
response = await test_function() response = await test_function()
end_time = time.time() end_time = time.time()
@@ -31,18 +32,19 @@ async def run_test(test_name, test_function, iterations=5):
times.append(elapsed) times.append(elapsed)
responses.append(response) responses.append(response)
print(f" - 耗时: {elapsed:.2f}") print(f" - 耗时: {elapsed:.2f}")
results = { results = {
"平均耗时": statistics.mean(times), "平均耗时": statistics.mean(times),
"最短耗时": min(times), "最短耗时": min(times),
"最长耗时": max(times), "最长耗时": max(times),
"标准差": statistics.stdev(times) if len(times) > 1 else 0, "标准差": statistics.stdev(times) if len(times) > 1 else 0,
"所有耗时": times, "所有耗时": times,
"响应结果": responses "响应结果": responses,
} }
return results return results
async def test_with_tool_calls(): async def test_with_tool_calls():
"""使用工具调用的LLM请求测试""" """使用工具调用的LLM请求测试"""
# 创建LLM模型实例 # 创建LLM模型实例
@@ -53,14 +55,14 @@ async def test_with_tool_calls():
max_tokens=800, max_tokens=800,
request_type="benchmark_test", request_type="benchmark_test",
) )
# 创建工具实例 # 创建工具实例
tool_instance = ToolUser() tool_instance = ToolUser()
tools = tool_instance._define_tools() tools = tool_instance._define_tools()
# 简单的测试提示词 # 简单的测试提示词
prompt = "请分析当前天气情况并查询今日历史上的重要事件。并且3.9和3.11谁比较大?请使用适当的工具来获取这些信息。" prompt = "请分析当前天气情况并查询今日历史上的重要事件。并且3.9和3.11谁比较大?请使用适当的工具来获取这些信息。"
prompt = ''' prompt = """
你的名字是麦麦,你包容开放,情绪敏感,有时候有些搞怪幽默, 是一个学习心理学和脑科学的女大学生,现在在读大二,你会刷贴吧,有时候会想瑟瑟,喜欢刷小红书 你的名字是麦麦,你包容开放,情绪敏感,有时候有些搞怪幽默, 是一个学习心理学和脑科学的女大学生,现在在读大二,你会刷贴吧,有时候会想瑟瑟,喜欢刷小红书
----------------------------------- -----------------------------------
现在是2025-04-24 12:37:00你正在上网和qq群里的网友们聊天群里正在聊的话题是 现在是2025-04-24 12:37:00你正在上网和qq群里的网友们聊天群里正在聊的话题是
@@ -89,52 +91,47 @@ async def test_with_tool_calls():
回复的要求是:平淡一些,简短一些,说中文,如果你要回复,最好只回复一个人的一个话题 回复的要求是:平淡一些,简短一些,说中文,如果你要回复,最好只回复一个人的一个话题
请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写。不要回复自己的发言,尽量不要说你说过的话。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写。不要回复自己的发言,尽量不要说你说过的话。
现在请你继续生成你在这个聊天中的想法,在原来想法的基础上继续思考,不要分点输出,生成内心想法,文字不要浮夸 现在请你继续生成你在这个聊天中的想法,在原来想法的基础上继续思考,不要分点输出,生成内心想法,文字不要浮夸
在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。''' 在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。"""
# 发送带有工具调用的请求 # 发送带有工具调用的请求
response = await llm_model.generate_response_tool_async(prompt=prompt, tools=tools) response = await llm_model.generate_response_tool_async(prompt=prompt, tools=tools)
result_info = {} result_info = {}
# 简单处理工具调用结果 # 简单处理工具调用结果
if len(response) == 3: if len(response) == 3:
content, reasoning_content, tool_calls = response content, reasoning_content, tool_calls = response
tool_calls_count = len(tool_calls) if tool_calls else 0 tool_calls_count = len(tool_calls) if tool_calls else 0
print(f" 工具调用请求生成了 {tool_calls_count} 个工具调用") print(f" 工具调用请求生成了 {tool_calls_count} 个工具调用")
# 输出内容和工具调用详情 # 输出内容和工具调用详情
print("\n 生成的内容:") print("\n 生成的内容:")
print(f" {content[:200]}..." if len(content) > 200 else f" {content}") print(f" {content[:200]}..." if len(content) > 200 else f" {content}")
if tool_calls: if tool_calls:
print("\n 工具调用详情:") print("\n 工具调用详情:")
for i, tool_call in enumerate(tool_calls): for i, tool_call in enumerate(tool_calls):
tool_name = tool_call['function']['name'] tool_name = tool_call["function"]["name"]
tool_params = tool_call['function'].get('arguments', {}) tool_params = tool_call["function"].get("arguments", {})
print(f" - 工具 {i+1}: {tool_name}") print(f" - 工具 {i + 1}: {tool_name}")
print(f" 参数: {json.dumps(tool_params, ensure_ascii=False)[:100]}..." print(
if len(json.dumps(tool_params, ensure_ascii=False)) > 100 f" 参数: {json.dumps(tool_params, ensure_ascii=False)[:100]}..."
else f" 参数: {json.dumps(tool_params, ensure_ascii=False)}") if len(json.dumps(tool_params, ensure_ascii=False)) > 100
else f" 参数: {json.dumps(tool_params, ensure_ascii=False)}"
result_info = { )
"内容": content,
"推理内容": reasoning_content, result_info = {"内容": content, "推理内容": reasoning_content, "工具调用": tool_calls}
"工具调用": tool_calls
}
else: else:
content, reasoning_content = response content, reasoning_content = response
print(" 工具调用请求未生成任何工具调用") print(" 工具调用请求未生成任何工具调用")
print("\n 生成的内容:") print("\n 生成的内容:")
print(f" {content[:200]}..." if len(content) > 200 else f" {content}") print(f" {content[:200]}..." if len(content) > 200 else f" {content}")
result_info = { result_info = {"内容": content, "推理内容": reasoning_content, "工具调用": []}
"内容": content,
"推理内容": reasoning_content,
"工具调用": []
}
return result_info return result_info
async def test_without_tool_calls(): async def test_without_tool_calls():
"""不使用工具调用的LLM请求测试""" """不使用工具调用的LLM请求测试"""
# 创建LLM模型实例 # 创建LLM模型实例
@@ -144,9 +141,9 @@ async def test_without_tool_calls():
max_tokens=800, max_tokens=800,
request_type="benchmark_test", request_type="benchmark_test",
) )
# 简单的测试提示词(与工具调用相同,以便公平比较) # 简单的测试提示词(与工具调用相同,以便公平比较)
prompt = ''' prompt = """
你的名字是麦麦,你包容开放,情绪敏感,有时候有些搞怪幽默, 是一个学习心理学和脑科学的女大学生,现在在读大二,你会刷贴吧,有时候会想瑟瑟,喜欢刷小红书 你的名字是麦麦,你包容开放,情绪敏感,有时候有些搞怪幽默, 是一个学习心理学和脑科学的女大学生,现在在读大二,你会刷贴吧,有时候会想瑟瑟,喜欢刷小红书
刚刚你的想法是: 刚刚你的想法是:
我是麦麦,我想,('小千石问3.8和3.11谁大已经简单回答了3.11大,现在可以继续聊猫猫头表情包,毕竟大家好像对版本问题兴趣不大,而且猫猫头的话题更轻松有趣。', '') 我是麦麦,我想,('小千石问3.8和3.11谁大已经简单回答了3.11大,现在可以继续聊猫猫头表情包,毕竟大家好像对版本问题兴趣不大,而且猫猫头的话题更轻松有趣。', '')
@@ -181,45 +178,42 @@ async def test_without_tool_calls():
回复的要求是:平淡一些,简短一些,说中文,如果你要回复,最好只回复一个人的一个话题 回复的要求是:平淡一些,简短一些,说中文,如果你要回复,最好只回复一个人的一个话题
请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写。不要回复自己的发言,尽量不要说你说过的话。 请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写。不要回复自己的发言,尽量不要说你说过的话。
现在请你继续生成你在这个聊天中的想法,在原来想法的基础上继续思考,不要分点输出,生成内心想法,文字不要浮夸 现在请你继续生成你在这个聊天中的想法,在原来想法的基础上继续思考,不要分点输出,生成内心想法,文字不要浮夸
在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。''' 在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。"""
# 发送不带工具调用的请求 # 发送不带工具调用的请求
response, reasoning_content = await llm_model.generate_response_async(prompt) response, reasoning_content = await llm_model.generate_response_async(prompt)
# 输出生成的内容 # 输出生成的内容
print("\n 生成的内容:") print("\n 生成的内容:")
print(f" {response[:200]}..." if len(response) > 200 else f" {response}") print(f" {response[:200]}..." if len(response) > 200 else f" {response}")
result_info = { result_info = {"内容": response, "推理内容": reasoning_content, "工具调用": []}
"内容": response,
"推理内容": reasoning_content,
"工具调用": []
}
return result_info return result_info
async def main(): async def main():
"""主测试函数""" """主测试函数"""
print("=" * 50) print("=" * 50)
print("LLM工具调用与普通请求性能比较测试") print("LLM工具调用与普通请求性能比较测试")
print("=" * 50) print("=" * 50)
# 设置测试迭代次数 # 设置测试迭代次数
iterations = 3 iterations = 3
# 测试不使用工具调用 # 测试不使用工具调用
results_without_tools = await run_test("不使用工具调用", test_without_tool_calls, iterations) results_without_tools = await run_test("不使用工具调用", test_without_tool_calls, iterations)
print("\n" + "-" * 50 + "\n") print("\n" + "-" * 50 + "\n")
# 测试使用工具调用 # 测试使用工具调用
results_with_tools = await run_test("使用工具调用", test_with_tool_calls, iterations) results_with_tools = await run_test("使用工具调用", test_with_tool_calls, iterations)
# 显示结果比较 # 显示结果比较
print("\n" + "=" * 50) print("\n" + "=" * 50)
print("测试结果比较") print("测试结果比较")
print("=" * 50) print("=" * 50)
print("\n不使用工具调用:") print("\n不使用工具调用:")
for key, value in results_without_tools.items(): for key, value in results_without_tools.items():
if key == "所有耗时": if key == "所有耗时":
@@ -228,7 +222,7 @@ async def main():
print(f" {key}: [内容已省略,详见结果文件]") print(f" {key}: [内容已省略,详见结果文件]")
else: else:
print(f" {key}: {value:.2f}") print(f" {key}: {value:.2f}")
print("\n使用工具调用:") print("\n使用工具调用:")
for key, value in results_with_tools.items(): for key, value in results_with_tools.items():
if key == "所有耗时": if key == "所有耗时":
@@ -239,29 +233,30 @@ async def main():
print(f" 工具调用数量: {tool_calls_counts}") print(f" 工具调用数量: {tool_calls_counts}")
else: else:
print(f" {key}: {value:.2f}") print(f" {key}: {value:.2f}")
# 计算差异百分比 # 计算差异百分比
diff_percent = ((results_with_tools["平均耗时"] / results_without_tools["平均耗时"]) - 1) * 100 diff_percent = ((results_with_tools["平均耗时"] / results_without_tools["平均耗时"]) - 1) * 100
print(f"\n工具调用比普通请求平均耗时相差: {diff_percent:.2f}%") print(f"\n工具调用比普通请求平均耗时相差: {diff_percent:.2f}%")
# 保存结果到JSON文件 # 保存结果到JSON文件
results = { results = {
"测试时间": time.strftime("%Y-%m-%d %H:%M:%S"), "测试时间": time.strftime("%Y-%m-%d %H:%M:%S"),
"测试迭代次数": iterations, "测试迭代次数": iterations,
"不使用工具调用": { "不使用工具调用": {
k: (v if k != "所有耗时" else [float(f"{t:.2f}") for t in v]) k: (v if k != "所有耗时" else [float(f"{t:.2f}") for t in v])
for k, v in results_without_tools.items() for k, v in results_without_tools.items()
if k != "响应结果" if k != "响应结果"
}, },
"不使用工具调用_详细响应": [ "不使用工具调用_详细响应": [
{ {
"内容摘要": resp["内容"][:200] + "..." if len(resp["内容"]) > 200 else resp["内容"], "内容摘要": resp["内容"][:200] + "..." if len(resp["内容"]) > 200 else resp["内容"],
"推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"] "推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"],
} for resp in results_without_tools["响应结果"] }
for resp in results_without_tools["响应结果"]
], ],
"使用工具调用": { "使用工具调用": {
k: (v if k != "所有耗时" else [float(f"{t:.2f}") for t in v]) k: (v if k != "所有耗时" else [float(f"{t:.2f}") for t in v])
for k, v in results_with_tools.items() for k, v in results_with_tools.items()
if k != "响应结果" if k != "响应结果"
}, },
"使用工具调用_详细响应": [ "使用工具调用_详细响应": [
@@ -270,20 +265,20 @@ async def main():
"推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"], "推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"],
"工具调用数量": len(resp["工具调用"]), "工具调用数量": len(resp["工具调用"]),
"工具调用详情": [ "工具调用详情": [
{ {"工具名称": tool["function"]["name"], "参数": tool["function"].get("arguments", {})}
"工具名称": tool["function"]["name"], for tool in resp["工具调用"]
"参数": tool["function"].get("arguments", {}) ],
} for tool in resp["工具调用"] }
] for resp in results_with_tools["响应结果"]
} for resp in results_with_tools["响应结果"]
], ],
"差异百分比": float(f"{diff_percent:.2f}") "差异百分比": float(f"{diff_percent:.2f}"),
} }
with open("llm_tool_benchmark_results.json", "w", encoding="utf-8") as f: with open("llm_tool_benchmark_results.json", "w", encoding="utf-8") as f:
json.dump(results, f, ensure_ascii=False, indent=2) json.dump(results, f, ensure_ascii=False, indent=2)
print(f"\n测试结果已保存到 llm_tool_benchmark_results.json") print("\n测试结果已保存到 llm_tool_benchmark_results.json")
if __name__ == "__main__": if __name__ == "__main__":
asyncio.run(main()) asyncio.run(main())