@@ -38,9 +38,9 @@ class EmojiManager:
|
||||
|
||||
def __init__(self):
|
||||
self._scan_task = None
|
||||
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
|
||||
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000,request_type = 'image')
|
||||
self.llm_emotion_judge = LLM_request(
|
||||
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8
|
||||
model=global_config.llm_emotion_judge, max_tokens=600, temperature=0.8,request_type = 'image'
|
||||
) # 更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||
|
||||
def _ensure_emoji_dir(self):
|
||||
|
||||
@@ -14,7 +14,7 @@ config = driver.config
|
||||
|
||||
class TopicIdentifier:
|
||||
def __init__(self):
|
||||
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge)
|
||||
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge,request_type = 'topic')
|
||||
|
||||
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
|
||||
"""识别消息主题,返回主题列表"""
|
||||
|
||||
@@ -54,7 +54,7 @@ def is_mentioned_bot_in_message(message: MessageRecv) -> bool:
|
||||
|
||||
async def get_embedding(text):
|
||||
"""获取文本的embedding向量"""
|
||||
llm = LLM_request(model=global_config.embedding)
|
||||
llm = LLM_request(model=global_config.embedding,request_type = 'embedding')
|
||||
# return llm.get_embedding_sync(text)
|
||||
return await llm.get_embedding(text)
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ class ImageManager:
|
||||
self._ensure_description_collection()
|
||||
self._ensure_image_dir()
|
||||
self._initialized = True
|
||||
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000)
|
||||
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=1000,request_type = 'image')
|
||||
|
||||
def _ensure_image_dir(self):
|
||||
"""确保图像存储目录存在"""
|
||||
|
||||
@@ -156,8 +156,8 @@ class Memory_graph:
|
||||
class Hippocampus:
|
||||
def __init__(self, memory_graph: Memory_graph):
|
||||
self.memory_graph = memory_graph
|
||||
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, temperature=0.5)
|
||||
self.llm_summary_by_topic = LLM_request(model=global_config.llm_summary_by_topic, temperature=0.5)
|
||||
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge, temperature=0.5,request_type = 'topic')
|
||||
self.llm_summary_by_topic = LLM_request(model=global_config.llm_summary_by_topic, temperature=0.5,request_type = 'topic')
|
||||
|
||||
def get_all_node_names(self) -> list:
|
||||
"""获取记忆图中所有节点的名字列表
|
||||
|
||||
@@ -49,6 +49,12 @@ class LLM_request:
|
||||
# 获取数据库实例
|
||||
self._init_database()
|
||||
|
||||
# 从 kwargs 中提取 request_type,如果没有提供则默认为 "default"
|
||||
self.request_type = kwargs.pop("request_type", "default")
|
||||
|
||||
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _init_database():
|
||||
"""初始化数据库集合"""
|
||||
@@ -67,7 +73,7 @@ class LLM_request:
|
||||
completion_tokens: int,
|
||||
total_tokens: int,
|
||||
user_id: str = "system",
|
||||
request_type: str = "chat",
|
||||
request_type: str = None,
|
||||
endpoint: str = "/chat/completions",
|
||||
):
|
||||
"""记录模型使用情况到数据库
|
||||
@@ -76,9 +82,13 @@ class LLM_request:
|
||||
completion_tokens: 输出token数
|
||||
total_tokens: 总token数
|
||||
user_id: 用户ID,默认为system
|
||||
request_type: 请求类型(chat/embedding/image等)
|
||||
request_type: 请求类型(chat/embedding/image/topic/schedule)
|
||||
endpoint: API端点
|
||||
"""
|
||||
# 如果 request_type 为 None,则使用实例变量中的值
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
|
||||
try:
|
||||
usage_data = {
|
||||
"model_name": self.model_name,
|
||||
@@ -128,7 +138,7 @@ class LLM_request:
|
||||
retry_policy: dict = None,
|
||||
response_handler: callable = None,
|
||||
user_id: str = "system",
|
||||
request_type: str = "chat",
|
||||
request_type: str = None,
|
||||
):
|
||||
"""统一请求执行入口
|
||||
Args:
|
||||
@@ -142,6 +152,10 @@ class LLM_request:
|
||||
user_id: 用户ID
|
||||
request_type: 请求类型
|
||||
"""
|
||||
|
||||
if request_type is None:
|
||||
request_type = self.request_type
|
||||
|
||||
# 合并重试策略
|
||||
default_retry = {
|
||||
"max_retries": 3,
|
||||
@@ -441,7 +455,7 @@ class LLM_request:
|
||||
return payload
|
||||
|
||||
def _default_response_handler(
|
||||
self, result: dict, user_id: str = "system", request_type: str = "chat", endpoint: str = "/chat/completions"
|
||||
self, result: dict, user_id: str = "system", request_type: str = None, endpoint: str = "/chat/completions"
|
||||
) -> Tuple:
|
||||
"""默认响应解析"""
|
||||
if "choices" in result and result["choices"]:
|
||||
@@ -465,7 +479,7 @@ class LLM_request:
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
request_type = request_type if request_type is not None else self.request_type,
|
||||
endpoint=endpoint,
|
||||
)
|
||||
|
||||
@@ -538,6 +552,22 @@ class LLM_request:
|
||||
def embedding_handler(result):
|
||||
"""处理响应"""
|
||||
if "data" in result and len(result["data"]) > 0:
|
||||
# 提取 token 使用信息
|
||||
usage = result.get("usage", {})
|
||||
if usage:
|
||||
prompt_tokens = usage.get("prompt_tokens", 0)
|
||||
completion_tokens = usage.get("completion_tokens", 0)
|
||||
total_tokens = usage.get("total_tokens", 0)
|
||||
# 记录 token 使用情况
|
||||
self._record_usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
user_id="system", # 可以根据需要修改 user_id
|
||||
request_type="embedding", # 请求类型为 embedding
|
||||
endpoint="/embeddings" # API 端点
|
||||
)
|
||||
return result["data"][0].get("embedding", None)
|
||||
return result["data"][0].get("embedding", None)
|
||||
return None
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class ScheduleGenerator:
|
||||
def __init__(self):
|
||||
# 根据global_config.llm_normal这一字典配置指定模型
|
||||
# self.llm_scheduler = LLMModel(model = global_config.llm_normal,temperature=0.9)
|
||||
self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9)
|
||||
self.llm_scheduler = LLM_request(model=global_config.llm_normal, temperature=0.9,request_type = 'scheduler')
|
||||
self.today_schedule_text = ""
|
||||
self.today_schedule = {}
|
||||
self.tomorrow_schedule_text = ""
|
||||
|
||||
Reference in New Issue
Block a user