v0.4.1
修复了数据库命名问题 修复了嵌入模型未定义问题
This commit is contained in:
26
.env
26
.env
@@ -1,26 +1,2 @@
|
|||||||
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
|
# 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod
|
||||||
ENVIRONMENT=prod
|
ENVIRONMENT=.dev
|
||||||
# HOST=127.0.0.1
|
|
||||||
# PORT=8080
|
|
||||||
|
|
||||||
# COMMAND_START=["/"]
|
|
||||||
|
|
||||||
# # 插件配置
|
|
||||||
# PLUGINS=["src2.plugins.chat"]
|
|
||||||
|
|
||||||
# # 默认配置
|
|
||||||
# MONGODB_HOST=127.0.0.1
|
|
||||||
# MONGODB_PORT=27017
|
|
||||||
# DATABASE_NAME=MegBot
|
|
||||||
|
|
||||||
# MONGODB_USERNAME = "" # 默认空值
|
|
||||||
# MONGODB_PASSWORD = "" # 默认空值
|
|
||||||
# MONGODB_AUTH_SOURCE = "" # 默认空值
|
|
||||||
|
|
||||||
# #key and url
|
|
||||||
# CHAT_ANY_WHERE_KEY=
|
|
||||||
# SILICONFLOW_KEY=
|
|
||||||
# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
|
||||||
# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
|
||||||
# DEEP_SEEK_KEY=
|
|
||||||
# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
|
||||||
@@ -1,8 +1,6 @@
|
|||||||
HOST=127.0.0.1
|
HOST=127.0.0.1
|
||||||
PORT=8080
|
PORT=8080
|
||||||
|
|
||||||
COMMAND_START=["/"]
|
|
||||||
|
|
||||||
# 插件配置
|
# 插件配置
|
||||||
PLUGINS=["src2.plugins.chat"]
|
PLUGINS=["src2.plugins.chat"]
|
||||||
|
|
||||||
@@ -16,11 +14,11 @@ MONGODB_PASSWORD = "" # 默认空值
|
|||||||
MONGODB_AUTH_SOURCE = "" # 默认空值
|
MONGODB_AUTH_SOURCE = "" # 默认空值
|
||||||
|
|
||||||
#key and url
|
#key and url
|
||||||
|
|
||||||
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
#定义你要用的api的base_url
|
||||||
DEEP_SEEK_KEY=
|
DEEP_SEEK_KEY=
|
||||||
CHAT_ANY_WHERE_KEY=
|
CHAT_ANY_WHERE_KEY=
|
||||||
SILICONFLOW_KEY=
|
SILICONFLOW_KEY=
|
||||||
11
bot.py
11
bot.py
@@ -15,25 +15,22 @@ for i, char in enumerate(text):
|
|||||||
print(rainbow_text)
|
print(rainbow_text)
|
||||||
'''彩蛋'''
|
'''彩蛋'''
|
||||||
|
|
||||||
# 首先加载基础环境变量
|
# 首先加载基础环境变量.env
|
||||||
if os.path.exists(".env"):
|
if os.path.exists(".env"):
|
||||||
load_dotenv(".env")
|
load_dotenv(".env")
|
||||||
logger.success("成功加载基础环境变量配置")
|
logger.success("成功加载基础环境变量配置")
|
||||||
else:
|
else:
|
||||||
logger.error("基础环境变量配置文件 .env 不存在")
|
logger.error("基础环境变量配置文件 .env 不存在")
|
||||||
exit(1)
|
exit(1)
|
||||||
# 根据 ENVIRONMENT 加载对应的环境配置
|
|
||||||
env = os.getenv("ENVIRONMENT")
|
|
||||||
env_file = f".env.{env}"
|
|
||||||
|
|
||||||
if env_file == ".env.dev" and os.path.exists(env_file):
|
if os.path.exists(".env.dev"):
|
||||||
logger.success("加载开发环境变量配置")
|
logger.success("加载开发环境变量配置")
|
||||||
load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量
|
load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
elif os.path.exists(".env.prod"):
|
elif os.path.exists(".env.prod"):
|
||||||
logger.success("加载环境变量配置")
|
logger.success("加载环境变量配置")
|
||||||
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量
|
||||||
else:
|
else:
|
||||||
logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
logger.error(f".env对应的环境配置文件不存在,请修改.env文件中的ENVIRONMENT变量为 prod.")
|
||||||
exit(1)
|
exit(1)
|
||||||
|
|
||||||
# 获取所有环境变量
|
# 获取所有环境变量
|
||||||
|
|||||||
@@ -17,12 +17,12 @@ driver = get_driver()
|
|||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= int(config.MONGODB_PORT),
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source= config.mongodb_auth_source
|
auth_source= config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
print("\033[1;32m[初始化数据库完成]\033[0m")
|
print("\033[1;32m[初始化数据库完成]\033[0m")
|
||||||
|
|
||||||
|
|||||||
@@ -117,6 +117,9 @@ class BotConfig:
|
|||||||
if "vlm" in model_config:
|
if "vlm" in model_config:
|
||||||
config.vlm = model_config["vlm"]
|
config.vlm = model_config["vlm"]
|
||||||
|
|
||||||
|
if "embedding" in model_config:
|
||||||
|
config.embedding = model_config["embedding"]
|
||||||
|
|
||||||
# 消息配置
|
# 消息配置
|
||||||
if "message" in toml_dict:
|
if "message" in toml_dict:
|
||||||
msg_config = toml_dict["message"]
|
msg_config = toml_dict["message"]
|
||||||
@@ -152,31 +155,13 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
|
|||||||
if not os.path.exists(bot_config_path):
|
if not os.path.exists(bot_config_path):
|
||||||
# 如果开发环境配置文件不存在,则使用默认配置文件
|
# 如果开发环境配置文件不存在,则使用默认配置文件
|
||||||
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
||||||
logger.info("使用默认配置文件")
|
logger.info("使用bot配置文件")
|
||||||
else:
|
else:
|
||||||
logger.info("已找到开发环境配置文件")
|
logger.info("已找到开发bot配置文件")
|
||||||
|
|
||||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LLMConfig:
|
|
||||||
"""机器人配置类"""
|
|
||||||
# 基础配置
|
|
||||||
SILICONFLOW_API_KEY: str = None
|
|
||||||
SILICONFLOW_BASE_URL: str = None
|
|
||||||
DEEP_SEEK_API_KEY: str = None
|
|
||||||
DEEP_SEEK_BASE_URL: str = None
|
|
||||||
|
|
||||||
llm_config = LLMConfig()
|
|
||||||
config = get_driver().config
|
|
||||||
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
|
|
||||||
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
|
|
||||||
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
|
|
||||||
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
|
|
||||||
|
|
||||||
|
|
||||||
if not global_config.enable_advance_output:
|
if not global_config.enable_advance_output:
|
||||||
# logger.remove()
|
# logger.remove()
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ from typing import Dict
|
|||||||
from collections import Counter
|
from collections import Counter
|
||||||
import math
|
import math
|
||||||
from nonebot import get_driver
|
from nonebot import get_driver
|
||||||
|
from ..models.utils_model import LLM_request
|
||||||
|
|
||||||
driver = get_driver()
|
driver = get_driver()
|
||||||
config = driver.config
|
config = driver.config
|
||||||
@@ -64,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def get_embedding(text):
|
def get_embedding(text):
|
||||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
"""获取文本的embedding向量"""
|
||||||
payload = {
|
llm = LLM_request(model=global_config.embedding)
|
||||||
"model": "BAAI/bge-m3",
|
return llm.get_embedding_sync(text)
|
||||||
"input": text,
|
|
||||||
"encoding_format": "float"
|
|
||||||
}
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {config.siliconflow_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
response = requests.request("POST", url, json=payload, headers=headers)
|
|
||||||
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(f"API请求失败: {response.status_code}")
|
|
||||||
print(f"错误信息: {response.text}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
return response.json()['data'][0]['embedding']
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2):
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
|
|||||||
@@ -17,12 +17,12 @@ from src.plugins.chat.config import llm_config
|
|||||||
|
|
||||||
# 直接配置数据库连接信息
|
# 直接配置数据库连接信息
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= int(config.MONGODB_PORT),
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source=config.mongodb_auth_source
|
auth_source=config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
|
|
||||||
class KnowledgeLibrary:
|
class KnowledgeLibrary:
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ class LLMModel:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|||||||
@@ -259,12 +259,12 @@ config = driver.config
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= config.MONGODB_PORT,
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source=config.mongodb_auth_source
|
auth_source=config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
#创建记忆图
|
#创建记忆图
|
||||||
memory_graph = Memory_graph()
|
memory_graph = Memory_graph()
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ driver = get_driver()
|
|||||||
config = driver.config
|
config = driver.config
|
||||||
|
|
||||||
class LLM_request:
|
class LLM_request:
|
||||||
def __init__(self, model = global_config.llm_normal,**kwargs):
|
def __init__(self, model ,**kwargs):
|
||||||
# 将大写的配置键转换为小写并从config中获取实际值
|
# 将大写的配置键转换为小写并从config中获取实际值
|
||||||
try:
|
try:
|
||||||
self.api_key = getattr(config, model["key"])
|
self.api_key = getattr(config, model["key"])
|
||||||
@@ -61,7 +61,7 @@ class LLM_request:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
@@ -126,7 +126,7 @@ class LLM_request:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
@@ -191,9 +191,119 @@ class LLM_request:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < max_retries - 1: # 如果还有重试机会
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
print(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
time.sleep(wait_time)
|
time.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
return f"请求失败: {str(e)}", ""
|
return f"请求失败: {str(e)}", ""
|
||||||
|
|
||||||
return "达到最大重试次数,请求仍然失败", ""
|
return "达到最大重试次数,请求仍然失败", ""
|
||||||
|
|
||||||
|
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
||||||
|
"""同步方法:获取文本的embedding向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 需要获取embedding的文本
|
||||||
|
model: 使用的模型名称,默认为"BAAI/bge-m3"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: embedding向量,如果失败则返回None
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"input": text,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
|
||||||
|
max_retries = 2
|
||||||
|
base_wait_time = 6
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
||||||
|
|
||||||
|
if response.status_code == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
if 'data' in result and len(result['data']) > 0:
|
||||||
|
return result['data'][0]['embedding']
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
time.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
print(f"embedding请求失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print("达到最大重试次数,embedding请求仍然失败")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
||||||
|
"""异步方法:获取文本的embedding向量
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 需要获取embedding的文本
|
||||||
|
model: 使用的模型名称,默认为"BAAI/bge-m3"
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: embedding向量,如果失败则返回None
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"input": text,
|
||||||
|
"encoding_format": "float"
|
||||||
|
}
|
||||||
|
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
||||||
|
|
||||||
|
max_retries = 3
|
||||||
|
base_wait_time = 15
|
||||||
|
|
||||||
|
for retry in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.post(api_url, headers=headers, json=data) as response:
|
||||||
|
if response.status == 429:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
continue
|
||||||
|
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = await response.json()
|
||||||
|
if 'data' in result and len(result['data']) > 0:
|
||||||
|
return result['data'][0]['embedding']
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if retry < max_retries - 1:
|
||||||
|
wait_time = base_wait_time * (2 ** retry)
|
||||||
|
print(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
|
await asyncio.sleep(wait_time)
|
||||||
|
else:
|
||||||
|
print(f"embedding请求失败: {str(e)}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
print("达到最大重试次数,embedding请求仍然失败")
|
||||||
|
return None
|
||||||
|
|||||||
@@ -11,12 +11,12 @@ config = driver.config
|
|||||||
|
|
||||||
|
|
||||||
Database.initialize(
|
Database.initialize(
|
||||||
host= config.mongodb_host,
|
host= config.MONGODB_HOST,
|
||||||
port= int(config.mongodb_port),
|
port= int(config.MONGODB_PORT),
|
||||||
db_name= config.database_name,
|
db_name= config.DATABASE_NAME,
|
||||||
username= config.mongodb_username,
|
username= config.MONGODB_USERNAME,
|
||||||
password= config.mongodb_password,
|
password= config.MONGODB_PASSWORD,
|
||||||
auth_source=config.mongodb_auth_source
|
auth_source=config.MONGODB_AUTH_SOURCE
|
||||||
)
|
)
|
||||||
|
|
||||||
class ScheduleGenerator:
|
class ScheduleGenerator:
|
||||||
|
|||||||
Reference in New Issue
Block a user