From 30756644806c4cafe5edb7e3dcabd321915d3b26 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Thu, 24 Apr 2025 14:19:26 +0800 Subject: [PATCH] fix:FFUF --- src/do_tool/tool_use.py | 4 +- src/heart_flow/observation.py | 14 +- src/heart_flow/sub_heartflow.py | 68 ++++----- src/heart_flow/subheartflow_manager.py | 6 +- src/plugins/chat/bot.py | 1 - src/plugins/chat/utils.py | 4 +- src/plugins/heartFC_chat/heartFC_chat.py | 38 +++-- src/plugins/heartFC_chat/heartFC_generator.py | 8 +- .../heartFC_chat/heartflow_prompt_builder.py | 13 +- src/plugins/models/utils_model.py | 26 ++-- src/plugins/utils/chat_message_builder.py | 4 +- src/plugins/utils/json_utils.py | 122 +++++++-------- tool_call_benchmark.py | 141 +++++++++--------- 13 files changed, 224 insertions(+), 225 deletions(-) diff --git a/src/do_tool/tool_use.py b/src/do_tool/tool_use.py index 1f625a586..8087cedab 100644 --- a/src/do_tool/tool_use.py +++ b/src/do_tool/tool_use.py @@ -159,7 +159,9 @@ class ToolUser: tool_calls_str = "" for tool_call in tool_calls: tool_calls_str += f"{tool_call['function']['name']}\n" - logger.info(f"根据:\n{prompt}\n\n内容:{content}\n\n模型请求调用{len(tool_calls)}个工具: {tool_calls_str}") + logger.info( + f"根据:\n{prompt}\n\n内容:{content}\n\n模型请求调用{len(tool_calls)}个工具: {tool_calls_str}" + ) tool_results = [] structured_info = {} # 动态生成键 diff --git a/src/heart_flow/observation.py b/src/heart_flow/observation.py index 0f61f6082..9391a660a 100644 --- a/src/heart_flow/observation.py +++ b/src/heart_flow/observation.py @@ -82,29 +82,25 @@ class ChattingObservation(Observation): new_messages_list = get_raw_msg_by_timestamp_with_chat( chat_id=self.chat_id, timestamp_start=self.last_observe_time, - timestamp_end=datetime.now().timestamp(), + timestamp_end=datetime.now().timestamp(), limit=self.max_now_obs_len, limit_mode="latest", ) - + last_obs_time_mark = self.last_observe_time if new_messages_list: self.last_observe_time = new_messages_list[-1]["time"] self.talking_message.extend(new_messages_list) - if len(self.talking_message) > self.max_now_obs_len: # 计算需要移除的消息数量,保留最新的 max_now_obs_len 条 messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len oldest_messages = self.talking_message[:messages_to_remove_count] self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的 - + oldest_messages_str = await build_readable_messages( - messages=oldest_messages, - timestamp_mode="normal", - read_mark=0 + messages=oldest_messages, timestamp_mode="normal", read_mark=0 ) - # 调用 LLM 总结主题 prompt = ( @@ -145,7 +141,7 @@ class ChattingObservation(Observation): messages=self.talking_message, timestamp_mode="normal", read_mark=last_obs_time_mark, - ) + ) logger.trace( f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}" diff --git a/src/heart_flow/sub_heartflow.py b/src/heart_flow/sub_heartflow.py index f0a448866..1aa6f9027 100644 --- a/src/heart_flow/sub_heartflow.py +++ b/src/heart_flow/sub_heartflow.py @@ -6,12 +6,10 @@ from src.config.config import global_config import time from typing import Optional, List, Dict, Callable import traceback -from src.plugins.chat.utils import parse_text_timestamps import enum from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402 from src.individuality.individuality import Individuality import random -from src.plugins.person_info.relationship_manager import relationship_manager from ..plugins.utils.prompt_builder import Prompt, global_prompt_manager from src.plugins.chat.message import MessageRecv from src.plugins.chat.chat_stream import chat_manager @@ -20,7 +18,7 @@ from src.plugins.heartFC_chat.heartFC_chat import HeartFChatting from src.plugins.heartFC_chat.normal_chat import NormalChat from src.do_tool.tool_use import ToolUser from src.heart_flow.mai_state_manager import MaiStateInfo -from src.plugins.utils.json_utils import safe_json_dumps, process_llm_tool_response, normalize_llm_response, process_llm_tool_calls +from src.plugins.utils.json_utils import safe_json_dumps, normalize_llm_response, process_llm_tool_calls # 定义常量 (从 interest.py 移动过来) MAX_INTEREST = 15.0 @@ -114,8 +112,6 @@ class InterestChatting: self.above_threshold = False self.start_hfc_probability = 0.0 - - def add_interest_dict(self, message: MessageRecv, interest_value: float, is_mentioned: bool): self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned) @@ -293,7 +289,7 @@ class SubHeartflow: ) self.log_prefix = chat_manager.get_stream_name(self.subheartflow_id) or self.subheartflow_id - + self.structured_info = {} async def add_time_current_state(self, add_time: float): @@ -484,36 +480,36 @@ class SubHeartflow: async def do_thinking_before_reply(self): """ 在回复前进行思考,生成内心想法并收集工具调用结果 - + 返回: tuple: (current_mind, past_mind) 当前想法和过去的想法列表 """ # 更新活跃时间 self.last_active_time = time.time() - + # ---------- 1. 准备基础数据 ---------- # 获取现有想法和情绪状态 current_thinking_info = self.current_mind mood_info = self.chat_state.mood - + # 获取观察对象 observation = self._get_primary_observation() if not observation: logger.error(f"[{self.subheartflow_id}] 无法获取观察对象") self.update_current_mind("(我没看到任何聊天内容...)") return self.current_mind, self.past_mind - + # 获取观察内容 chat_observe_info = observation.get_observe_info() - + # ---------- 2. 准备工具和个性化数据 ---------- # 初始化工具 tool_instance = ToolUser() tools = tool_instance._define_tools() - + # 获取个性化信息 individuality = Individuality.get_instance() - + # 构建个性部分 prompt_personality = f"你的名字是{individuality.personality.bot_nickname},你" prompt_personality += individuality.personality.personality_core @@ -547,9 +543,7 @@ class SubHeartflow: # 加权随机选择思考指导 hf_do_next = local_random.choices( - [option[0] for option in hf_options], - weights=[option[1] for option in hf_options], - k=1 + [option[0] for option in hf_options], weights=[option[1] for option in hf_options], k=1 )[0] # ---------- 4. 构建最终提示词 ---------- @@ -570,16 +564,16 @@ class SubHeartflow: # ---------- 5. 执行LLM请求并处理响应 ---------- content = "" # 初始化内容变量 reasoning_content = "" # 初始化推理内容变量 - + try: # 调用LLM生成响应 response = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools) - + # 标准化响应格式 success, normalized_response, error_msg = normalize_llm_response( response, log_prefix=f"[{self.subheartflow_id}] " ) - + if not success: # 处理标准化失败情况 logger.warning(f"[{self.subheartflow_id}] {error_msg}") @@ -588,23 +582,24 @@ class SubHeartflow: # 从标准化响应中提取内容 if len(normalized_response) >= 2: content = normalized_response[0] - reasoning_content = normalized_response[1] if len(normalized_response) > 1 else "" - + _reasoning_content = normalized_response[1] if len(normalized_response) > 1 else "" + # 处理可能的工具调用 if len(normalized_response) == 3: # 提取并验证工具调用 success, valid_tool_calls, error_msg = process_llm_tool_calls( normalized_response, log_prefix=f"[{self.subheartflow_id}] " ) - + if success and valid_tool_calls: # 记录工具调用信息 - tool_calls_str = ", ".join([ - call.get("function", {}).get("name", "未知工具") - for call in valid_tool_calls - ]) - logger.info(f"[{self.subheartflow_id}] 模型请求调用{len(valid_tool_calls)}个工具: {tool_calls_str}") - + tool_calls_str = ", ".join( + [call.get("function", {}).get("name", "未知工具") for call in valid_tool_calls] + ) + logger.info( + f"[{self.subheartflow_id}] 模型请求调用{len(valid_tool_calls)}个工具: {tool_calls_str}" + ) + # 收集工具执行结果 await self._execute_tool_calls(valid_tool_calls, tool_instance) elif not success: @@ -628,37 +623,34 @@ class SubHeartflow: self.update_current_mind(content) return self.current_mind, self.past_mind - + async def _execute_tool_calls(self, tool_calls, tool_instance): """ 执行一组工具调用并收集结果 - + 参数: tool_calls: 工具调用列表 tool_instance: 工具使用器实例 """ tool_results = [] structured_info = {} # 动态生成键 - + # 执行所有工具调用 for tool_call in tool_calls: try: result = await tool_instance._execute_tool_call(tool_call) if result: tool_results.append(result) - + # 使用工具名称作为键 tool_name = result["name"] if tool_name not in structured_info: structured_info[tool_name] = [] - - structured_info[tool_name].append({ - "name": result["name"], - "content": result["content"] - }) + + structured_info[tool_name].append({"name": result["name"], "content": result["content"]}) except Exception as tool_e: logger.error(f"[{self.subheartflow_id}] 工具执行失败: {tool_e}") - + # 如果有工具结果,记录并更新结构化信息 if structured_info: logger.debug(f"工具调用收集到结构化信息: {safe_json_dumps(structured_info, ensure_ascii=False)}") diff --git a/src/heart_flow/subheartflow_manager.py b/src/heart_flow/subheartflow_manager.py index 1e64027cb..bf473b781 100644 --- a/src/heart_flow/subheartflow_manager.py +++ b/src/heart_flow/subheartflow_manager.py @@ -290,9 +290,9 @@ class SubHeartflowManager: log_prefix_flow = f"[{stream_name}]" # 只处理 CHAT 状态的子心流 -# The code snippet is checking if the `chat_status` attribute of `sub_hf.chat_state` is not equal to -# `ChatState.CHAT`. If the condition is met, the code will continue to the next iteration of the loop -# or block of code where this snippet is located. + # The code snippet is checking if the `chat_status` attribute of `sub_hf.chat_state` is not equal to + # `ChatState.CHAT`. If the condition is met, the code will continue to the next iteration of the loop + # or block of code where this snippet is located. # if sub_hf.chat_state.chat_status != ChatState.CHAT: # continue diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 5c1ce6f81..b6584dcd3 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -78,7 +78,6 @@ class ChatBot: groupinfo = message.message_info.group_info userinfo = message.message_info.user_info - if userinfo.user_id in global_config.ban_user_id: logger.debug(f"用户{userinfo.user_id}被禁止回复") return diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 386d6ac7a..aed0025b8 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -328,7 +328,9 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: final_sentences = [content for content, sep in merged_segments if content] # 只保留有内容的段 # 清理可能引入的空字符串和仅包含空白的字符串 - final_sentences = [s for s in final_sentences if s.strip()] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串 + final_sentences = [ + s for s in final_sentences if s.strip() + ] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串 logger.debug(f"分割并合并后的句子: {final_sentences}") return final_sentences diff --git a/src/plugins/heartFC_chat/heartFC_chat.py b/src/plugins/heartFC_chat/heartFC_chat.py index 494ddeb09..41ea2711b 100644 --- a/src/plugins/heartFC_chat/heartFC_chat.py +++ b/src/plugins/heartFC_chat/heartFC_chat.py @@ -2,6 +2,7 @@ import asyncio import time import traceback from typing import List, Optional, Dict, Any, TYPE_CHECKING + # import json # 移除,因为使用了json_utils from src.plugins.chat.message import MessageRecv, BaseMessageInfo, MessageThinking, MessageSending from src.plugins.chat.message import MessageSet, Seg # Local import needed after move @@ -17,7 +18,7 @@ from src.plugins.heartFC_chat.heartFC_generator import HeartFCGenerator from src.do_tool.tool_use import ToolUser from ..chat.message_sender import message_manager # <-- Import the global manager from src.plugins.chat.emoji_manager import emoji_manager -from src.plugins.utils.json_utils import extract_tool_call_arguments, safe_json_dumps, process_llm_tool_response # 导入新的JSON工具 +from src.plugins.utils.json_utils import process_llm_tool_response # 导入新的JSON工具 # --- End import --- @@ -37,7 +38,7 @@ if TYPE_CHECKING: # Keep this if HeartFCController methods are still needed elsewhere, # but the instance variable will be removed from HeartFChatting # from .heartFC_controler import HeartFCController - from src.heart_flow.heartflow import SubHeartflow, heartflow # <-- 同时导入 heartflow 实例用于类型检查 + from src.heart_flow.heartflow import SubHeartflow # <-- 同时导入 heartflow 实例用于类型检查 PLANNER_TOOL_DEFINITION = [ { @@ -327,7 +328,6 @@ class HeartFChatting: with Timer("Wait New Msg", cycle_timers): # <--- Start Wait timer wait_start_time = time.monotonic() while True: - # 检查是否有新消息 has_new = await observation.has_new_messages_since(planner_start_db_time) if has_new: @@ -424,7 +424,7 @@ class HeartFChatting: observed_messages: List[dict] = [] current_mind: Optional[str] = None - llm_error = False + llm_error = False try: observation = self.sub_hf._get_primary_observation() @@ -434,19 +434,17 @@ class HeartFChatting: except Exception as e: logger.error(f"{log_prefix}[Planner] 获取观察信息时出错: {e}") - try: current_mind, _past_mind = await self.sub_hf.do_thinking_before_reply() except Exception as e_subhf: logger.error(f"{log_prefix}[Planner] SubHeartflow 思考失败: {e_subhf}") current_mind = "[思考时出错]" - # --- 使用 LLM 进行决策 --- # action = "no_reply" # 默认动作 - emoji_query = "" # 默认表情查询 - reasoning = "默认决策或获取决策失败" - llm_error = False # LLM错误标志 + emoji_query = "" # 默认表情查询 + reasoning = "默认决策或获取决策失败" + llm_error = False # LLM错误标志 try: prompt = await self._build_planner_prompt(observed_messages_str, current_mind, self.sub_hf.structured_info) @@ -475,21 +473,17 @@ class HeartFChatting: # 使用辅助函数处理工具调用响应 success, arguments, error_msg = process_llm_tool_response( - response, - expected_tool_name="decide_reply_action", - log_prefix=f"{log_prefix}[Planner] " + response, expected_tool_name="decide_reply_action", log_prefix=f"{log_prefix}[Planner] " ) - + if success: # 提取决策参数 action = arguments.get("action", "no_reply") reasoning = arguments.get("reasoning", "未提供理由") emoji_query = arguments.get("emoji_query", "") - + # 记录决策结果 - logger.debug( - f"{log_prefix}[Planner] 决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'" - ) + logger.debug(f"{log_prefix}[Planner] 决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'") else: # 处理工具调用失败 logger.warning(f"{log_prefix}[Planner] {error_msg}") @@ -584,7 +578,7 @@ class HeartFChatting: """优雅关闭HeartFChatting实例,取消活动循环任务""" log_prefix = self._get_log_prefix() logger.info(f"{log_prefix} 正在关闭HeartFChatting...") - + # 取消循环任务 if self._loop_task and not self._loop_task.done(): logger.info(f"{log_prefix} 正在取消HeartFChatting循环任务") @@ -605,17 +599,19 @@ class HeartFChatting: if self._processing_lock.locked(): self._processing_lock.release() logger.warning(f"{log_prefix} 已释放处理锁") - + logger.info(f"{log_prefix} HeartFChatting关闭完成") - async def _build_planner_prompt(self, observed_messages_str: str, current_mind: Optional[str], structured_info: Dict[str, Any]) -> str: + async def _build_planner_prompt( + self, observed_messages_str: str, current_mind: Optional[str], structured_info: Dict[str, Any] + ) -> str: """构建 Planner LLM 的提示词""" prompt = f"你的名字是 {global_config.BOT_NICKNAME}。是一个女大二学生,正在QQ聊天,正在决定是否以及如何回应当前的聊天。\n" if structured_info: prompt += f"以下是一些额外的信息:\n{structured_info}\n" - + if observed_messages_str: prompt += "观察到的最新聊天内容如下 (最近的消息在最后):\n---\n" prompt += observed_messages_str diff --git a/src/plugins/heartFC_chat/heartFC_generator.py b/src/plugins/heartFC_chat/heartFC_generator.py index 0ed6229e6..cbf050bd9 100644 --- a/src/plugins/heartFC_chat/heartFC_generator.py +++ b/src/plugins/heartFC_chat/heartFC_generator.py @@ -72,7 +72,13 @@ class HeartFCGenerator: return None async def _generate_response_with_model( - self, structured_info: str, current_mind_info: str, reason: str, message: MessageRecv, model: LLMRequest, thinking_id: str + self, + structured_info: str, + current_mind_info: str, + reason: str, + message: MessageRecv, + model: LLMRequest, + thinking_id: str, ) -> str: sender_name = "" diff --git a/src/plugins/heartFC_chat/heartflow_prompt_builder.py b/src/plugins/heartFC_chat/heartflow_prompt_builder.py index 33baad371..c5b04ed93 100644 --- a/src/plugins/heartFC_chat/heartflow_prompt_builder.py +++ b/src/plugins/heartFC_chat/heartflow_prompt_builder.py @@ -81,13 +81,22 @@ class PromptBuilder: self.activate_messages = "" async def build_prompt( - self, build_mode, reason, current_mind_info, structured_info, message_txt: str, sender_name: str = "某人", chat_stream=None + self, + build_mode, + reason, + current_mind_info, + structured_info, + message_txt: str, + sender_name: str = "某人", + chat_stream=None, ) -> Optional[tuple[str, str]]: if build_mode == "normal": return await self._build_prompt_normal(chat_stream, message_txt, sender_name) elif build_mode == "focus": - return await self._build_prompt_focus(reason, current_mind_info, structured_info, chat_stream, message_txt, sender_name) + return await self._build_prompt_focus( + reason, current_mind_info, structured_info, chat_stream, message_txt, sender_name + ) return None async def _build_prompt_focus( diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index bdc408aba..2cab7b629 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -711,7 +711,7 @@ class LLMRequest: reasoning_content = "" content = "" tool_calls = None # 初始化工具调用变量 - + async for line_bytes in response.content: try: line = line_bytes.decode("utf-8").strip() @@ -733,7 +733,7 @@ class LLMRequest: if delta_content is None: delta_content = "" accumulated_content += delta_content - + # 提取工具调用信息 if "tool_calls" in delta: if tool_calls is None: @@ -741,7 +741,7 @@ class LLMRequest: else: # 合并工具调用信息 tool_calls.extend(delta["tool_calls"]) - + # 检测流式输出文本是否结束 finish_reason = chunk["choices"][0].get("finish_reason") if delta.get("reasoning_content", None): @@ -774,23 +774,19 @@ class LLMRequest: if think_match: reasoning_content = think_match.group(1).strip() content = re.sub(r".*?", "", content, flags=re.DOTALL).strip() - + # 构建消息对象 message = { "content": content, "reasoning_content": reasoning_content, } - + # 如果有工具调用,添加到消息中 if tool_calls: message["tool_calls"] = tool_calls - + result = { - "choices": [ - { - "message": message - } - ], + "choices": [{"message": message}], "usage": usage, } return result @@ -1128,9 +1124,9 @@ class LLMRequest: response = await self._execute_request(endpoint="/chat/completions", payload=data, prompt=prompt) # 原样返回响应,不做处理 - + return response - + async def generate_response_tool_async(self, prompt: str, tools: list, **kwargs) -> Union[str, Tuple]: """异步方式根据输入的提示生成模型的响应""" # 构建请求体,不硬编码max_tokens @@ -1139,7 +1135,7 @@ class LLMRequest: "messages": [{"role": "user", "content": prompt}], **self.params, **kwargs, - "tools": tools + "tools": tools, } logger.debug(f"向模型 {self.model_name} 发送工具调用请求,包含 {len(tools)} 个工具") @@ -1150,7 +1146,7 @@ class LLMRequest: logger.debug(f"收到工具调用响应,包含 {len(tool_calls) if tool_calls else 0} 个工具调用") return content, reasoning_content, tool_calls else: - logger.debug(f"收到普通响应,无工具调用") + logger.debug("收到普通响应,无工具调用") return response async def get_embedding(self, text: str) -> Union[list, None]: diff --git a/src/plugins/utils/chat_message_builder.py b/src/plugins/utils/chat_message_builder.py index 6a5e4e8e1..6ae6ccc32 100644 --- a/src/plugins/utils/chat_message_builder.py +++ b/src/plugins/utils/chat_message_builder.py @@ -303,7 +303,9 @@ async def build_readable_messages( ) readable_read_mark = translate_timestamp_to_human_readable(read_mark, mode=timestamp_mode) - read_mark_line = f"\n\n--- 以上消息已读 (标记时间: {readable_read_mark}) ---\n--- 请关注你上次思考之后以下的新消息---\n" + read_mark_line = ( + f"\n\n--- 以上消息已读 (标记时间: {readable_read_mark}) ---\n--- 请关注你上次思考之后以下的新消息---\n" + ) # 组合结果,确保空部分不引入多余的标记或换行 if formatted_before and formatted_after: diff --git a/src/plugins/utils/json_utils.py b/src/plugins/utils/json_utils.py index 962901b55..bf4b08398 100644 --- a/src/plugins/utils/json_utils.py +++ b/src/plugins/utils/json_utils.py @@ -1,27 +1,28 @@ import json import logging -from typing import Any, Dict, Optional, TypeVar, Generic, List, Union, Callable, Tuple +from typing import Any, Dict, TypeVar, List, Union, Callable, Tuple # 定义类型变量用于泛型类型提示 -T = TypeVar('T') +T = TypeVar("T") # 获取logger logger = logging.getLogger("json_utils") + def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]: """ 安全地解析JSON字符串,出错时返回默认值 - + 参数: json_str: 要解析的JSON字符串 default_value: 解析失败时返回的默认值 - + 返回: 解析后的Python对象,或在解析失败时返回default_value """ if not json_str: return default_value - + try: return json.loads(json_str) except json.JSONDecodeError as e: @@ -31,66 +32,67 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]: logger.error(f"JSON解析过程中发生意外错误: {e}") return default_value -def extract_tool_call_arguments(tool_call: Dict[str, Any], - default_value: Dict[str, Any] = None) -> Dict[str, Any]: + +def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]: """ 从LLM工具调用对象中提取参数 - + 参数: tool_call: 工具调用对象字典 default_value: 解析失败时返回的默认值 - + 返回: 解析后的参数字典,或在解析失败时返回default_value """ default_result = default_value or {} - + if not tool_call or not isinstance(tool_call, dict): logger.error(f"无效的工具调用对象: {tool_call}") return default_result - + try: # 提取function参数 function_data = tool_call.get("function", {}) if not function_data or not isinstance(function_data, dict): logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}") return default_result - + # 提取arguments arguments_str = function_data.get("arguments", "{}") if not arguments_str: return default_result - + # 解析JSON return safe_json_loads(arguments_str, default_result) - + except Exception as e: logger.error(f"提取工具调用参数时出错: {e}") return default_result -def get_json_value(json_obj: Dict[str, Any], key_path: str, - default_value: T = None, - transform_func: Callable[[Any], T] = None) -> Union[Any, T]: + +def get_json_value( + json_obj: Dict[str, Any], key_path: str, default_value: T = None, transform_func: Callable[[Any], T] = None +) -> Union[Any, T]: """ 从JSON对象中按照路径提取值,支持点表示法路径,如"data.items.0.name" - + 参数: json_obj: JSON对象(已解析的字典) key_path: 键路径,使用点表示法,如"data.items.0.name" default_value: 获取失败时返回的默认值 transform_func: 可选的转换函数,用于对获取的值进行转换 - + 返回: 路径指向的值,或在获取失败时返回default_value """ if not json_obj or not key_path: return default_value - + try: # 分割路径 keys = key_path.split(".") current = json_obj - + # 遍历路径 for key in keys: # 处理数组索引 @@ -108,7 +110,7 @@ def get_json_value(json_obj: Dict[str, Any], key_path: str, return default_value else: return default_value - + # 应用转换函数(如果提供) if transform_func and current is not None: return transform_func(current) @@ -117,17 +119,17 @@ def get_json_value(json_obj: Dict[str, Any], key_path: str, logger.error(f"从JSON获取值时出错: {e}, 路径: {key_path}") return default_value -def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, - pretty: bool = False) -> str: + +def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str: """ 安全地将Python对象序列化为JSON字符串 - + 参数: obj: 要序列化的Python对象 default_value: 序列化失败时返回的默认值 ensure_ascii: 是否确保ASCII编码(默认False,允许中文等非ASCII字符) pretty: 是否美化输出JSON - + 返回: 序列化后的JSON字符串,或在序列化失败时返回default_value """ @@ -141,13 +143,14 @@ def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = Fa logger.error(f"JSON序列化过程中发生意外错误: {e}") return default_value + def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]: """ 合并多个JSON对象(字典) - + 参数: *objects: 要合并的JSON对象(字典) - + 返回: 合并后的字典,后面的对象会覆盖前面对象的相同键 """ @@ -157,109 +160,110 @@ def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]: result.update(obj) return result + def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]: """ 标准化LLM响应格式,将各种格式(如元组)转换为统一的列表格式 - + 参数: response: 原始LLM响应 log_prefix: 日志前缀 - + 返回: 元组 (成功标志, 标准化后的响应列表, 错误消息) """ # 检查是否为None if response is None: return False, [], "LLM响应为None" - + # 记录原始类型 logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}") - + # 将元组转换为列表 if isinstance(response, tuple): logger.debug(f"{log_prefix}将元组响应转换为列表") response = list(response) - + # 确保是列表类型 if not isinstance(response, list): return False, [], f"无法处理的LLM响应类型: {type(response).__name__}" - + # 处理工具调用部分(如果存在) if len(response) == 3: content, reasoning, tool_calls = response - + # 将工具调用部分转换为列表(如果是元组) if isinstance(tool_calls, tuple): logger.debug(f"{log_prefix}将工具调用元组转换为列表") tool_calls = list(tool_calls) response[2] = tool_calls - + return True, response, "" + def process_llm_tool_calls(response: List[Any], log_prefix: str = "") -> Tuple[bool, List[Dict[str, Any]], str]: """ 处理并提取LLM响应中的工具调用列表 - + 参数: response: 标准化后的LLM响应列表 log_prefix: 日志前缀 - + 返回: 元组 (成功标志, 工具调用列表, 错误消息) """ # 确保响应格式正确 if len(response) != 3: return False, [], f"LLM响应元素数量不正确: 预期3个元素,实际{len(response)}个" - + # 提取工具调用部分 tool_calls = response[2] - + # 检查工具调用是否有效 if tool_calls is None: return False, [], "工具调用部分为None" - + if not isinstance(tool_calls, list): return False, [], f"工具调用部分不是列表: {type(tool_calls).__name__}" - + if len(tool_calls) == 0: return False, [], "工具调用列表为空" - + # 检查工具调用是否格式正确 valid_tool_calls = [] for i, tool_call in enumerate(tool_calls): if not isinstance(tool_call, dict): logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}") continue - + if tool_call.get("type") != "function": logger.warning(f"{log_prefix}工具调用[{i}]不是函数类型: {tool_call.get('type', '未知')}") continue - + if "function" not in tool_call or not isinstance(tool_call["function"], dict): logger.warning(f"{log_prefix}工具调用[{i}]缺少function字段或格式不正确") continue - + valid_tool_calls.append(tool_call) - + # 检查是否有有效的工具调用 if not valid_tool_calls: return False, [], "没有找到有效的工具调用" - + return True, valid_tool_calls, "" + def process_llm_tool_response( - response: Any, - expected_tool_name: str = None, - log_prefix: str = "" + response: Any, expected_tool_name: str = None, log_prefix: str = "" ) -> Tuple[bool, Dict[str, Any], str]: """ 处理LLM返回的工具调用响应,进行常见错误检查并提取参数 - + 参数: response: LLM的响应,预期是[content, reasoning, tool_calls]格式的列表或元组 expected_tool_name: 预期的工具名称,如不指定则不检查 log_prefix: 日志前缀,用于标识日志来源 - + 返回: 三元组(成功标志, 参数字典, 错误描述) - 如果成功解析,返回(True, 参数字典, "") @@ -269,29 +273,29 @@ def process_llm_tool_response( success, normalized_response, error_msg = normalize_llm_response(response, log_prefix) if not success: return False, {}, error_msg - + # 使用新的工具调用处理函数 success, valid_tool_calls, error_msg = process_llm_tool_calls(normalized_response, log_prefix) if not success: return False, {}, error_msg - + # 检查是否有工具调用 if not valid_tool_calls: return False, {}, "没有有效的工具调用" - + # 获取第一个工具调用 tool_call = valid_tool_calls[0] - + # 检查工具名称(如果提供了预期名称) if expected_tool_name: actual_name = tool_call.get("function", {}).get("name") if actual_name != expected_tool_name: return False, {}, f"工具名称不匹配: 预期'{expected_tool_name}',实际'{actual_name}'" - + # 提取并解析参数 try: arguments = extract_tool_call_arguments(tool_call, {}) return True, arguments, "" except Exception as e: logger.error(f"{log_prefix}解析工具参数时出错: {e}") - return False, {}, f"解析参数失败: {str(e)}" \ No newline at end of file + return False, {}, f"解析参数失败: {str(e)}" diff --git a/tool_call_benchmark.py b/tool_call_benchmark.py index 691aeb7c5..e756d1da3 100644 --- a/tool_call_benchmark.py +++ b/tool_call_benchmark.py @@ -6,24 +6,25 @@ from src.do_tool.tool_use import ToolUser import statistics import json + async def run_test(test_name, test_function, iterations=5): """ 运行指定次数的测试并计算平均响应时间 - + 参数: test_name: 测试名称 test_function: 要执行的测试函数 iterations: 测试迭代次数 - + 返回: 测试结果统计 """ print(f"开始 {test_name} 测试({iterations}次迭代)...") times = [] responses = [] - + for i in range(iterations): - print(f" 运行第 {i+1}/{iterations} 次测试...") + print(f" 运行第 {i + 1}/{iterations} 次测试...") start_time = time.time() response = await test_function() end_time = time.time() @@ -31,18 +32,19 @@ async def run_test(test_name, test_function, iterations=5): times.append(elapsed) responses.append(response) print(f" - 耗时: {elapsed:.2f}秒") - + results = { "平均耗时": statistics.mean(times), "最短耗时": min(times), "最长耗时": max(times), "标准差": statistics.stdev(times) if len(times) > 1 else 0, "所有耗时": times, - "响应结果": responses + "响应结果": responses, } - + return results + async def test_with_tool_calls(): """使用工具调用的LLM请求测试""" # 创建LLM模型实例 @@ -53,14 +55,14 @@ async def test_with_tool_calls(): max_tokens=800, request_type="benchmark_test", ) - + # 创建工具实例 tool_instance = ToolUser() tools = tool_instance._define_tools() - + # 简单的测试提示词 prompt = "请分析当前天气情况,并查询今日历史上的重要事件。并且3.9和3.11谁比较大?请使用适当的工具来获取这些信息。" - prompt = ''' + prompt = """ 你的名字是麦麦,你包容开放,情绪敏感,有时候有些搞怪幽默, 是一个学习心理学和脑科学的女大学生,现在在读大二,你会刷贴吧,有时候会想瑟瑟,喜欢刷小红书 ----------------------------------- 现在是2025-04-24 12:37:00,你正在上网,和qq群里的网友们聊天,群里正在聊的话题是: @@ -89,52 +91,47 @@ async def test_with_tool_calls(): 回复的要求是:平淡一些,简短一些,说中文,如果你要回复,最好只回复一个人的一个话题 请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写。不要回复自己的发言,尽量不要说你说过的话。 现在请你继续生成你在这个聊天中的想法,在原来想法的基础上继续思考,不要分点输出,生成内心想法,文字不要浮夸 -在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。''' - +在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。""" + # 发送带有工具调用的请求 response = await llm_model.generate_response_tool_async(prompt=prompt, tools=tools) - + result_info = {} - + # 简单处理工具调用结果 if len(response) == 3: content, reasoning_content, tool_calls = response tool_calls_count = len(tool_calls) if tool_calls else 0 print(f" 工具调用请求生成了 {tool_calls_count} 个工具调用") - + # 输出内容和工具调用详情 print("\n 生成的内容:") print(f" {content[:200]}..." if len(content) > 200 else f" {content}") - + if tool_calls: print("\n 工具调用详情:") for i, tool_call in enumerate(tool_calls): - tool_name = tool_call['function']['name'] - tool_params = tool_call['function'].get('arguments', {}) - print(f" - 工具 {i+1}: {tool_name}") - print(f" 参数: {json.dumps(tool_params, ensure_ascii=False)[:100]}..." - if len(json.dumps(tool_params, ensure_ascii=False)) > 100 - else f" 参数: {json.dumps(tool_params, ensure_ascii=False)}") - - result_info = { - "内容": content, - "推理内容": reasoning_content, - "工具调用": tool_calls - } + tool_name = tool_call["function"]["name"] + tool_params = tool_call["function"].get("arguments", {}) + print(f" - 工具 {i + 1}: {tool_name}") + print( + f" 参数: {json.dumps(tool_params, ensure_ascii=False)[:100]}..." + if len(json.dumps(tool_params, ensure_ascii=False)) > 100 + else f" 参数: {json.dumps(tool_params, ensure_ascii=False)}" + ) + + result_info = {"内容": content, "推理内容": reasoning_content, "工具调用": tool_calls} else: content, reasoning_content = response print(" 工具调用请求未生成任何工具调用") print("\n 生成的内容:") print(f" {content[:200]}..." if len(content) > 200 else f" {content}") - - result_info = { - "内容": content, - "推理内容": reasoning_content, - "工具调用": [] - } - + + result_info = {"内容": content, "推理内容": reasoning_content, "工具调用": []} + return result_info + async def test_without_tool_calls(): """不使用工具调用的LLM请求测试""" # 创建LLM模型实例 @@ -144,9 +141,9 @@ async def test_without_tool_calls(): max_tokens=800, request_type="benchmark_test", ) - + # 简单的测试提示词(与工具调用相同,以便公平比较) - prompt = ''' + prompt = """ 你的名字是麦麦,你包容开放,情绪敏感,有时候有些搞怪幽默, 是一个学习心理学和脑科学的女大学生,现在在读大二,你会刷贴吧,有时候会想瑟瑟,喜欢刷小红书 刚刚你的想法是: 我是麦麦,我想,('小千石问3.8和3.11谁大,已经简单回答了3.11大,现在可以继续聊猫猫头表情包,毕竟大家好像对版本问题兴趣不大,而且猫猫头的话题更轻松有趣。', '') @@ -181,45 +178,42 @@ async def test_without_tool_calls(): 回复的要求是:平淡一些,简短一些,说中文,如果你要回复,最好只回复一个人的一个话题 请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写。不要回复自己的发言,尽量不要说你说过的话。 现在请你继续生成你在这个聊天中的想法,在原来想法的基础上继续思考,不要分点输出,生成内心想法,文字不要浮夸 -在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。''' - +在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。""" + # 发送不带工具调用的请求 response, reasoning_content = await llm_model.generate_response_async(prompt) - + # 输出生成的内容 print("\n 生成的内容:") print(f" {response[:200]}..." if len(response) > 200 else f" {response}") - - result_info = { - "内容": response, - "推理内容": reasoning_content, - "工具调用": [] - } - + + result_info = {"内容": response, "推理内容": reasoning_content, "工具调用": []} + return result_info + async def main(): """主测试函数""" print("=" * 50) print("LLM工具调用与普通请求性能比较测试") print("=" * 50) - + # 设置测试迭代次数 iterations = 3 - + # 测试不使用工具调用 results_without_tools = await run_test("不使用工具调用", test_without_tool_calls, iterations) - + print("\n" + "-" * 50 + "\n") - + # 测试使用工具调用 results_with_tools = await run_test("使用工具调用", test_with_tool_calls, iterations) - + # 显示结果比较 print("\n" + "=" * 50) print("测试结果比较") print("=" * 50) - + print("\n不使用工具调用:") for key, value in results_without_tools.items(): if key == "所有耗时": @@ -228,7 +222,7 @@ async def main(): print(f" {key}: [内容已省略,详见结果文件]") else: print(f" {key}: {value:.2f}秒") - + print("\n使用工具调用:") for key, value in results_with_tools.items(): if key == "所有耗时": @@ -239,29 +233,30 @@ async def main(): print(f" 工具调用数量: {tool_calls_counts}") else: print(f" {key}: {value:.2f}秒") - + # 计算差异百分比 diff_percent = ((results_with_tools["平均耗时"] / results_without_tools["平均耗时"]) - 1) * 100 print(f"\n工具调用比普通请求平均耗时相差: {diff_percent:.2f}%") - + # 保存结果到JSON文件 results = { "测试时间": time.strftime("%Y-%m-%d %H:%M:%S"), "测试迭代次数": iterations, "不使用工具调用": { - k: (v if k != "所有耗时" else [float(f"{t:.2f}") for t in v]) - for k, v in results_without_tools.items() + k: (v if k != "所有耗时" else [float(f"{t:.2f}") for t in v]) + for k, v in results_without_tools.items() if k != "响应结果" }, "不使用工具调用_详细响应": [ { "内容摘要": resp["内容"][:200] + "..." if len(resp["内容"]) > 200 else resp["内容"], - "推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"] - } for resp in results_without_tools["响应结果"] + "推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"], + } + for resp in results_without_tools["响应结果"] ], "使用工具调用": { - k: (v if k != "所有耗时" else [float(f"{t:.2f}") for t in v]) - for k, v in results_with_tools.items() + k: (v if k != "所有耗时" else [float(f"{t:.2f}") for t in v]) + for k, v in results_with_tools.items() if k != "响应结果" }, "使用工具调用_详细响应": [ @@ -270,20 +265,20 @@ async def main(): "推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"], "工具调用数量": len(resp["工具调用"]), "工具调用详情": [ - { - "工具名称": tool["function"]["name"], - "参数": tool["function"].get("arguments", {}) - } for tool in resp["工具调用"] - ] - } for resp in results_with_tools["响应结果"] + {"工具名称": tool["function"]["name"], "参数": tool["function"].get("arguments", {})} + for tool in resp["工具调用"] + ], + } + for resp in results_with_tools["响应结果"] ], - "差异百分比": float(f"{diff_percent:.2f}") + "差异百分比": float(f"{diff_percent:.2f}"), } - + with open("llm_tool_benchmark_results.json", "w", encoding="utf-8") as f: json.dump(results, f, ensure_ascii=False, indent=2) - - print(f"\n测试结果已保存到 llm_tool_benchmark_results.json") + + print("\n测试结果已保存到 llm_tool_benchmark_results.json") + if __name__ == "__main__": - asyncio.run(main()) \ No newline at end of file + asyncio.run(main())