From 4bc222ba6faa3a071576ef396e5c1dad19f28279 Mon Sep 17 00:00:00 2001 From: Maple127667 <98679702+Maple127667@users.noreply.github.com> Date: Sun, 16 Mar 2025 23:11:32 +0800 Subject: [PATCH] =?UTF-8?q?token=E7=BB=9F=E8=AE=A1=E9=83=A8=E5=88=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/chat/emoji_manager.py | 4 +-- src/plugins/chat/topic_identifier.py | 2 +- src/plugins/chat/utils.py | 2 +- src/plugins/chat/utils_image.py | 2 +- src/plugins/memory_system/memory.py | 4 +-- src/plugins/models/utils_model.py | 40 +++++++++++++++++++--- src/plugins/schedule/schedule_generator.py | 2 +- 7 files changed, 43 insertions(+), 13 deletions(-) diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 1d0573ccb..21ec1f71c 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -38,9 +38,9 @@ class EmojiManager: def __init__(self): self._scan_task = None - self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000) + self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000,request_type = 'image') self.llm_emotion_judge = LLM_request( - model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8 + model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8,request_type = 'image' ) # 更高的温度,更少的token(后续可以根据情绪来调整温度) def _ensure_emoji_dir(self): diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py index 58069f131..71abf6bae 100644 --- a/src/plugins/chat/topic_identifier.py +++ b/src/plugins/chat/topic_identifier.py @@ -14,7 +14,7 @@ config = driver.config class TopicIdentifier: def __init__(self): - self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge) + self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge,request_type = 'topic') async def identify_topic_llm(self, text: str) -> Optional[List[str]]: """识别消息主题,返回主题列表""" diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 29f10fc20..05cc3ca06 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -54,7 +54,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> bool: async def get_embedding(text): """获取文本的embedding向量""" - llm = LLM_request(model=global_config.embedding) + llm = LLM_request(model=global_config.embedding,request_type = 'embedding') # return llm.get_embedding_sync(text) return await llm.get_embedding(text) diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 78b635df9..120aa104a 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -37,7 +37,7 @@ class ImageManager: self._ensure_description_collection() self._ensure_image_dir() self._initialized = True - self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000) + self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000,request_type = 'image') def _ensure_image_dir(self): """确保图像存储目录存在""" diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index 6660fa152..d2f77e0f8 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -156,8 +156,8 @@ class Memory_graph: class Hippocampus: def __init__(self, memory_graph: Memory_graph): self.memory_graph = memory_graph - self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, temperature=0.5) - self.llm_summary_by_topic = LLM_request(model=global_config.llm_summary_by_topic, temperature=0.5) + self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, temperature=0.5,request_type = 'topic') + self.llm_summary_by_topic = LLM_request(model=global_config.llm_summary_by_topic, temperature=0.5,request_type = 'topic') def get_all_node_names(self) -> list: """获取记忆图中所有节点的名字列表 diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 7572460f7..0764a1949 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -49,6 +49,12 @@ class LLM_request: # 获取数据库实例 self._init_database() + # 从 kwargs 中提取 request_type,如果没有提供则默认为 "default" + self.request_type = kwargs.pop("request_type", "default") + + + + @staticmethod def _init_database(): """初始化数据库集合""" @@ -67,7 +73,7 @@ class LLM_request: completion_tokens: int, total_tokens: int, user_id: str = "system", - request_type: str = "chat", + request_type: str = None, endpoint: str = "/chat/completions", ): """记录模型使用情况到数据库 @@ -76,9 +82,13 @@ class LLM_request: completion_tokens: 输出token数 total_tokens: 总token数 user_id: 用户ID,默认为system - request_type: 请求类型(chat/embedding/image等) + request_type: 请求类型(chat/embedding/image/topic/schedule) endpoint: API端点 """ + # 如果 request_type 为 None,则使用实例变量中的值 + if request_type is None: + request_type = self.request_type + try: usage_data = { "model_name": self.model_name, @@ -128,7 +138,7 @@ class LLM_request: retry_policy: dict = None, response_handler: callable = None, user_id: str = "system", - request_type: str = "chat", + request_type: str = None, ): """统一请求执行入口 Args: @@ -142,6 +152,10 @@ class LLM_request: user_id: 用户ID request_type: 请求类型 """ + + if request_type is None: + request_type = self.request_type + # 合并重试策略 default_retry = { "max_retries": 3, @@ -441,7 +455,7 @@ class LLM_request: return payload def _default_response_handler( - self, result: dict, user_id: str = "system", request_type: str = "chat", endpoint: str = "/chat/completions" + self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions" ) -> Tuple: """默认响应解析""" if "choices" in result and result["choices"]: @@ -465,7 +479,7 @@ class LLM_request: completion_tokens=completion_tokens, total_tokens=total_tokens, user_id=user_id, - request_type=request_type, + request_type = request_type if request_type is not None else self.request_type, endpoint=endpoint, ) @@ -538,6 +552,22 @@ class LLM_request: def embedding_handler(result): """处理响应""" if "data" in result and len(result["data"]) > 0: + # 提取 token 使用信息 + usage = result.get("usage", {}) + if usage: + prompt_tokens = usage.get("prompt_tokens", 0) + completion_tokens = usage.get("completion_tokens", 0) + total_tokens = usage.get("total_tokens", 0) + # 记录 token 使用情况 + self._record_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + user_id="system", # 可以根据需要修改 user_id + request_type="embedding", # 请求类型为 embedding + endpoint="/embeddings" # API 端点 + ) + return result["data"][0].get("embedding", None) return result["data"][0].get("embedding", None) return None diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index a28e24999..d35c7f11f 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -23,7 +23,7 @@ class ScheduleGenerator: def __init__(self): # 根据global_config.llm_normal这一字典配置指定模型 # self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9) - self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9) + self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9,request_type = 'scheduler') self.today_schedule_text = "" self.today_schedule = {} self.tomorrow_schedule_text = ""