diff --git a/src/llm_models/model_client/base_client.py b/src/llm_models/model_client/base_client.py index 246b0618b..baab2897b 100644 --- a/src/llm_models/model_client/base_client.py +++ b/src/llm_models/model_client/base_client.py @@ -26,13 +26,13 @@ class UsageRecord: provider_name: str """提供商名称""" - prompt_tokens: int + prompt_tokens: int = 0 """提示token数""" - completion_tokens: int + completion_tokens: int = 0 """完成token数""" - total_tokens: int + total_tokens: int = 0 """总token数""" diff --git a/src/llm_models/model_client/openai_client.py b/src/llm_models/model_client/openai_client.py index 95611d8da..6bf48acc0 100644 --- a/src/llm_models/model_client/openai_client.py +++ b/src/llm_models/model_client/openai_client.py @@ -298,9 +298,9 @@ async def _default_stream_response_handler( if event.usage: # 如果有使用情况,则将其存储在APIResponse对象中 _usage_record = ( - event.usage.prompt_tokens or 0, - event.usage.completion_tokens or 0, - event.usage.total_tokens or 0, + getattr(event.usage, "prompt_tokens", 0) or 0, + getattr(event.usage, "completion_tokens", 0) or 0, + getattr(event.usage, "total_tokens", 0) or 0, ) try: @@ -368,9 +368,9 @@ def _default_normal_response_parser( # 提取Usage信息 if resp.usage: _usage_record = ( - resp.usage.prompt_tokens or 0, - resp.usage.completion_tokens or 0, - resp.usage.total_tokens or 0, + getattr(resp.usage, "prompt_tokens", 0) or 0, + getattr(resp.usage, "completion_tokens", 0) or 0, + getattr(resp.usage, "total_tokens", 0) or 0, ) else: _usage_record = None @@ -599,7 +599,7 @@ class OpenaiClient(BaseClient): model_name=model_info.name, provider_name=model_info.api_provider, prompt_tokens=raw_response.usage.prompt_tokens or 0, - completion_tokens=raw_response.usage.completion_tokens or 0, # type: ignore + completion_tokens=getattr(raw_response.usage, "completion_tokens", 0) or 0, total_tokens=raw_response.usage.total_tokens or 0, ) diff --git a/src/llm_models/utils.py b/src/llm_models/utils.py index b2c7e57b0..9855b2446 100644 --- a/src/llm_models/utils.py +++ b/src/llm_models/utils.py @@ -155,8 +155,12 @@ class LLMUsageRecorder: endpoint: str, time_cost: float = 0.0, ): - input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in - output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out + prompt_tokens = getattr(model_usage, "prompt_tokens", 0) + completion_tokens = getattr(model_usage, "completion_tokens", 0) + total_tokens = getattr(model_usage, "total_tokens", 0) + + input_cost = (prompt_tokens / 1000000) * model_info.price_in + output_cost = (completion_tokens / 1000000) * model_info.price_out round(input_cost + output_cost, 6) session = None @@ -170,9 +174,9 @@ class LLMUsageRecorder: user_id=user_id, request_type=request_type, endpoint=endpoint, - prompt_tokens=model_usage.prompt_tokens or 0, - completion_tokens=model_usage.completion_tokens or 0, - total_tokens=model_usage.total_tokens or 0, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, cost=1.0, time_cost=round(time_cost or 0.0, 3), status="success", @@ -185,8 +189,8 @@ class LLMUsageRecorder: logger.debug( f"Token使用情况 - 模型: {model_usage.model_name}, " f"用户: {user_id}, 类型: {request_type}, " - f"提示词: {model_usage.prompt_tokens}, 完成: {model_usage.completion_tokens}, " - f"总计: {model_usage.total_tokens}" + f"提示词: {prompt_tokens}, 完成: {completion_tokens}, " + f"总计: {total_tokens}" ) except Exception as e: logger.error(f"记录token使用情况失败: {e!s}") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index d4bb708a2..08e5bb934 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1009,12 +1009,15 @@ class LLMRequest: # 步骤1: 更新内存中的统计数据,用于负载均衡 stats = self.model_usage[model_info.name] + # 安全地获取 token 使用量, embedding 模型可能不返回 completion_tokens + total_tokens = getattr(usage, "total_tokens", 0) + # 计算新的平均延迟 new_request_count = stats.request_count + 1 new_avg_latency = (stats.avg_latency * stats.request_count + time_cost) / new_request_count self.model_usage[model_info.name] = stats._replace( - total_tokens=stats.total_tokens + usage.total_tokens, + total_tokens=stats.total_tokens + total_tokens, avg_latency=new_avg_latency, request_count=new_request_count, )