fix:FFUF
This commit is contained in:
@@ -159,7 +159,9 @@ class ToolUser:
|
||||
tool_calls_str = ""
|
||||
for tool_call in tool_calls:
|
||||
tool_calls_str += f"{tool_call['function']['name']}\n"
|
||||
logger.info(f"根据:\n{prompt}\n\n内容:{content}\n\n模型请求调用{len(tool_calls)}个工具: {tool_calls_str}")
|
||||
logger.info(
|
||||
f"根据:\n{prompt}\n\n内容:{content}\n\n模型请求调用{len(tool_calls)}个工具: {tool_calls_str}"
|
||||
)
|
||||
tool_results = []
|
||||
structured_info = {} # 动态生成键
|
||||
|
||||
|
||||
@@ -92,7 +92,6 @@ class ChattingObservation(Observation):
|
||||
self.last_observe_time = new_messages_list[-1]["time"]
|
||||
self.talking_message.extend(new_messages_list)
|
||||
|
||||
|
||||
if len(self.talking_message) > self.max_now_obs_len:
|
||||
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
|
||||
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
|
||||
@@ -100,12 +99,9 @@ class ChattingObservation(Observation):
|
||||
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
|
||||
|
||||
oldest_messages_str = await build_readable_messages(
|
||||
messages=oldest_messages,
|
||||
timestamp_mode="normal",
|
||||
read_mark=0
|
||||
messages=oldest_messages, timestamp_mode="normal", read_mark=0
|
||||
)
|
||||
|
||||
|
||||
# 调用 LLM 总结主题
|
||||
prompt = (
|
||||
f"请总结以下聊天记录的主题:\n{oldest_messages_str}\n用一句话概括包括人物事件和主要信息,不要分点:"
|
||||
@@ -145,7 +141,7 @@ class ChattingObservation(Observation):
|
||||
messages=self.talking_message,
|
||||
timestamp_mode="normal",
|
||||
read_mark=last_obs_time_mark,
|
||||
)
|
||||
)
|
||||
|
||||
logger.trace(
|
||||
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
|
||||
|
||||
@@ -6,12 +6,10 @@ from src.config.config import global_config
|
||||
import time
|
||||
from typing import Optional, List, Dict, Callable
|
||||
import traceback
|
||||
from src.plugins.chat.utils import parse_text_timestamps
|
||||
import enum
|
||||
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
|
||||
from src.individuality.individuality import Individuality
|
||||
import random
|
||||
from src.plugins.person_info.relationship_manager import relationship_manager
|
||||
from ..plugins.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.plugins.chat.message import MessageRecv
|
||||
from src.plugins.chat.chat_stream import chat_manager
|
||||
@@ -20,7 +18,7 @@ from src.plugins.heartFC_chat.heartFC_chat import HeartFChatting
|
||||
from src.plugins.heartFC_chat.normal_chat import NormalChat
|
||||
from src.do_tool.tool_use import ToolUser
|
||||
from src.heart_flow.mai_state_manager import MaiStateInfo
|
||||
from src.plugins.utils.json_utils import safe_json_dumps, process_llm_tool_response, normalize_llm_response, process_llm_tool_calls
|
||||
from src.plugins.utils.json_utils import safe_json_dumps, normalize_llm_response, process_llm_tool_calls
|
||||
|
||||
# 定义常量 (从 interest.py 移动过来)
|
||||
MAX_INTEREST = 15.0
|
||||
@@ -115,8 +113,6 @@ class InterestChatting:
|
||||
self.above_threshold = False
|
||||
self.start_hfc_probability = 0.0
|
||||
|
||||
|
||||
|
||||
def add_interest_dict(self, message: MessageRecv, interest_value: float, is_mentioned: bool):
|
||||
self.interest_dict[message.message_info.message_id] = (message, interest_value, is_mentioned)
|
||||
self.last_interaction_time = time.time()
|
||||
@@ -547,9 +543,7 @@ class SubHeartflow:
|
||||
|
||||
# 加权随机选择思考指导
|
||||
hf_do_next = local_random.choices(
|
||||
[option[0] for option in hf_options],
|
||||
weights=[option[1] for option in hf_options],
|
||||
k=1
|
||||
[option[0] for option in hf_options], weights=[option[1] for option in hf_options], k=1
|
||||
)[0]
|
||||
|
||||
# ---------- 4. 构建最终提示词 ----------
|
||||
@@ -588,7 +582,7 @@ class SubHeartflow:
|
||||
# 从标准化响应中提取内容
|
||||
if len(normalized_response) >= 2:
|
||||
content = normalized_response[0]
|
||||
reasoning_content = normalized_response[1] if len(normalized_response) > 1 else ""
|
||||
_reasoning_content = normalized_response[1] if len(normalized_response) > 1 else ""
|
||||
|
||||
# 处理可能的工具调用
|
||||
if len(normalized_response) == 3:
|
||||
@@ -599,11 +593,12 @@ class SubHeartflow:
|
||||
|
||||
if success and valid_tool_calls:
|
||||
# 记录工具调用信息
|
||||
tool_calls_str = ", ".join([
|
||||
call.get("function", {}).get("name", "未知工具")
|
||||
for call in valid_tool_calls
|
||||
])
|
||||
logger.info(f"[{self.subheartflow_id}] 模型请求调用{len(valid_tool_calls)}个工具: {tool_calls_str}")
|
||||
tool_calls_str = ", ".join(
|
||||
[call.get("function", {}).get("name", "未知工具") for call in valid_tool_calls]
|
||||
)
|
||||
logger.info(
|
||||
f"[{self.subheartflow_id}] 模型请求调用{len(valid_tool_calls)}个工具: {tool_calls_str}"
|
||||
)
|
||||
|
||||
# 收集工具执行结果
|
||||
await self._execute_tool_calls(valid_tool_calls, tool_instance)
|
||||
@@ -652,10 +647,7 @@ class SubHeartflow:
|
||||
if tool_name not in structured_info:
|
||||
structured_info[tool_name] = []
|
||||
|
||||
structured_info[tool_name].append({
|
||||
"name": result["name"],
|
||||
"content": result["content"]
|
||||
})
|
||||
structured_info[tool_name].append({"name": result["name"], "content": result["content"]})
|
||||
except Exception as tool_e:
|
||||
logger.error(f"[{self.subheartflow_id}] 工具执行失败: {tool_e}")
|
||||
|
||||
|
||||
@@ -290,9 +290,9 @@ class SubHeartflowManager:
|
||||
log_prefix_flow = f"[{stream_name}]"
|
||||
|
||||
# 只处理 CHAT 状态的子心流
|
||||
# The code snippet is checking if the `chat_status` attribute of `sub_hf.chat_state` is not equal to
|
||||
# `ChatState.CHAT`. If the condition is met, the code will continue to the next iteration of the loop
|
||||
# or block of code where this snippet is located.
|
||||
# The code snippet is checking if the `chat_status` attribute of `sub_hf.chat_state` is not equal to
|
||||
# `ChatState.CHAT`. If the condition is met, the code will continue to the next iteration of the loop
|
||||
# or block of code where this snippet is located.
|
||||
# if sub_hf.chat_state.chat_status != ChatState.CHAT:
|
||||
# continue
|
||||
|
||||
|
||||
@@ -78,7 +78,6 @@ class ChatBot:
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
|
||||
|
||||
if userinfo.user_id in global_config.ban_user_id:
|
||||
logger.debug(f"用户{userinfo.user_id}被禁止回复")
|
||||
return
|
||||
|
||||
@@ -328,7 +328,9 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
||||
final_sentences = [content for content, sep in merged_segments if content] # 只保留有内容的段
|
||||
|
||||
# 清理可能引入的空字符串和仅包含空白的字符串
|
||||
final_sentences = [s for s in final_sentences if s.strip()] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串
|
||||
final_sentences = [
|
||||
s for s in final_sentences if s.strip()
|
||||
] # 过滤掉空字符串以及仅包含空白(如换行符、空格)的字符串
|
||||
|
||||
logger.debug(f"分割并合并后的句子: {final_sentences}")
|
||||
return final_sentences
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import time
|
||||
import traceback
|
||||
from typing import List, Optional, Dict, Any, TYPE_CHECKING
|
||||
|
||||
# import json # 移除,因为使用了json_utils
|
||||
from src.plugins.chat.message import MessageRecv, BaseMessageInfo, MessageThinking, MessageSending
|
||||
from src.plugins.chat.message import MessageSet, Seg # Local import needed after move
|
||||
@@ -17,7 +18,7 @@ from src.plugins.heartFC_chat.heartFC_generator import HeartFCGenerator
|
||||
from src.do_tool.tool_use import ToolUser
|
||||
from ..chat.message_sender import message_manager # <-- Import the global manager
|
||||
from src.plugins.chat.emoji_manager import emoji_manager
|
||||
from src.plugins.utils.json_utils import extract_tool_call_arguments, safe_json_dumps, process_llm_tool_response # 导入新的JSON工具
|
||||
from src.plugins.utils.json_utils import process_llm_tool_response # 导入新的JSON工具
|
||||
# --- End import ---
|
||||
|
||||
|
||||
@@ -37,7 +38,7 @@ if TYPE_CHECKING:
|
||||
# Keep this if HeartFCController methods are still needed elsewhere,
|
||||
# but the instance variable will be removed from HeartFChatting
|
||||
# from .heartFC_controler import HeartFCController
|
||||
from src.heart_flow.heartflow import SubHeartflow, heartflow # <-- 同时导入 heartflow 实例用于类型检查
|
||||
from src.heart_flow.heartflow import SubHeartflow # <-- 同时导入 heartflow 实例用于类型检查
|
||||
|
||||
PLANNER_TOOL_DEFINITION = [
|
||||
{
|
||||
@@ -327,7 +328,6 @@ class HeartFChatting:
|
||||
with Timer("Wait New Msg", cycle_timers): # <--- Start Wait timer
|
||||
wait_start_time = time.monotonic()
|
||||
while True:
|
||||
|
||||
# 检查是否有新消息
|
||||
has_new = await observation.has_new_messages_since(planner_start_db_time)
|
||||
if has_new:
|
||||
@@ -434,19 +434,17 @@ class HeartFChatting:
|
||||
except Exception as e:
|
||||
logger.error(f"{log_prefix}[Planner] 获取观察信息时出错: {e}")
|
||||
|
||||
|
||||
try:
|
||||
current_mind, _past_mind = await self.sub_hf.do_thinking_before_reply()
|
||||
except Exception as e_subhf:
|
||||
logger.error(f"{log_prefix}[Planner] SubHeartflow 思考失败: {e_subhf}")
|
||||
current_mind = "[思考时出错]"
|
||||
|
||||
|
||||
# --- 使用 LLM 进行决策 --- #
|
||||
action = "no_reply" # 默认动作
|
||||
emoji_query = "" # 默认表情查询
|
||||
emoji_query = "" # 默认表情查询
|
||||
reasoning = "默认决策或获取决策失败"
|
||||
llm_error = False # LLM错误标志
|
||||
llm_error = False # LLM错误标志
|
||||
|
||||
try:
|
||||
prompt = await self._build_planner_prompt(observed_messages_str, current_mind, self.sub_hf.structured_info)
|
||||
@@ -475,9 +473,7 @@ class HeartFChatting:
|
||||
|
||||
# 使用辅助函数处理工具调用响应
|
||||
success, arguments, error_msg = process_llm_tool_response(
|
||||
response,
|
||||
expected_tool_name="decide_reply_action",
|
||||
log_prefix=f"{log_prefix}[Planner] "
|
||||
response, expected_tool_name="decide_reply_action", log_prefix=f"{log_prefix}[Planner] "
|
||||
)
|
||||
|
||||
if success:
|
||||
@@ -487,9 +483,7 @@ class HeartFChatting:
|
||||
emoji_query = arguments.get("emoji_query", "")
|
||||
|
||||
# 记录决策结果
|
||||
logger.debug(
|
||||
f"{log_prefix}[Planner] 决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'"
|
||||
)
|
||||
logger.debug(f"{log_prefix}[Planner] 决策结果: {action}, 理由: {reasoning}, 表情查询: '{emoji_query}'")
|
||||
else:
|
||||
# 处理工具调用失败
|
||||
logger.warning(f"{log_prefix}[Planner] {error_msg}")
|
||||
@@ -608,7 +602,9 @@ class HeartFChatting:
|
||||
|
||||
logger.info(f"{log_prefix} HeartFChatting关闭完成")
|
||||
|
||||
async def _build_planner_prompt(self, observed_messages_str: str, current_mind: Optional[str], structured_info: Dict[str, Any]) -> str:
|
||||
async def _build_planner_prompt(
|
||||
self, observed_messages_str: str, current_mind: Optional[str], structured_info: Dict[str, Any]
|
||||
) -> str:
|
||||
"""构建 Planner LLM 的提示词"""
|
||||
|
||||
prompt = f"你的名字是 {global_config.BOT_NICKNAME}。是一个女大二学生,正在QQ聊天,正在决定是否以及如何回应当前的聊天。\n"
|
||||
|
||||
@@ -72,7 +72,13 @@ class HeartFCGenerator:
|
||||
return None
|
||||
|
||||
async def _generate_response_with_model(
|
||||
self, structured_info: str, current_mind_info: str, reason: str, message: MessageRecv, model: LLMRequest, thinking_id: str
|
||||
self,
|
||||
structured_info: str,
|
||||
current_mind_info: str,
|
||||
reason: str,
|
||||
message: MessageRecv,
|
||||
model: LLMRequest,
|
||||
thinking_id: str,
|
||||
) -> str:
|
||||
sender_name = ""
|
||||
|
||||
|
||||
@@ -81,13 +81,22 @@ class PromptBuilder:
|
||||
self.activate_messages = ""
|
||||
|
||||
async def build_prompt(
|
||||
self, build_mode, reason, current_mind_info, structured_info, message_txt: str, sender_name: str = "某人", chat_stream=None
|
||||
self,
|
||||
build_mode,
|
||||
reason,
|
||||
current_mind_info,
|
||||
structured_info,
|
||||
message_txt: str,
|
||||
sender_name: str = "某人",
|
||||
chat_stream=None,
|
||||
) -> Optional[tuple[str, str]]:
|
||||
if build_mode == "normal":
|
||||
return await self._build_prompt_normal(chat_stream, message_txt, sender_name)
|
||||
|
||||
elif build_mode == "focus":
|
||||
return await self._build_prompt_focus(reason, current_mind_info, structured_info, chat_stream, message_txt, sender_name)
|
||||
return await self._build_prompt_focus(
|
||||
reason, current_mind_info, structured_info, chat_stream, message_txt, sender_name
|
||||
)
|
||||
return None
|
||||
|
||||
async def _build_prompt_focus(
|
||||
|
||||
@@ -786,11 +786,7 @@ class LLMRequest:
|
||||
message["tool_calls"] = tool_calls
|
||||
|
||||
result = {
|
||||
"choices": [
|
||||
{
|
||||
"message": message
|
||||
}
|
||||
],
|
||||
"choices": [{"message": message}],
|
||||
"usage": usage,
|
||||
}
|
||||
return result
|
||||
@@ -1139,7 +1135,7 @@ class LLMRequest:
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
**self.params,
|
||||
**kwargs,
|
||||
"tools": tools
|
||||
"tools": tools,
|
||||
}
|
||||
|
||||
logger.debug(f"向模型 {self.model_name} 发送工具调用请求,包含 {len(tools)} 个工具")
|
||||
@@ -1150,7 +1146,7 @@ class LLMRequest:
|
||||
logger.debug(f"收到工具调用响应,包含 {len(tool_calls) if tool_calls else 0} 个工具调用")
|
||||
return content, reasoning_content, tool_calls
|
||||
else:
|
||||
logger.debug(f"收到普通响应,无工具调用")
|
||||
logger.debug("收到普通响应,无工具调用")
|
||||
return response
|
||||
|
||||
async def get_embedding(self, text: str) -> Union[list, None]:
|
||||
|
||||
@@ -303,7 +303,9 @@ async def build_readable_messages(
|
||||
)
|
||||
|
||||
readable_read_mark = translate_timestamp_to_human_readable(read_mark, mode=timestamp_mode)
|
||||
read_mark_line = f"\n\n--- 以上消息已读 (标记时间: {readable_read_mark}) ---\n--- 请关注你上次思考之后以下的新消息---\n"
|
||||
read_mark_line = (
|
||||
f"\n\n--- 以上消息已读 (标记时间: {readable_read_mark}) ---\n--- 请关注你上次思考之后以下的新消息---\n"
|
||||
)
|
||||
|
||||
# 组合结果,确保空部分不引入多余的标记或换行
|
||||
if formatted_before and formatted_after:
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, Optional, TypeVar, Generic, List, Union, Callable, Tuple
|
||||
from typing import Any, Dict, TypeVar, List, Union, Callable, Tuple
|
||||
|
||||
# 定义类型变量用于泛型类型提示
|
||||
T = TypeVar('T')
|
||||
T = TypeVar("T")
|
||||
|
||||
# 获取logger
|
||||
logger = logging.getLogger("json_utils")
|
||||
|
||||
|
||||
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
|
||||
"""
|
||||
安全地解析JSON字符串,出错时返回默认值
|
||||
@@ -31,8 +32,8 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
|
||||
logger.error(f"JSON解析过程中发生意外错误: {e}")
|
||||
return default_value
|
||||
|
||||
def extract_tool_call_arguments(tool_call: Dict[str, Any],
|
||||
default_value: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
|
||||
def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
从LLM工具调用对象中提取参数
|
||||
|
||||
@@ -68,9 +69,10 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any],
|
||||
logger.error(f"提取工具调用参数时出错: {e}")
|
||||
return default_result
|
||||
|
||||
def get_json_value(json_obj: Dict[str, Any], key_path: str,
|
||||
default_value: T = None,
|
||||
transform_func: Callable[[Any], T] = None) -> Union[Any, T]:
|
||||
|
||||
def get_json_value(
|
||||
json_obj: Dict[str, Any], key_path: str, default_value: T = None, transform_func: Callable[[Any], T] = None
|
||||
) -> Union[Any, T]:
|
||||
"""
|
||||
从JSON对象中按照路径提取值,支持点表示法路径,如"data.items.0.name"
|
||||
|
||||
@@ -117,8 +119,8 @@ def get_json_value(json_obj: Dict[str, Any], key_path: str,
|
||||
logger.error(f"从JSON获取值时出错: {e}, 路径: {key_path}")
|
||||
return default_value
|
||||
|
||||
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False,
|
||||
pretty: bool = False) -> str:
|
||||
|
||||
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
|
||||
"""
|
||||
安全地将Python对象序列化为JSON字符串
|
||||
|
||||
@@ -141,6 +143,7 @@ def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = Fa
|
||||
logger.error(f"JSON序列化过程中发生意外错误: {e}")
|
||||
return default_value
|
||||
|
||||
|
||||
def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
合并多个JSON对象(字典)
|
||||
@@ -157,6 +160,7 @@ def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]:
|
||||
result.update(obj)
|
||||
return result
|
||||
|
||||
|
||||
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
|
||||
"""
|
||||
标准化LLM响应格式,将各种格式(如元组)转换为统一的列表格式
|
||||
@@ -196,6 +200,7 @@ def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, L
|
||||
|
||||
return True, response, ""
|
||||
|
||||
|
||||
def process_llm_tool_calls(response: List[Any], log_prefix: str = "") -> Tuple[bool, List[Dict[str, Any]], str]:
|
||||
"""
|
||||
处理并提取LLM响应中的工具调用列表
|
||||
@@ -247,10 +252,9 @@ def process_llm_tool_calls(response: List[Any], log_prefix: str = "") -> Tuple[b
|
||||
|
||||
return True, valid_tool_calls, ""
|
||||
|
||||
|
||||
def process_llm_tool_response(
|
||||
response: Any,
|
||||
expected_tool_name: str = None,
|
||||
log_prefix: str = ""
|
||||
response: Any, expected_tool_name: str = None, log_prefix: str = ""
|
||||
) -> Tuple[bool, Dict[str, Any], str]:
|
||||
"""
|
||||
处理LLM返回的工具调用响应,进行常见错误检查并提取参数
|
||||
|
||||
@@ -6,6 +6,7 @@ from src.do_tool.tool_use import ToolUser
|
||||
import statistics
|
||||
import json
|
||||
|
||||
|
||||
async def run_test(test_name, test_function, iterations=5):
|
||||
"""
|
||||
运行指定次数的测试并计算平均响应时间
|
||||
@@ -23,7 +24,7 @@ async def run_test(test_name, test_function, iterations=5):
|
||||
responses = []
|
||||
|
||||
for i in range(iterations):
|
||||
print(f" 运行第 {i+1}/{iterations} 次测试...")
|
||||
print(f" 运行第 {i + 1}/{iterations} 次测试...")
|
||||
start_time = time.time()
|
||||
response = await test_function()
|
||||
end_time = time.time()
|
||||
@@ -38,11 +39,12 @@ async def run_test(test_name, test_function, iterations=5):
|
||||
"最长耗时": max(times),
|
||||
"标准差": statistics.stdev(times) if len(times) > 1 else 0,
|
||||
"所有耗时": times,
|
||||
"响应结果": responses
|
||||
"响应结果": responses,
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
async def test_with_tool_calls():
|
||||
"""使用工具调用的LLM请求测试"""
|
||||
# 创建LLM模型实例
|
||||
@@ -60,7 +62,7 @@ async def test_with_tool_calls():
|
||||
|
||||
# 简单的测试提示词
|
||||
prompt = "请分析当前天气情况,并查询今日历史上的重要事件。并且3.9和3.11谁比较大?请使用适当的工具来获取这些信息。"
|
||||
prompt = '''
|
||||
prompt = """
|
||||
你的名字是麦麦,你包容开放,情绪敏感,有时候有些搞怪幽默, 是一个学习心理学和脑科学的女大学生,现在在读大二,你会刷贴吧,有时候会想瑟瑟,喜欢刷小红书
|
||||
-----------------------------------
|
||||
现在是2025-04-24 12:37:00,你正在上网,和qq群里的网友们聊天,群里正在聊的话题是:
|
||||
@@ -89,7 +91,7 @@ async def test_with_tool_calls():
|
||||
回复的要求是:平淡一些,简短一些,说中文,如果你要回复,最好只回复一个人的一个话题
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写。不要回复自己的发言,尽量不要说你说过的话。
|
||||
现在请你继续生成你在这个聊天中的想法,在原来想法的基础上继续思考,不要分点输出,生成内心想法,文字不要浮夸
|
||||
在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。'''
|
||||
在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。"""
|
||||
|
||||
# 发送带有工具调用的请求
|
||||
response = await llm_model.generate_response_tool_async(prompt=prompt, tools=tools)
|
||||
@@ -109,32 +111,27 @@ async def test_with_tool_calls():
|
||||
if tool_calls:
|
||||
print("\n 工具调用详情:")
|
||||
for i, tool_call in enumerate(tool_calls):
|
||||
tool_name = tool_call['function']['name']
|
||||
tool_params = tool_call['function'].get('arguments', {})
|
||||
print(f" - 工具 {i+1}: {tool_name}")
|
||||
print(f" 参数: {json.dumps(tool_params, ensure_ascii=False)[:100]}..."
|
||||
if len(json.dumps(tool_params, ensure_ascii=False)) > 100
|
||||
else f" 参数: {json.dumps(tool_params, ensure_ascii=False)}")
|
||||
tool_name = tool_call["function"]["name"]
|
||||
tool_params = tool_call["function"].get("arguments", {})
|
||||
print(f" - 工具 {i + 1}: {tool_name}")
|
||||
print(
|
||||
f" 参数: {json.dumps(tool_params, ensure_ascii=False)[:100]}..."
|
||||
if len(json.dumps(tool_params, ensure_ascii=False)) > 100
|
||||
else f" 参数: {json.dumps(tool_params, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
result_info = {
|
||||
"内容": content,
|
||||
"推理内容": reasoning_content,
|
||||
"工具调用": tool_calls
|
||||
}
|
||||
result_info = {"内容": content, "推理内容": reasoning_content, "工具调用": tool_calls}
|
||||
else:
|
||||
content, reasoning_content = response
|
||||
print(" 工具调用请求未生成任何工具调用")
|
||||
print("\n 生成的内容:")
|
||||
print(f" {content[:200]}..." if len(content) > 200 else f" {content}")
|
||||
|
||||
result_info = {
|
||||
"内容": content,
|
||||
"推理内容": reasoning_content,
|
||||
"工具调用": []
|
||||
}
|
||||
result_info = {"内容": content, "推理内容": reasoning_content, "工具调用": []}
|
||||
|
||||
return result_info
|
||||
|
||||
|
||||
async def test_without_tool_calls():
|
||||
"""不使用工具调用的LLM请求测试"""
|
||||
# 创建LLM模型实例
|
||||
@@ -146,7 +143,7 @@ async def test_without_tool_calls():
|
||||
)
|
||||
|
||||
# 简单的测试提示词(与工具调用相同,以便公平比较)
|
||||
prompt = '''
|
||||
prompt = """
|
||||
你的名字是麦麦,你包容开放,情绪敏感,有时候有些搞怪幽默, 是一个学习心理学和脑科学的女大学生,现在在读大二,你会刷贴吧,有时候会想瑟瑟,喜欢刷小红书
|
||||
刚刚你的想法是:
|
||||
我是麦麦,我想,('小千石问3.8和3.11谁大,已经简单回答了3.11大,现在可以继续聊猫猫头表情包,毕竟大家好像对版本问题兴趣不大,而且猫猫头的话题更轻松有趣。', '')
|
||||
@@ -181,7 +178,7 @@ async def test_without_tool_calls():
|
||||
回复的要求是:平淡一些,简短一些,说中文,如果你要回复,最好只回复一个人的一个话题
|
||||
请注意不要输出多余内容(包括前后缀,冒号和引号,括号, 表情,等),不要带有括号和动作描写。不要回复自己的发言,尽量不要说你说过的话。
|
||||
现在请你继续生成你在这个聊天中的想法,在原来想法的基础上继续思考,不要分点输出,生成内心想法,文字不要浮夸
|
||||
在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。'''
|
||||
在输出完想法后,请你思考应该使用什么工具,如果你需要做某件事,来对消息和你的回复进行处理,请使用工具。"""
|
||||
|
||||
# 发送不带工具调用的请求
|
||||
response, reasoning_content = await llm_model.generate_response_async(prompt)
|
||||
@@ -190,14 +187,11 @@ async def test_without_tool_calls():
|
||||
print("\n 生成的内容:")
|
||||
print(f" {response[:200]}..." if len(response) > 200 else f" {response}")
|
||||
|
||||
result_info = {
|
||||
"内容": response,
|
||||
"推理内容": reasoning_content,
|
||||
"工具调用": []
|
||||
}
|
||||
result_info = {"内容": response, "推理内容": reasoning_content, "工具调用": []}
|
||||
|
||||
return result_info
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
print("=" * 50)
|
||||
@@ -256,8 +250,9 @@ async def main():
|
||||
"不使用工具调用_详细响应": [
|
||||
{
|
||||
"内容摘要": resp["内容"][:200] + "..." if len(resp["内容"]) > 200 else resp["内容"],
|
||||
"推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"]
|
||||
} for resp in results_without_tools["响应结果"]
|
||||
"推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"],
|
||||
}
|
||||
for resp in results_without_tools["响应结果"]
|
||||
],
|
||||
"使用工具调用": {
|
||||
k: (v if k != "所有耗时" else [float(f"{t:.2f}") for t in v])
|
||||
@@ -270,20 +265,20 @@ async def main():
|
||||
"推理内容摘要": resp["推理内容"][:200] + "..." if len(resp["推理内容"]) > 200 else resp["推理内容"],
|
||||
"工具调用数量": len(resp["工具调用"]),
|
||||
"工具调用详情": [
|
||||
{
|
||||
"工具名称": tool["function"]["name"],
|
||||
"参数": tool["function"].get("arguments", {})
|
||||
} for tool in resp["工具调用"]
|
||||
]
|
||||
} for resp in results_with_tools["响应结果"]
|
||||
{"工具名称": tool["function"]["name"], "参数": tool["function"].get("arguments", {})}
|
||||
for tool in resp["工具调用"]
|
||||
],
|
||||
}
|
||||
for resp in results_with_tools["响应结果"]
|
||||
],
|
||||
"差异百分比": float(f"{diff_percent:.2f}")
|
||||
"差异百分比": float(f"{diff_percent:.2f}"),
|
||||
}
|
||||
|
||||
with open("llm_tool_benchmark_results.json", "w", encoding="utf-8") as f:
|
||||
json.dump(results, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"\n测试结果已保存到 llm_tool_benchmark_results.json")
|
||||
print("\n测试结果已保存到 llm_tool_benchmark_results.json")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user