token统计部分
This commit is contained in:
@@ -49,6 +49,12 @@ class LLM_request:
|
||||
# 获取数据库实例
|
||||
self._init_database()
|
||||
|
||||
# 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
|
||||
self.request_type = kwargs.pop("request_type", "default")
|
||||
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
@@ -67,7 +73,7 @@ class LLM_request:
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
user_id: str = "system",
|
||||
request_type: str = "chat",
|
||||
request_type: str = None,
|
||||
endpoint: str = "/chat/completions",
|
||||
):
|
||||
"""记录模型使用情况到数据库
|
||||
@@ -76,9 +82,13 @@ class LLM_request:
|
||||
completion_tokens: 输出token数
|
||||
total_tokens: 总token数
|
||||
user_id: 用户ID,默认为system
|
||||
request_type: 请求类型(chat/embedding/image等)
|
||||
request_type: 请求类型(chat/embedding/image/topic/schedule)
|
||||
endpoint: API端点
|
||||
"""
|
||||
# 如果 request_type 为 None,则使用实例变量中的值
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
|
||||
try:
|
||||
usage_data = {
|
||||
"model_name": self.model_name,
|
||||
@@ -128,7 +138,7 @@ class LLM_request:
|
||||
retry_policy: dict = None,
|
||||
response_handler: callable = None,
|
||||
user_id: str = "system",
|
||||
request_type: str = "chat",
|
||||
request_type: str = None,
|
||||
):
|
||||
"""统一请求执行入口
|
||||
Args:
|
||||
@@ -142,6 +152,10 @@ class LLM_request:
|
||||
user_id: 用户ID
|
||||
request_type: 请求类型
|
||||
"""
|
||||
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
|
||||
# 合并重试策略
|
||||
default_retry = {
|
||||
"max_retries": 3,
|
||||
@@ -441,7 +455,7 @@ class LLM_request:
|
||||
return payload
|
||||
|
||||
def _default_response_handler(
|
||||
self, result: dict, user_id: str = "system", request_type: str = "chat", endpoint: str = "/chat/completions"
|
||||
self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
|
||||
) -> Tuple:
|
||||
"""默认响应解析"""
|
||||
if "choices" in result and result["choices"]:
|
||||
@@ -465,7 +479,7 @@ class LLM_request:
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
request_type = request_type if request_type is not None else self.request_type,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
@@ -538,6 +552,22 @@ class LLM_request:
|
||||
def embedding_handler(result):
|
||||
"""处理响应"""
|
||||
if "data" in result and len(result["data"]) > 0:
|
||||
# 提取 token 使用信息
|
||||
usage = result.get("usage", {})
|
||||
if usage:
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
total_tokens = usage.get("total_tokens", 0)
|
||||
# 记录 token 使用情况
|
||||
self._record_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
user_id="system", # 可以根据需要修改 user_id
|
||||
request_type="embedding", # 请求类型为 embedding
|
||||
endpoint="/embeddings" # API 端点
|
||||
)
|
||||
return result["data"][0].get("embedding", None)
|
||||
return result["data"][0].get("embedding", None)
|
||||
return None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user