token统计部分

This commit is contained in:
Maple127667
2025-03-16 23:11:32 +08:00
parent 71132315d8
commit 4bc222ba6f
7 changed files with 43 additions and 13 deletions

View File

@@ -38,9 +38,9 @@ class EmojiManager:
def __init__(self): def __init__(self):
self._scan_task = None self._scan_task = None
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000,request_type = 'image')
self.llm_emotion_judge = LLM_request( self.llm_emotion_judge = LLM_request(
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8 model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8,request_type = 'image'
) # 更高的温度更少的token后续可以根据情绪来调整温度 ) # 更高的温度更少的token后续可以根据情绪来调整温度
def _ensure_emoji_dir(self): def _ensure_emoji_dir(self):

View File

@@ -14,7 +14,7 @@ config = driver.config
class TopicIdentifier: class TopicIdentifier:
def __init__(self): def __init__(self):
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge) self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge,request_type = 'topic')
async def identify_topic_llm(self, text: str) -> Optional[List[str]]: async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
"""识别消息主题,返回主题列表""" """识别消息主题,返回主题列表"""

View File

@@ -54,7 +54,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
async def get_embedding(text): async def get_embedding(text):
"""获取文本的embedding向量""" """获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding) llm = LLM_request(model=global_config.embedding,request_type = 'embedding')
# return llm.get_embedding_sync(text) # return llm.get_embedding_sync(text)
return await llm.get_embedding(text) return await llm.get_embedding(text)

View File

@@ -37,7 +37,7 @@ class ImageManager:
self._ensure_description_collection() self._ensure_description_collection()
self._ensure_image_dir() self._ensure_image_dir()
self._initialized = True self._initialized = True
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000) self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000,request_type = 'image')
def _ensure_image_dir(self): def _ensure_image_dir(self):
"""确保图像存储目录存在""" """确保图像存储目录存在"""

View File

@@ -156,8 +156,8 @@ class Memory_graph:
class Hippocampus: class Hippocampus:
def __init__(self, memory_graph: Memory_graph): def __init__(self, memory_graph: Memory_graph):
self.memory_graph = memory_graph self.memory_graph = memory_graph
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, temperature=0.5) self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, temperature=0.5,request_type = 'topic')
self.llm_summary_by_topic = LLM_request(model=global_config.llm_summary_by_topic, temperature=0.5) self.llm_summary_by_topic = LLM_request(model=global_config.llm_summary_by_topic, temperature=0.5,request_type = 'topic')
def get_all_node_names(self) -> list: def get_all_node_names(self) -> list:
"""获取记忆图中所有节点的名字列表 """获取记忆图中所有节点的名字列表

View File

@@ -49,6 +49,12 @@ class LLM_request:
# 获取数据库实例 # 获取数据库实例
self._init_database() self._init_database()
# 从 kwargs 中提取 request_type如果没有提供则默认为 "default"
self.request_type = kwargs.pop("request_type", "default")
@staticmethod @staticmethod
def _init_database(): def _init_database():
"""初始化数据库集合""" """初始化数据库集合"""
@@ -67,7 +73,7 @@ class LLM_request:
completion_tokens: int, completion_tokens: int,
total_tokens: int, total_tokens: int,
user_id: str = "system", user_id: str = "system",
request_type: str = "chat", request_type: str = None,
endpoint: str = "/chat/completions", endpoint: str = "/chat/completions",
): ):
"""记录模型使用情况到数据库 """记录模型使用情况到数据库
@@ -76,9 +82,13 @@ class LLM_request:
completion_tokens: 输出token数 completion_tokens: 输出token数
total_tokens: 总token数 total_tokens: 总token数
user_id: 用户ID默认为system user_id: 用户ID默认为system
request_type: 请求类型(chat/embedding/image) request_type: 请求类型(chat/embedding/image/topic/schedule)
endpoint: API端点 endpoint: API端点
""" """
# 如果 request_type 为 None则使用实例变量中的值
if request_type is None:
request_type = self.request_type
try: try:
usage_data = { usage_data = {
"model_name": self.model_name, "model_name": self.model_name,
@@ -128,7 +138,7 @@ class LLM_request:
retry_policy: dict = None, retry_policy: dict = None,
response_handler: callable = None, response_handler: callable = None,
user_id: str = "system", user_id: str = "system",
request_type: str = "chat", request_type: str = None,
): ):
"""统一请求执行入口 """统一请求执行入口
Args: Args:
@@ -142,6 +152,10 @@ class LLM_request:
user_id: 用户ID user_id: 用户ID
request_type: 请求类型 request_type: 请求类型
""" """
if request_type is None:
request_type = self.request_type
# 合并重试策略 # 合并重试策略
default_retry = { default_retry = {
"max_retries": 3, "max_retries": 3,
@@ -441,7 +455,7 @@ class LLM_request:
return payload return payload
def _default_response_handler( def _default_response_handler(
self, result: dict, user_id: str = "system", request_type: str = "chat", endpoint: str = "/chat/completions" self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
) -> Tuple: ) -> Tuple:
"""默认响应解析""" """默认响应解析"""
if "choices" in result and result["choices"]: if "choices" in result and result["choices"]:
@@ -465,7 +479,7 @@ class LLM_request:
completion_tokens=completion_tokens, completion_tokens=completion_tokens,
total_tokens=total_tokens, total_tokens=total_tokens,
user_id=user_id, user_id=user_id,
request_type=request_type, request_type = request_type if request_type is not None else self.request_type,
endpoint=endpoint, endpoint=endpoint,
) )
@@ -538,6 +552,22 @@ class LLM_request:
def embedding_handler(result): def embedding_handler(result):
"""处理响应""" """处理响应"""
if "data" in result and len(result["data"]) > 0: if "data" in result and len(result["data"]) > 0:
# 提取 token 使用信息
usage = result.get("usage", {})
if usage:
prompt_tokens = usage.get("prompt_tokens", 0)
completion_tokens = usage.get("completion_tokens", 0)
total_tokens = usage.get("total_tokens", 0)
# 记录 token 使用情况
self._record_usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
user_id="system", # 可以根据需要修改 user_id
request_type="embedding", # 请求类型为 embedding
endpoint="/embeddings" # API 端点
)
return result["data"][0].get("embedding", None)
return result["data"][0].get("embedding", None) return result["data"][0].get("embedding", None)
return None return None

View File

@@ -23,7 +23,7 @@ class ScheduleGenerator:
def __init__(self): def __init__(self):
# 根据global_config.llm_normal这一字典配置指定模型 # 根据global_config.llm_normal这一字典配置指定模型
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9) # self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9)
self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9) self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9,request_type = 'scheduler')
self.today_schedule_text = "" self.today_schedule_text = ""
self.today_schedule = {} self.today_schedule = {}
self.tomorrow_schedule_text = "" self.tomorrow_schedule_text = ""