fix:优化工具解析

This commit is contained in:
SengokuCola
2025-04-28 19:31:00 +08:00
parent 629cdb007b
commit f83e151d40
4 changed files with 159 additions and 262 deletions

View File

@@ -70,55 +70,6 @@ def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[s
return default_result
def get_json_value(
json_obj: Dict[str, Any], key_path: str, default_value: T = None, transform_func: Callable[[Any], T] = None
) -> Union[Any, T]:
"""
从JSON对象中按照路径提取值支持点表示法路径"data.items.0.name"
参数:
json_obj: JSON对象(已解析的字典)
key_path: 键路径,使用点表示法,如"data.items.0.name"
default_value: 获取失败时返回的默认值
transform_func: 可选的转换函数,用于对获取的值进行转换
返回:
路径指向的值或在获取失败时返回default_value
"""
if not json_obj or not key_path:
return default_value
try:
# 分割路径
keys = key_path.split(".")
current = json_obj
# 遍历路径
for key in keys:
# 处理数组索引
if key.isdigit() and isinstance(current, list):
index = int(key)
if 0 <= index < len(current):
current = current[index]
else:
return default_value
# 处理字典键
elif isinstance(current, dict):
if key in current:
current = current[key]
else:
return default_value
else:
return default_value
# 应用转换函数(如果提供)
if transform_func and current is not None:
return transform_func(current)
return current
except Exception as e:
logger.error(f"从JSON获取值时出错: {e}, 路径: {key_path}")
return default_value
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
"""
@@ -144,21 +95,6 @@ def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = Fa
return default_value
def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]:
"""
合并多个JSON对象(字典)
参数:
*objects: 要合并的JSON对象(字典)
返回:
合并后的字典,后面的对象会覆盖前面对象的相同键
"""
result = {}
for obj in objects:
if obj and isinstance(obj, dict):
result.update(obj)
return result
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
@@ -172,6 +108,9 @@ def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, L
返回:
元组 (成功标志, 标准化后的响应列表, 错误消息)
"""
logger.debug(f"{log_prefix}原始人 LLM响应: {response}")
# 检查是否为None
if response is None:
return False, [], "LLM响应为None"
@@ -201,114 +140,60 @@ def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, L
return True, response, ""
def process_llm_tool_calls(response: List[Any], log_prefix: str = "") -> Tuple[bool, List[Dict[str, Any]], str]:
def process_llm_tool_calls(tool_calls: List[Dict[str, Any]], log_prefix: str = "") -> Tuple[bool, List[Dict[str, Any]], str]:
"""
处理并提取LLM响应中的工具调用列表
处理并验证LLM响应中的工具调用列表
参数:
response: 标准化后的LLM响应列表
tool_calls: 从LLM响应中直接获取的工具调用列表
log_prefix: 日志前缀
返回:
元组 (成功标志, 工具调用列表, 错误消息)
元组 (成功标志, 验证后的工具调用列表, 错误消息)
"""
# 确保响应格式正确
print(response)
print(11111111111111111)
if len(response) != 3:
return False, [], f"LLM响应元素数量不正确: 预期3个元素实际{len(response)}"
# 如果列表为空,表示没有工具调用,这不是错误
if not tool_calls:
return True, [], "工具调用列表为空"
# 提取工具调用部分
tool_calls = response[2]
# 检查工具调用是否有效
if tool_calls is None:
return False, [], "工具调用部分为None"
if not isinstance(tool_calls, list):
return False, [], f"工具调用部分不是列表: {type(tool_calls).__name__}"
if len(tool_calls) == 0:
return False, [], "工具调用列表为空"
# 检查工具调用是否格式正确
# 验证每个工具调用的格式
valid_tool_calls = []
for i, tool_call in enumerate(tool_calls):
if not isinstance(tool_call, dict):
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}")
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}, 内容: {tool_call}")
continue
# 检查基本结构
if tool_call.get("type") != "function":
logger.warning(f"{log_prefix}工具调用[{i}]不是函数类型: {tool_call.get('type', '')}")
logger.warning(f"{log_prefix}工具调用[{i}]不是function类型: type={tool_call.get('type', '定义')}, 内容: {tool_call}")
continue
if "function" not in tool_call or not isinstance(tool_call["function"], dict):
logger.warning(f"{log_prefix}工具调用[{i}]缺少function字段或格式不正确")
if "function" not in tool_call or not isinstance(tool_call.get("function"), dict):
logger.warning(f"{log_prefix}工具调用[{i}]缺少'function'字段或其类型不正确: {tool_call}")
continue
func_details = tool_call["function"]
if "name" not in func_details or not isinstance(func_details.get("name"), str):
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'name'或类型不正确: {func_details}")
continue
if "arguments" not in func_details or not isinstance(func_details.get("arguments"), str): # 参数是字符串形式的JSON
logger.warning(f"{log_prefix}工具调用[{i}]的'function'字段缺少'arguments'或类型不正确: {func_details}")
continue
# 可选尝试解析参数JSON确保其有效
args_str = func_details["arguments"]
try:
json.loads(args_str) # 尝试解析,但不存储结果
except json.JSONDecodeError as e:
logger.warning(f"{log_prefix}工具调用[{i}]的'arguments'不是有效的JSON字符串: {e}, 内容: {args_str[:100]}...")
continue
except Exception as e:
logger.warning(f"{log_prefix}解析工具调用[{i}]的'arguments'时发生意外错误: {e}, 内容: {args_str[:100]}...")
continue
valid_tool_calls.append(tool_call)
# 检查是否有有效的工具调用
if not valid_tool_calls:
return False, [], "没有找到有效的工具调用"
if not valid_tool_calls and tool_calls: # 如果原始列表不为空,但验证后为空
return False, [], "所有工具调用格式均无效"
return True, valid_tool_calls, ""
def process_llm_tool_response(
response: Any, expected_tool_name: str = None, log_prefix: str = ""
) -> Tuple[bool, Dict[str, Any], str]:
"""
处理LLM返回的工具调用响应进行常见错误检查并提取参数
参数:
response: LLM的响应预期是[content, reasoning, tool_calls]格式的列表或元组
expected_tool_name: 预期的工具名称,如不指定则不检查
log_prefix: 日志前缀,用于标识日志来源
返回:
三元组(成功标志, 参数字典, 错误描述)
- 如果成功解析,返回(True, 参数字典, "")
- 如果解析失败,返回(False, {}, 错误描述)
"""
# 使用新的标准化函数
success, normalized_response, error_msg = normalize_llm_response(response, log_prefix)
if not success:
return False, {}, error_msg
# 新增检查:确保响应包含预期的工具调用部分
if len(normalized_response) != 3:
# 如果长度不为3说明LLM响应不包含工具调用部分这在期望工具调用的上下文中是错误的
error_msg = (
f"LLM响应未包含预期的工具调用部分: 元素数量{len(normalized_response)},响应内容:{normalized_response}"
)
logger.warning(f"{log_prefix}{error_msg}")
return False, {}, error_msg
# 使用新的工具调用处理函数
# 此时已知 normalized_response 长度必定为 3
success, valid_tool_calls, error_msg = process_llm_tool_calls(normalized_response, log_prefix)
if not success:
return False, {}, error_msg
# 检查是否有工具调用
if not valid_tool_calls:
return False, {}, "没有有效的工具调用"
# 获取第一个工具调用
tool_call = valid_tool_calls[0]
# 检查工具名称(如果提供了预期名称)
if expected_tool_name:
actual_name = tool_call.get("function", {}).get("name")
if actual_name != expected_tool_name:
return False, {}, f"工具名称不匹配: 预期'{expected_tool_name}',实际'{actual_name}'"
# 提取并解析参数
try:
arguments = extract_tool_call_arguments(tool_call, {})
return True, arguments, ""
except Exception as e:
logger.error(f"{log_prefix}解析工具参数时出错: {e}")
return False, {}, f"解析参数失败: {str(e)}"