v0.4.1
修复了数据库命名问题 修复了嵌入模型未定义问题
This commit is contained in:
26
.env
26
.env
@@ -1,26 +1,2 @@
|
||||
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
|
||||
ENVIRONMENT=prod
|
||||
# HOST=127.0.0.1
|
||||
# PORT=8080
|
||||
|
||||
# COMMAND_START=["/"]
|
||||
|
||||
# # 插件配置
|
||||
# PLUGINS=["src2.plugins.chat"]
|
||||
|
||||
# # 默认配置
|
||||
# MONGODB_HOST=127.0.0.1
|
||||
# MONGODB_PORT=27017
|
||||
# DATABASE_NAME=MegBot
|
||||
|
||||
# MONGODB_USERNAME = "" # 默认空值
|
||||
# MONGODB_PASSWORD = "" # 默认空值
|
||||
# MONGODB_AUTH_SOURCE = "" # 默认空值
|
||||
|
||||
# #key and url
|
||||
# CHAT_ANY_WHERE_KEY=
|
||||
# SILICONFLOW_KEY=
|
||||
# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||
# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||
# DEEP_SEEK_KEY=
|
||||
# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||
ENVIRONMENT=.dev
|
||||
@@ -1,8 +1,6 @@
|
||||
HOST=127.0.0.1
|
||||
PORT=8080
|
||||
|
||||
COMMAND_START=["/"]
|
||||
|
||||
# 插件配置
|
||||
PLUGINS=["src2.plugins.chat"]
|
||||
|
||||
@@ -16,11 +14,11 @@ MONGODB_PASSWORD = "" # 默认空值
|
||||
MONGODB_AUTH_SOURCE = "" # 默认空值
|
||||
|
||||
#key and url
|
||||
|
||||
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||
|
||||
#定义你要用的api的base_url
|
||||
DEEP_SEEK_KEY=
|
||||
CHAT_ANY_WHERE_KEY=
|
||||
SILICONFLOW_KEY=
|
||||
11
bot.py
11
bot.py
@@ -15,25 +15,22 @@ for i, char in enumerate(text):
|
||||
print(rainbow_text)
|
||||
'''彩蛋'''
|
||||
|
||||
# 首先加载基础环境变量
|
||||
# 首先加载基础环境变量.env
|
||||
if os.path.exists(".env"):
|
||||
load_dotenv(".env")
|
||||
logger.success("成功加载基础环境变量配置")
|
||||
else:
|
||||
logger.error("基础环境变量配置文件 .env 不存在")
|
||||
exit(1)
|
||||
# 根据 ENVIRONMENT 加载对应的环境配置
|
||||
env = os.getenv("ENVIRONMENT")
|
||||
env_file = f".env.{env}"
|
||||
|
||||
if env_file == ".env.dev" and os.path.exists(env_file):
|
||||
if os.path.exists(".env.dev"):
|
||||
logger.success("加载开发环境变量配置")
|
||||
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
||||
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
|
||||
elif os.path.exists(".env.prod"):
|
||||
logger.success("加载环境变量配置")
|
||||
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
||||
else:
|
||||
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
||||
logger.error(f".env对应的环境配置文件不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
||||
exit(1)
|
||||
|
||||
# 获取所有环境变量
|
||||
|
||||
@@ -17,12 +17,12 @@ driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
Database.initialize(
|
||||
host= config.mongodb_host,
|
||||
port= int(config.mongodb_port),
|
||||
db_name= config.database_name,
|
||||
username= config.mongodb_username,
|
||||
password= config.mongodb_password,
|
||||
auth_source= config.mongodb_auth_source
|
||||
host= config.MONGODB_HOST,
|
||||
port= int(config.MONGODB_PORT),
|
||||
db_name= config.DATABASE_NAME,
|
||||
username= config.MONGODB_USERNAME,
|
||||
password= config.MONGODB_PASSWORD,
|
||||
auth_source= config.MONGODB_AUTH_SOURCE
|
||||
)
|
||||
print("\033[1;32m[初始化数据库完成]\033[0m")
|
||||
|
||||
|
||||
@@ -116,6 +116,9 @@ class BotConfig:
|
||||
|
||||
if "vlm" in model_config:
|
||||
config.vlm = model_config["vlm"]
|
||||
|
||||
if "embedding" in model_config:
|
||||
config.embedding = model_config["embedding"]
|
||||
|
||||
# 消息配置
|
||||
if "message" in toml_dict:
|
||||
@@ -152,31 +155,13 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
|
||||
if not os.path.exists(bot_config_path):
|
||||
# 如果开发环境配置文件不存在,则使用默认配置文件
|
||||
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
||||
logger.info("使用默认配置文件")
|
||||
logger.info("使用bot配置文件")
|
||||
else:
|
||||
logger.info("已找到开发环境配置文件")
|
||||
logger.info("已找到开发bot配置文件")
|
||||
|
||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""机器人配置类"""
|
||||
# 基础配置
|
||||
SILICONFLOW_API_KEY: str = None
|
||||
SILICONFLOW_BASE_URL: str = None
|
||||
DEEP_SEEK_API_KEY: str = None
|
||||
DEEP_SEEK_BASE_URL: str = None
|
||||
|
||||
llm_config = LLMConfig()
|
||||
config = get_driver().config
|
||||
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
|
||||
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
|
||||
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
|
||||
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
|
||||
|
||||
|
||||
if not global_config.enable_advance_output:
|
||||
# logger.remove()
|
||||
pass
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Dict
|
||||
from collections import Counter
|
||||
import math
|
||||
from nonebot import get_driver
|
||||
from ..models.utils_model import LLM_request
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
@@ -64,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
|
||||
return False
|
||||
|
||||
def get_embedding(text):
|
||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
payload = {
|
||||
"model": "BAAI/bge-m3",
|
||||
"input": text,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {config.siliconflow_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"API请求失败: {response.status_code}")
|
||||
print(f"错误信息: {response.text}")
|
||||
return None
|
||||
|
||||
return response.json()['data'][0]['embedding']
|
||||
"""获取文本的embedding向量"""
|
||||
llm = LLM_request(model=global_config.embedding)
|
||||
return llm.get_embedding_sync(text)
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
dot_product = np.dot(v1, v2)
|
||||
|
||||
@@ -17,12 +17,12 @@ from src.plugins.chat.config import llm_config
|
||||
|
||||
# 直接配置数据库连接信息
|
||||
Database.initialize(
|
||||
host= config.mongodb_host,
|
||||
port= int(config.mongodb_port),
|
||||
db_name= config.database_name,
|
||||
username= config.mongodb_username,
|
||||
password= config.mongodb_password,
|
||||
auth_source=config.mongodb_auth_source
|
||||
host= config.MONGODB_HOST,
|
||||
port= int(config.MONGODB_PORT),
|
||||
db_name= config.DATABASE_NAME,
|
||||
username= config.MONGODB_USERNAME,
|
||||
password= config.MONGODB_PASSWORD,
|
||||
auth_source=config.MONGODB_AUTH_SOURCE
|
||||
)
|
||||
|
||||
class KnowledgeLibrary:
|
||||
|
||||
@@ -66,7 +66,7 @@ class LLMModel:
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
@@ -259,12 +259,12 @@ config = driver.config
|
||||
start_time = time.time()
|
||||
|
||||
Database.initialize(
|
||||
host= config.mongodb_host,
|
||||
port= int(config.mongodb_port),
|
||||
db_name= config.database_name,
|
||||
username= config.mongodb_username,
|
||||
password= config.mongodb_password,
|
||||
auth_source=config.mongodb_auth_source
|
||||
host= config.MONGODB_HOST,
|
||||
port= config.MONGODB_PORT,
|
||||
db_name= config.DATABASE_NAME,
|
||||
username= config.MONGODB_USERNAME,
|
||||
password= config.MONGODB_PASSWORD,
|
||||
auth_source=config.MONGODB_AUTH_SOURCE
|
||||
)
|
||||
#创建记忆图
|
||||
memory_graph = Memory_graph()
|
||||
|
||||
@@ -9,7 +9,7 @@ driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
class LLM_request:
|
||||
def __init__(self, model = global_config.llm_normal,**kwargs):
|
||||
def __init__(self, model ,**kwargs):
|
||||
# 将大写的配置键转换为小写并从config中获取实际值
|
||||
try:
|
||||
self.api_key = getattr(config, model["key"])
|
||||
@@ -61,7 +61,7 @@ class LLM_request:
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
return f"请求失败: {str(e)}", ""
|
||||
@@ -126,7 +126,7 @@ class LLM_request:
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
print(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
return f"请求失败: {str(e)}", ""
|
||||
@@ -191,9 +191,119 @@ class LLM_request:
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1: # 如果还有重试机会
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
print(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
return f"请求失败: {str(e)}", ""
|
||||
|
||||
return "达到最大重试次数,请求仍然失败", ""
|
||||
|
||||
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
||||
"""同步方法:获取文本的embedding向量
|
||||
|
||||
Args:
|
||||
text: 需要获取embedding的文本
|
||||
model: 使用的模型名称,默认为"BAAI/bge-m3"
|
||||
|
||||
Returns:
|
||||
list: embedding向量,如果失败则返回None
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"input": text,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
|
||||
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||
|
||||
max_retries = 2
|
||||
base_wait_time = 6
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||
|
||||
if response.status_code == 429:
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
time.sleep(wait_time)
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
if 'data' in result and len(result['data']) > 0:
|
||||
return result['data'][0]['embedding']
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1:
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
time.sleep(wait_time)
|
||||
else:
|
||||
print(f"embedding请求失败: {str(e)}")
|
||||
return None
|
||||
|
||||
print("达到最大重试次数,embedding请求仍然失败")
|
||||
return None
|
||||
|
||||
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
||||
"""异步方法:获取文本的embedding向量
|
||||
|
||||
Args:
|
||||
text: 需要获取embedding的文本
|
||||
model: 使用的模型名称,默认为"BAAI/bge-m3"
|
||||
|
||||
Returns:
|
||||
list: embedding向量,如果失败则返回None
|
||||
"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"input": text,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
|
||||
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||
|
||||
max_retries = 3
|
||||
base_wait_time = 15
|
||||
|
||||
for retry in range(max_retries):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(api_url, headers=headers, json=data) as response:
|
||||
if response.status == 429:
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
result = await response.json()
|
||||
if 'data' in result and len(result['data']) > 0:
|
||||
return result['data'][0]['embedding']
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
if retry < max_retries - 1:
|
||||
wait_time = base_wait_time * (2 ** retry)
|
||||
print(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||
await asyncio.sleep(wait_time)
|
||||
else:
|
||||
print(f"embedding请求失败: {str(e)}")
|
||||
return None
|
||||
|
||||
print("达到最大重试次数,embedding请求仍然失败")
|
||||
return None
|
||||
|
||||
@@ -11,12 +11,12 @@ config = driver.config
|
||||
|
||||
|
||||
Database.initialize(
|
||||
host= config.mongodb_host,
|
||||
port= int(config.mongodb_port),
|
||||
db_name= config.database_name,
|
||||
username= config.mongodb_username,
|
||||
password= config.mongodb_password,
|
||||
auth_source=config.mongodb_auth_source
|
||||
host= config.MONGODB_HOST,
|
||||
port= int(config.MONGODB_PORT),
|
||||
db_name= config.DATABASE_NAME,
|
||||
username= config.MONGODB_USERNAME,
|
||||
password= config.MONGODB_PASSWORD,
|
||||
auth_source=config.MONGODB_AUTH_SOURCE
|
||||
)
|
||||
|
||||
class ScheduleGenerator:
|
||||
|
||||
Reference in New Issue
Block a user