修复了数据库命名问题
修复了嵌入模型未定义问题
This commit is contained in:
SengokuCola
2025-03-03 23:50:45 +08:00
parent e014617cdb
commit dd4fb315df
11 changed files with 154 additions and 103 deletions

View File

@@ -17,12 +17,12 @@ driver = get_driver()
config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source= config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source= config.MONGODB_AUTH_SOURCE
)
print("\033[1;32m[初始化数据库完成]\033[0m")

View File

@@ -116,6 +116,9 @@ class BotConfig:
if "vlm" in model_config:
config.vlm = model_config["vlm"]
if "embedding" in model_config:
config.embedding = model_config["embedding"]
# 消息配置
if "message" in toml_dict:
@@ -152,31 +155,13 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
if not os.path.exists(bot_config_path):
# 如果开发环境配置文件不存在,则使用默认配置文件
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
logger.info("使用默认配置文件")
logger.info("使用bot配置文件")
else:
logger.info("已找到开发环境配置文件")
logger.info("已找到开发bot配置文件")
global_config = BotConfig.load_config(config_path=bot_config_path)
@dataclass
class LLMConfig:
"""机器人配置类"""
# 基础配置
SILICONFLOW_API_KEY: str = None
SILICONFLOW_BASE_URL: str = None
DEEP_SEEK_API_KEY: str = None
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig()
config = get_driver().config
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
if not global_config.enable_advance_output:
# logger.remove()
pass

View File

@@ -10,6 +10,7 @@ from typing import Dict
from collections import Counter
import math
from nonebot import get_driver
from ..models.utils_model import LLM_request
driver = get_driver()
config = driver.config
@@ -64,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
return False
def get_embedding(text):
url = "https://api.siliconflow.cn/v1/embeddings"
payload = {
"model": "BAAI/bge-m3",
"input": text,
"encoding_format": "float"
}
headers = {
"Authorization": f"Bearer {config.siliconflow_key}",
"Content-Type": "application/json"
}
response = requests.request("POST", url, json=payload, headers=headers)
if response.status_code != 200:
print(f"API请求失败: {response.status_code}")
print(f"错误信息: {response.text}")
return None
return response.json()['data'][0]['embedding']
"""获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding)
return llm.get_embedding_sync(text)
def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2)

View File

@@ -17,12 +17,12 @@ from src.plugins.chat.config import llm_config
# 直接配置数据库连接信息
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
class KnowledgeLibrary:

View File

@@ -66,7 +66,7 @@ class LLMModel:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""

View File

@@ -259,12 +259,12 @@ config = driver.config
start_time = time.time()
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= config.MONGODB_PORT,
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
#创建记忆图
memory_graph = Memory_graph()

View File

@@ -9,7 +9,7 @@ driver = get_driver()
config = driver.config
class LLM_request:
def __init__(self, model = global_config.llm_normal,**kwargs):
def __init__(self, model ,**kwargs):
# 将大写的配置键转换为小写并从config中获取实际值
try:
self.api_key = getattr(config, model["key"])
@@ -61,7 +61,7 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
@@ -126,7 +126,7 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
@@ -191,9 +191,119 @@ class LLM_request:
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
print(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
return f"请求失败: {str(e)}", ""
return "达到最大重试次数,请求仍然失败", ""
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""同步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry)
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status()
result = response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
print(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
time.sleep(wait_time)
else:
print(f"embedding请求失败: {str(e)}")
return None
print("达到最大重试次数embedding请求仍然失败")
return None
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""异步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": model,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry)
print(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status()
result = await response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
print(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
print(f"embedding请求失败: {str(e)}")
return None
print("达到最大重试次数embedding请求仍然失败")
return None

View File

@@ -11,12 +11,12 @@ config = driver.config
Database.initialize(
host= config.mongodb_host,
port= int(config.mongodb_port),
db_name= config.database_name,
username= config.mongodb_username,
password= config.mongodb_password,
auth_source=config.mongodb_auth_source
host= config.MONGODB_HOST,
port= int(config.MONGODB_PORT),
db_name= config.DATABASE_NAME,
username= config.MONGODB_USERNAME,
password= config.MONGODB_PASSWORD,
auth_source=config.MONGODB_AUTH_SOURCE
)
class ScheduleGenerator: