Merge remote-tracking branch 'upstream/debug' into feature
This commit is contained in:
@@ -235,6 +235,7 @@ class EmojiManager:
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}")
|
||||
return "skip"
|
||||
|
||||
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
|
||||
return "skip" # 默认标签
|
||||
@@ -249,12 +250,20 @@ class EmojiManager:
|
||||
files_to_process = [f for f in os.listdir(emoji_dir) if f.endswith('.jpg')]
|
||||
|
||||
for filename in files_to_process:
|
||||
image_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
# 检查文件大小
|
||||
file_size = os.path.getsize(image_path)
|
||||
if file_size > 5 * 1024 * 1024: # 5MB
|
||||
print(f"\033[1;33m[警告]\033[0m 表情包文件过大 ({file_size/1024/1024:.2f}MB),删除: {filename}")
|
||||
os.remove(image_path)
|
||||
continue
|
||||
|
||||
# 检查是否已经注册过
|
||||
existing_emoji = self.db.db['emoji'].find_one({'filename': filename})
|
||||
if existing_emoji:
|
||||
continue
|
||||
|
||||
image_path = os.path.join(emoji_dir, filename)
|
||||
# 读取图片数据
|
||||
with open(image_path, 'rb') as f:
|
||||
image_data = f.read()
|
||||
|
||||
@@ -72,7 +72,7 @@ class Message:
|
||||
#将详细翻译为详细可读文本
|
||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
|
||||
try:
|
||||
name = f"[({self.user_id}){self.user_nickname}]{self.user_cardname}"
|
||||
name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
|
||||
except:
|
||||
name = self.user_nickname or f"用户{self.user_id}"
|
||||
content = self.processed_plain_text
|
||||
|
||||
@@ -145,7 +145,8 @@ class MessageManager:
|
||||
|
||||
async def process_group_messages(self, group_id: int):
|
||||
"""处理群消息"""
|
||||
print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
|
||||
# if int(time.time() / 3) == time.time() / 3:
|
||||
# print(f"\033[1;34m[调试]\033[0m 开始处理群{group_id}的消息")
|
||||
container = self.get_container(group_id)
|
||||
if container.has_messages():
|
||||
#最早的对象,可能是思考消息,也可能是发送消息
|
||||
|
||||
@@ -6,32 +6,26 @@ import os
|
||||
from ...common.database import Database
|
||||
import zlib # 用于 CRC32
|
||||
import base64
|
||||
from .config import global_config
|
||||
from nonebot import get_driver
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
def storage_image(image_data: bytes,type: str, max_size: int = 200) -> bytes:
|
||||
if type == 'image':
|
||||
return storage_compress_image(image_data, max_size)
|
||||
elif type == 'emoji':
|
||||
return storage_emoji(image_data)
|
||||
else:
|
||||
raise ValueError(f"不支持的图片类型: {type}")
|
||||
|
||||
|
||||
def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
||||
def storage_compress_image(base64_data: str, max_size: int = 200) -> str:
|
||||
"""
|
||||
压缩图片到指定大小(单位:KB)并在数据库中记录图片信息
|
||||
压缩base64格式的图片到指定大小(单位:KB)并在数据库中记录图片信息
|
||||
Args:
|
||||
image_data: 图片字节数据
|
||||
group_id: 群组ID
|
||||
user_id: 用户ID
|
||||
base64_data: base64编码的图片数据
|
||||
max_size: 最大文件大小(KB)
|
||||
Returns:
|
||||
str: 压缩后的base64图片数据
|
||||
"""
|
||||
try:
|
||||
# 将base64转换为字节数据
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# 使用 CRC32 计算哈希值
|
||||
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
|
||||
|
||||
@@ -41,11 +35,11 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
||||
|
||||
# 连接数据库
|
||||
db = Database(
|
||||
host= config.mongodb_host,
|
||||
port= int(config.mongodb_port),
|
||||
db_name= config.database_name,
|
||||
username= config.mongodb_username,
|
||||
password= config.mongodb_password,
|
||||
host=config.mongodb_host,
|
||||
port=int(config.mongodb_port),
|
||||
db_name=config.database_name,
|
||||
username=config.mongodb_username,
|
||||
password=config.mongodb_password,
|
||||
auth_source=config.mongodb_auth_source
|
||||
)
|
||||
|
||||
@@ -55,14 +49,14 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
||||
|
||||
if existing_image:
|
||||
print(f"\033[1;33m[提示]\033[0m 发现重复图片,使用已存在的文件: {existing_image['path']}")
|
||||
return image_data
|
||||
return base64_data
|
||||
|
||||
# 将字节数据转换为图片对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 如果是动图,直接返回原图
|
||||
if getattr(img, 'is_animated', False):
|
||||
return image_data
|
||||
return base64_data
|
||||
|
||||
# 计算当前大小(KB)
|
||||
current_size = len(image_data) / 1024
|
||||
@@ -127,14 +121,16 @@ def storage_compress_image(image_data: bytes, max_size: int = 200) -> bytes:
|
||||
|
||||
except Exception as db_error:
|
||||
print(f"\033[1;31m[错误]\033[0m 数据库操作失败: {str(db_error)}")
|
||||
|
||||
return compressed_data
|
||||
|
||||
# 将压缩后的数据转换为base64
|
||||
compressed_base64 = base64.b64encode(compressed_data).decode('utf-8')
|
||||
return compressed_base64
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return image_data
|
||||
return base64_data
|
||||
|
||||
def storage_emoji(image_data: bytes) -> bytes:
|
||||
"""
|
||||
@@ -215,4 +211,48 @@ def storage_image(image_data: bytes) -> bytes:
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 保存图片失败: {str(e)}")
|
||||
return image_data
|
||||
return image_data
|
||||
|
||||
def compress_base64_image_by_scale(base64_data: str, scale: float = 0.5) -> str:
|
||||
"""按比例压缩base64格式的图片
|
||||
Args:
|
||||
base64_data: base64编码的图片数据
|
||||
scale: 压缩比例(0-1之间的浮点数)
|
||||
Returns:
|
||||
str: 压缩后的base64图片数据
|
||||
"""
|
||||
try:
|
||||
# 将base64转换为字节数据
|
||||
image_data = base64.b64decode(base64_data)
|
||||
|
||||
# 将字节数据转换为图片对象
|
||||
img = Image.open(io.BytesIO(image_data))
|
||||
|
||||
# 如果是动图,直接返回原图
|
||||
if getattr(img, 'is_animated', False):
|
||||
return base64_data
|
||||
|
||||
# 计算新的尺寸
|
||||
new_width = int(img.width * scale)
|
||||
new_height = int(img.height * scale)
|
||||
|
||||
# 缩放图片
|
||||
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为RGB模式(去除透明通道)
|
||||
if img.mode in ('RGBA', 'P'):
|
||||
img = img.convert('RGB')
|
||||
|
||||
# 保存压缩后的图片
|
||||
output = io.BytesIO()
|
||||
img.save(output, format='JPEG', quality=85, optimize=True)
|
||||
compressed_data = output.getvalue()
|
||||
|
||||
# 转换回base64
|
||||
return base64.b64encode(compressed_data).decode('utf-8')
|
||||
|
||||
except Exception as e:
|
||||
print(f"\033[1;31m[错误]\033[0m 压缩图片失败: {str(e)}")
|
||||
import traceback
|
||||
print(traceback.format_exc())
|
||||
return base64_data
|
||||
@@ -11,7 +11,6 @@ class WillingManager:
|
||||
while True:
|
||||
await asyncio.sleep(3)
|
||||
for group_id in self.group_reply_willing:
|
||||
# 每分钟衰减10%的回复意愿
|
||||
self.group_reply_willing[group_id] = max(0, self.group_reply_willing[group_id] * 0.6)
|
||||
|
||||
def get_willing(self, group_id: int) -> float:
|
||||
@@ -26,13 +25,7 @@ class WillingManager:
|
||||
"""改变指定群组的回复意愿并返回回复概率"""
|
||||
current_willing = self.group_reply_willing.get(group_id, 0)
|
||||
|
||||
print(f"初始意愿: {current_willing}")
|
||||
|
||||
# if topic and current_willing < 1:
|
||||
# current_willing += 0.2
|
||||
# elif topic:
|
||||
# current_willing += 0.05
|
||||
|
||||
# print(f"初始意愿: {current_willing}")
|
||||
if is_mentioned_bot and current_willing < 1.0:
|
||||
current_willing += 0.9
|
||||
print(f"被提及, 当前意愿: {current_willing}")
|
||||
@@ -44,9 +37,9 @@ class WillingManager:
|
||||
current_willing *= 0.15
|
||||
print(f"表情包, 当前意愿: {current_willing}")
|
||||
|
||||
if interested_rate > 0.6:
|
||||
if interested_rate > 0.65:
|
||||
print(f"兴趣度: {interested_rate}, 当前意愿: {current_willing}")
|
||||
current_willing += interested_rate-0.45
|
||||
current_willing += interested_rate-0.5
|
||||
|
||||
self.group_reply_willing[group_id] = min(current_willing, 3.0)
|
||||
|
||||
@@ -57,11 +50,10 @@ class WillingManager:
|
||||
|
||||
if group_id in config.talk_frequency_down_groups:
|
||||
reply_probability = reply_probability / 3.5
|
||||
|
||||
# if is_mentioned_bot and user_id == int(1026294844):
|
||||
# reply_probability = 1
|
||||
|
||||
reply_probability = min(reply_probability, 1)
|
||||
if reply_probability < 0.1:
|
||||
reply_probability = 0.1
|
||||
return reply_probability
|
||||
|
||||
def change_reply_willing_sent(self, group_id: int):
|
||||
|
||||
@@ -3,26 +3,28 @@ import sys
|
||||
import numpy as np
|
||||
import requests
|
||||
import time
|
||||
from nonebot import get_driver
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
root_path = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
sys.path.append(root_path)
|
||||
|
||||
from src.common.database import Database
|
||||
from src.plugins.chat.config import llm_config
|
||||
# 加载根目录下的env.edv文件
|
||||
env_path = os.path.join(root_path, ".env.dev")
|
||||
if not os.path.exists(env_path):
|
||||
raise FileNotFoundError(f"配置文件不存在: {env_path}")
|
||||
load_dotenv(env_path)
|
||||
|
||||
# 直接配置数据库连接信息
|
||||
from src.common.database import Database
|
||||
|
||||
# 从环境变量获取配置
|
||||
Database.initialize(
|
||||
host= config.MONGODB_HOST,
|
||||
port= int(config.MONGODB_PORT),
|
||||
db_name= config.DATABASE_NAME,
|
||||
username= config.MONGODB_USERNAME,
|
||||
password= config.MONGODB_PASSWORD,
|
||||
auth_source=config.MONGODB_AUTH_SOURCE
|
||||
host=os.getenv("MONGODB_HOST", "localhost"),
|
||||
port=int(os.getenv("MONGODB_PORT", "27017")),
|
||||
db_name=os.getenv("DATABASE_NAME", "maimai"),
|
||||
username=os.getenv("MONGODB_USERNAME"),
|
||||
password=os.getenv("MONGODB_PASSWORD"),
|
||||
auth_source=os.getenv("MONGODB_AUTH_SOURCE", "admin")
|
||||
)
|
||||
|
||||
class KnowledgeLibrary:
|
||||
@@ -30,6 +32,9 @@ class KnowledgeLibrary:
|
||||
self.db = Database.get_instance()
|
||||
self.raw_info_dir = "data/raw_info"
|
||||
self._ensure_dirs()
|
||||
self.api_key = os.getenv("SILICONFLOW_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("SILICONFLOW_API_KEY 环境变量未设置")
|
||||
|
||||
def _ensure_dirs(self):
|
||||
"""确保必要的目录存在"""
|
||||
@@ -44,7 +49,7 @@ class KnowledgeLibrary:
|
||||
"encoding_format": "float"
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {llm_config.SILICONFLOW_API_KEY}",
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
@@ -74,7 +79,7 @@ class KnowledgeLibrary:
|
||||
content = f.read()
|
||||
|
||||
# 按1024字符分段
|
||||
segments = [content[i:i+300] for i in range(0, len(content), 300)]
|
||||
segments = [content[i:i+400] for i in range(0, len(content), 400)]
|
||||
|
||||
# 处理每个分段
|
||||
for segment in segments:
|
||||
|
||||
@@ -2,14 +2,17 @@ import aiohttp
|
||||
import asyncio
|
||||
import requests
|
||||
import time
|
||||
import re
|
||||
from typing import Tuple, Union
|
||||
from nonebot import get_driver
|
||||
from loguru import logger
|
||||
from ..chat.config import global_config
|
||||
from ..chat.utils_image import compress_base64_image_by_scale
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
|
||||
class LLM_request:
|
||||
def __init__(self, model, **kwargs):
|
||||
# 将大写的配置键转换为小写并从config中获取实际值
|
||||
@@ -28,21 +31,21 @@ class LLM_request:
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
**self.params
|
||||
}
|
||||
|
||||
|
||||
# 发送请求到完整的chat/completions端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -52,16 +55,16 @@ class LLM_request:
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
@@ -70,7 +73,7 @@ class LLM_request:
|
||||
else:
|
||||
logger.critical(f"请求失败: {str(e)}", exc_info=True)
|
||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
||||
|
||||
@@ -80,39 +83,45 @@ class LLM_request:
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
||||
def build_request_data(img_base64: str):
|
||||
return {
|
||||
"model": self.model_name,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": prompt
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{img_base64}"
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
**self.params
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
**self.params
|
||||
}
|
||||
|
||||
|
||||
# 发送请求到完整的chat/completions端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
current_image_base64 = image_base64
|
||||
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
data = build_request_data(current_image_base64)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status == 429:
|
||||
@@ -120,16 +129,28 @@ class LLM_request:
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
elif response.status == 413:
|
||||
logger.warning("图片太大(413),尝试压缩...")
|
||||
current_image_base64 = compress_base64_image_by_scale(current_image_base64)
|
||||
continue
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
|
||||
result = await response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
message = result["choices"][0]["message"]
|
||||
content = message.get("content", "")
|
||||
think_match = None
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if not reasoning_content:
|
||||
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||
if think_match:
|
||||
reasoning_content = think_match.group(1).strip()
|
||||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
@@ -138,7 +159,7 @@ class LLM_request:
|
||||
else:
|
||||
logger.critical(f"请求失败: {str(e)}", exc_info=True)
|
||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
||||
|
||||
@@ -148,7 +169,7 @@ class LLM_request:
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
# 构建请求体
|
||||
data = {
|
||||
"model": self.model_name,
|
||||
@@ -171,33 +192,40 @@ class LLM_request:
|
||||
],
|
||||
**self.params
|
||||
}
|
||||
|
||||
|
||||
# 发送请求到完整的chat/completions端点
|
||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
||||
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||
|
||||
|
||||
max_retries = 2
|
||||
base_wait_time = 6
|
||||
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||
|
||||
|
||||
if response.status_code == 429:
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
|
||||
response.raise_for_status() # 检查其他响应状态
|
||||
|
||||
|
||||
result = response.json()
|
||||
if "choices" in result and len(result["choices"]) > 0:
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
reasoning_content = result["choices"][0]["message"].get("reasoning_content", "")
|
||||
message = result["choices"][0]["message"]
|
||||
content = message.get("content", "")
|
||||
think_match = None
|
||||
reasoning_content = message.get("reasoning_content", "")
|
||||
if not reasoning_content:
|
||||
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||
if think_match:
|
||||
reasoning_content = think_match.group(1).strip()
|
||||
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||
return content, reasoning_content
|
||||
return "没有返回结果", ""
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
@@ -206,7 +234,7 @@ class LLM_request:
|
||||
else:
|
||||
logger.critical(f"请求失败: {str(e)}", exc_info=True)
|
||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||
|
||||
|
||||
logger.error("达到最大重试次数,请求仍然失败")
|
||||
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
||||
|
||||
@@ -224,36 +252,36 @@ class LLM_request:
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"input": text,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
|
||||
|
||||
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||
|
||||
|
||||
max_retries = 2
|
||||
base_wait_time = 6
|
||||
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||
|
||||
|
||||
if response.status_code == 429:
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
result = response.json()
|
||||
if 'data' in result and len(result['data']) > 0:
|
||||
return result['data'][0]['embedding']
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1:
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
@@ -262,7 +290,7 @@ class LLM_request:
|
||||
else:
|
||||
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
logger.error("达到最大重试次数,embedding请求仍然失败")
|
||||
return None
|
||||
|
||||
@@ -280,19 +308,19 @@ class LLM_request:
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"input": text,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
|
||||
|
||||
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||
logger.info(f"发送请求到URL: {api_url}") # 记录请求的URL
|
||||
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
@@ -302,14 +330,14 @@ class LLM_request:
|
||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
result = await response.json()
|
||||
if 'data' in result and len(result['data']) > 0:
|
||||
return result['data'][0]['embedding']
|
||||
return None
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1:
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
@@ -318,6 +346,6 @@ class LLM_request:
|
||||
else:
|
||||
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
|
||||
logger.error("达到最大重试次数,embedding请求仍然失败")
|
||||
return None
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user