fix(llm): 兼容处理部分模型缺失的token用量字段
部分模型(如 embedding 模型)的 API 响应中可能不包含 `completion_tokens` 等完整的用量字段。 此前的直接属性访问会导致 `AttributeError`,从而中断使用记录和统计更新流程。 通过改用 `getattr(usage, "...", 0)` 的方式为缺失的字段提供默认值 0,增强了代码的健壮性,确保系统能够稳定处理来自不同类型模型的响应。
This commit is contained in:
@@ -26,13 +26,13 @@ class UsageRecord:
|
|||||||
provider_name: str
|
provider_name: str
|
||||||
"""提供商名称"""
|
"""提供商名称"""
|
||||||
|
|
||||||
prompt_tokens: int
|
prompt_tokens: int = 0
|
||||||
"""提示token数"""
|
"""提示token数"""
|
||||||
|
|
||||||
completion_tokens: int
|
completion_tokens: int = 0
|
||||||
"""完成token数"""
|
"""完成token数"""
|
||||||
|
|
||||||
total_tokens: int
|
total_tokens: int = 0
|
||||||
"""总token数"""
|
"""总token数"""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -298,9 +298,9 @@ async def _default_stream_response_handler(
|
|||||||
if event.usage:
|
if event.usage:
|
||||||
# 如果有使用情况,则将其存储在APIResponse对象中
|
# 如果有使用情况,则将其存储在APIResponse对象中
|
||||||
_usage_record = (
|
_usage_record = (
|
||||||
event.usage.prompt_tokens or 0,
|
getattr(event.usage, "prompt_tokens", 0) or 0,
|
||||||
event.usage.completion_tokens or 0,
|
getattr(event.usage, "completion_tokens", 0) or 0,
|
||||||
event.usage.total_tokens or 0,
|
getattr(event.usage, "total_tokens", 0) or 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -368,9 +368,9 @@ def _default_normal_response_parser(
|
|||||||
# 提取Usage信息
|
# 提取Usage信息
|
||||||
if resp.usage:
|
if resp.usage:
|
||||||
_usage_record = (
|
_usage_record = (
|
||||||
resp.usage.prompt_tokens or 0,
|
getattr(resp.usage, "prompt_tokens", 0) or 0,
|
||||||
resp.usage.completion_tokens or 0,
|
getattr(resp.usage, "completion_tokens", 0) or 0,
|
||||||
resp.usage.total_tokens or 0,
|
getattr(resp.usage, "total_tokens", 0) or 0,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
_usage_record = None
|
_usage_record = None
|
||||||
@@ -599,7 +599,7 @@ class OpenaiClient(BaseClient):
|
|||||||
model_name=model_info.name,
|
model_name=model_info.name,
|
||||||
provider_name=model_info.api_provider,
|
provider_name=model_info.api_provider,
|
||||||
prompt_tokens=raw_response.usage.prompt_tokens or 0,
|
prompt_tokens=raw_response.usage.prompt_tokens or 0,
|
||||||
completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore
|
completion_tokens=getattr(raw_response.usage, "completion_tokens", 0) or 0,
|
||||||
total_tokens=raw_response.usage.total_tokens or 0,
|
total_tokens=raw_response.usage.total_tokens or 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -155,8 +155,12 @@ class LLMUsageRecorder:
|
|||||||
endpoint: str,
|
endpoint: str,
|
||||||
time_cost: float = 0.0,
|
time_cost: float = 0.0,
|
||||||
):
|
):
|
||||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
prompt_tokens = getattr(model_usage, "prompt_tokens", 0)
|
||||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
completion_tokens = getattr(model_usage, "completion_tokens", 0)
|
||||||
|
total_tokens = getattr(model_usage, "total_tokens", 0)
|
||||||
|
|
||||||
|
input_cost = (prompt_tokens / 1000000) * model_info.price_in
|
||||||
|
output_cost = (completion_tokens / 1000000) * model_info.price_out
|
||||||
round(input_cost + output_cost, 6)
|
round(input_cost + output_cost, 6)
|
||||||
|
|
||||||
session = None
|
session = None
|
||||||
@@ -170,9 +174,9 @@ class LLMUsageRecorder:
|
|||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
request_type=request_type,
|
request_type=request_type,
|
||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
prompt_tokens=model_usage.prompt_tokens or 0,
|
prompt_tokens=prompt_tokens,
|
||||||
completion_tokens=model_usage.completion_tokens or 0,
|
completion_tokens=completion_tokens,
|
||||||
total_tokens=model_usage.total_tokens or 0,
|
total_tokens=total_tokens,
|
||||||
cost=1.0,
|
cost=1.0,
|
||||||
time_cost=round(time_cost or 0.0, 3),
|
time_cost=round(time_cost or 0.0, 3),
|
||||||
status="success",
|
status="success",
|
||||||
@@ -185,8 +189,8 @@ class LLMUsageRecorder:
|
|||||||
logger.debug(
|
logger.debug(
|
||||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||||
f"用户: {user_id}, 类型: {request_type}, "
|
f"用户: {user_id}, 类型: {request_type}, "
|
||||||
f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, "
|
f"提示词: {prompt_tokens}, 完成: {completion_tokens}, "
|
||||||
f"总计: {model_usage.total_tokens}"
|
f"总计: {total_tokens}"
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"记录token使用情况失败: {e!s}")
|
logger.error(f"记录token使用情况失败: {e!s}")
|
||||||
|
|||||||
@@ -1009,12 +1009,15 @@ class LLMRequest:
|
|||||||
# 步骤1: 更新内存中的统计数据,用于负载均衡
|
# 步骤1: 更新内存中的统计数据,用于负载均衡
|
||||||
stats = self.model_usage[model_info.name]
|
stats = self.model_usage[model_info.name]
|
||||||
|
|
||||||
|
# 安全地获取 token 使用量, embedding 模型可能不返回 completion_tokens
|
||||||
|
total_tokens = getattr(usage, "total_tokens", 0)
|
||||||
|
|
||||||
# 计算新的平均延迟
|
# 计算新的平均延迟
|
||||||
new_request_count = stats.request_count + 1
|
new_request_count = stats.request_count + 1
|
||||||
new_avg_latency = (stats.avg_latency * stats.request_count + time_cost) / new_request_count
|
new_avg_latency = (stats.avg_latency * stats.request_count + time_cost) / new_request_count
|
||||||
|
|
||||||
self.model_usage[model_info.name] = stats._replace(
|
self.model_usage[model_info.name] = stats._replace(
|
||||||
total_tokens=stats.total_tokens + usage.total_tokens,
|
total_tokens=stats.total_tokens + total_tokens,
|
||||||
avg_latency=new_avg_latency,
|
avg_latency=new_avg_latency,
|
||||||
request_count=new_request_count,
|
request_count=new_request_count,
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user