This commit is contained in:
SengokuCola
2025-04-24 14:19:26 +08:00
parent f8450f705a
commit 3075664480
13 changed files with 224 additions and 225 deletions

View File

@@ -1,27 +1,28 @@
import json
import logging
from typing import Any, Dict, Optional, TypeVar, Generic, List, Union, Callable, Tuple
from typing import Any, Dict, TypeVar, List, Union, Callable, Tuple
# 定义类型变量用于泛型类型提示
T = TypeVar('T')
T = TypeVar("T")
# 获取logger
logger = logging.getLogger("json_utils")
def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
"""
安全地解析JSON字符串出错时返回默认值
参数:
json_str: 要解析的JSON字符串
default_value: 解析失败时返回的默认值
返回:
解析后的Python对象或在解析失败时返回default_value
"""
if not json_str:
return default_value
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
@@ -31,66 +32,67 @@ def safe_json_loads(json_str: str, default_value: T = None) -> Union[Any, T]:
logger.error(f"JSON解析过程中发生意外错误: {e}")
return default_value
def extract_tool_call_arguments(tool_call: Dict[str, Any],
default_value: Dict[str, Any] = None) -> Dict[str, Any]:
def extract_tool_call_arguments(tool_call: Dict[str, Any], default_value: Dict[str, Any] = None) -> Dict[str, Any]:
"""
从LLM工具调用对象中提取参数
参数:
tool_call: 工具调用对象字典
default_value: 解析失败时返回的默认值
返回:
解析后的参数字典或在解析失败时返回default_value
"""
default_result = default_value or {}
if not tool_call or not isinstance(tool_call, dict):
logger.error(f"无效的工具调用对象: {tool_call}")
return default_result
try:
# 提取function参数
function_data = tool_call.get("function", {})
if not function_data or not isinstance(function_data, dict):
logger.error(f"工具调用缺少function字段或格式不正确: {tool_call}")
return default_result
# 提取arguments
arguments_str = function_data.get("arguments", "{}")
if not arguments_str:
return default_result
# 解析JSON
return safe_json_loads(arguments_str, default_result)
except Exception as e:
logger.error(f"提取工具调用参数时出错: {e}")
return default_result
def get_json_value(json_obj: Dict[str, Any], key_path: str,
default_value: T = None,
transform_func: Callable[[Any], T] = None) -> Union[Any, T]:
def get_json_value(
json_obj: Dict[str, Any], key_path: str, default_value: T = None, transform_func: Callable[[Any], T] = None
) -> Union[Any, T]:
"""
从JSON对象中按照路径提取值支持点表示法路径"data.items.0.name"
参数:
json_obj: JSON对象(已解析的字典)
key_path: 键路径,使用点表示法,如"data.items.0.name"
default_value: 获取失败时返回的默认值
transform_func: 可选的转换函数,用于对获取的值进行转换
返回:
路径指向的值或在获取失败时返回default_value
"""
if not json_obj or not key_path:
return default_value
try:
# 分割路径
keys = key_path.split(".")
current = json_obj
# 遍历路径
for key in keys:
# 处理数组索引
@@ -108,7 +110,7 @@ def get_json_value(json_obj: Dict[str, Any], key_path: str,
return default_value
else:
return default_value
# 应用转换函数(如果提供)
if transform_func and current is not None:
return transform_func(current)
@@ -117,17 +119,17 @@ def get_json_value(json_obj: Dict[str, Any], key_path: str,
logger.error(f"从JSON获取值时出错: {e}, 路径: {key_path}")
return default_value
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False,
pretty: bool = False) -> str:
def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = False, pretty: bool = False) -> str:
"""
安全地将Python对象序列化为JSON字符串
参数:
obj: 要序列化的Python对象
default_value: 序列化失败时返回的默认值
ensure_ascii: 是否确保ASCII编码(默认False允许中文等非ASCII字符)
pretty: 是否美化输出JSON
返回:
序列化后的JSON字符串或在序列化失败时返回default_value
"""
@@ -141,13 +143,14 @@ def safe_json_dumps(obj: Any, default_value: str = "{}", ensure_ascii: bool = Fa
logger.error(f"JSON序列化过程中发生意外错误: {e}")
return default_value
def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]:
"""
合并多个JSON对象(字典)
参数:
*objects: 要合并的JSON对象(字典)
返回:
合并后的字典,后面的对象会覆盖前面对象的相同键
"""
@@ -157,109 +160,110 @@ def merge_json_objects(*objects: Dict[str, Any]) -> Dict[str, Any]:
result.update(obj)
return result
def normalize_llm_response(response: Any, log_prefix: str = "") -> Tuple[bool, List[Any], str]:
"""
标准化LLM响应格式将各种格式如元组转换为统一的列表格式
参数:
response: 原始LLM响应
log_prefix: 日志前缀
返回:
元组 (成功标志, 标准化后的响应列表, 错误消息)
"""
# 检查是否为None
if response is None:
return False, [], "LLM响应为None"
# 记录原始类型
logger.debug(f"{log_prefix}LLM响应原始类型: {type(response).__name__}")
# 将元组转换为列表
if isinstance(response, tuple):
logger.debug(f"{log_prefix}将元组响应转换为列表")
response = list(response)
# 确保是列表类型
if not isinstance(response, list):
return False, [], f"无法处理的LLM响应类型: {type(response).__name__}"
# 处理工具调用部分(如果存在)
if len(response) == 3:
content, reasoning, tool_calls = response
# 将工具调用部分转换为列表(如果是元组)
if isinstance(tool_calls, tuple):
logger.debug(f"{log_prefix}将工具调用元组转换为列表")
tool_calls = list(tool_calls)
response[2] = tool_calls
return True, response, ""
def process_llm_tool_calls(response: List[Any], log_prefix: str = "") -> Tuple[bool, List[Dict[str, Any]], str]:
"""
处理并提取LLM响应中的工具调用列表
参数:
response: 标准化后的LLM响应列表
log_prefix: 日志前缀
返回:
元组 (成功标志, 工具调用列表, 错误消息)
"""
# 确保响应格式正确
if len(response) != 3:
return False, [], f"LLM响应元素数量不正确: 预期3个元素实际{len(response)}"
# 提取工具调用部分
tool_calls = response[2]
# 检查工具调用是否有效
if tool_calls is None:
return False, [], "工具调用部分为None"
if not isinstance(tool_calls, list):
return False, [], f"工具调用部分不是列表: {type(tool_calls).__name__}"
if len(tool_calls) == 0:
return False, [], "工具调用列表为空"
# 检查工具调用是否格式正确
valid_tool_calls = []
for i, tool_call in enumerate(tool_calls):
if not isinstance(tool_call, dict):
logger.warning(f"{log_prefix}工具调用[{i}]不是字典: {type(tool_call).__name__}")
continue
if tool_call.get("type") != "function":
logger.warning(f"{log_prefix}工具调用[{i}]不是函数类型: {tool_call.get('type', '未知')}")
continue
if "function" not in tool_call or not isinstance(tool_call["function"], dict):
logger.warning(f"{log_prefix}工具调用[{i}]缺少function字段或格式不正确")
continue
valid_tool_calls.append(tool_call)
# 检查是否有有效的工具调用
if not valid_tool_calls:
return False, [], "没有找到有效的工具调用"
return True, valid_tool_calls, ""
def process_llm_tool_response(
response: Any,
expected_tool_name: str = None,
log_prefix: str = ""
response: Any, expected_tool_name: str = None, log_prefix: str = ""
) -> Tuple[bool, Dict[str, Any], str]:
"""
处理LLM返回的工具调用响应进行常见错误检查并提取参数
参数:
response: LLM的响应预期是[content, reasoning, tool_calls]格式的列表或元组
expected_tool_name: 预期的工具名称,如不指定则不检查
log_prefix: 日志前缀,用于标识日志来源
返回:
三元组(成功标志, 参数字典, 错误描述)
- 如果成功解析,返回(True, 参数字典, "")
@@ -269,29 +273,29 @@ def process_llm_tool_response(
success, normalized_response, error_msg = normalize_llm_response(response, log_prefix)
if not success:
return False, {}, error_msg
# 使用新的工具调用处理函数
success, valid_tool_calls, error_msg = process_llm_tool_calls(normalized_response, log_prefix)
if not success:
return False, {}, error_msg
# 检查是否有工具调用
if not valid_tool_calls:
return False, {}, "没有有效的工具调用"
# 获取第一个工具调用
tool_call = valid_tool_calls[0]
# 检查工具名称(如果提供了预期名称)
if expected_tool_name:
actual_name = tool_call.get("function", {}).get("name")
if actual_name != expected_tool_name:
return False, {}, f"工具名称不匹配: 预期'{expected_tool_name}',实际'{actual_name}'"
# 提取并解析参数
try:
arguments = extract_tool_call_arguments(tool_call, {})
return True, arguments, ""
except Exception as e:
logger.error(f"{log_prefix}解析工具参数时出错: {e}")
return False, {}, f"解析参数失败: {str(e)}"
return False, {}, f"解析参数失败: {str(e)}"