v0.4.1
修复了数据库命名问题 修复了嵌入模型未定义问题
This commit is contained in:
@@ -17,12 +17,12 @@ driver = get_driver()
|
||||
config = driver.config
|
||||
|
||||
Database.initialize(
|
||||
host= config.mongodb_host,
|
||||
port= int(config.mongodb_port),
|
||||
db_name= config.database_name,
|
||||
username= config.mongodb_username,
|
||||
password= config.mongodb_password,
|
||||
auth_source= config.mongodb_auth_source
|
||||
host= config.MONGODB_HOST,
|
||||
port= int(config.MONGODB_PORT),
|
||||
db_name= config.DATABASE_NAME,
|
||||
username= config.MONGODB_USERNAME,
|
||||
password= config.MONGODB_PASSWORD,
|
||||
auth_source= config.MONGODB_AUTH_SOURCE
|
||||
)
|
||||
print("\033[1;32m[初始化数据库完成]\033[0m")
|
||||
|
||||
|
||||
@@ -116,6 +116,9 @@ class BotConfig:
|
||||
|
||||
if "vlm" in model_config:
|
||||
config.vlm = model_config["vlm"]
|
||||
|
||||
if "embedding" in model_config:
|
||||
config.embedding = model_config["embedding"]
|
||||
|
||||
# 消息配置
|
||||
if "message" in toml_dict:
|
||||
@@ -152,31 +155,13 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml")
|
||||
if not os.path.exists(bot_config_path):
|
||||
# 如果开发环境配置文件不存在,则使用默认配置文件
|
||||
bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml")
|
||||
logger.info("使用默认配置文件")
|
||||
logger.info("使用bot配置文件")
|
||||
else:
|
||||
logger.info("已找到开发环境配置文件")
|
||||
logger.info("已找到开发bot配置文件")
|
||||
|
||||
global_config = BotConfig.load_config(config_path=bot_config_path)
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
class LLMConfig:
|
||||
"""机器人配置类"""
|
||||
# 基础配置
|
||||
SILICONFLOW_API_KEY: str = None
|
||||
SILICONFLOW_BASE_URL: str = None
|
||||
DEEP_SEEK_API_KEY: str = None
|
||||
DEEP_SEEK_BASE_URL: str = None
|
||||
|
||||
llm_config = LLMConfig()
|
||||
config = get_driver().config
|
||||
llm_config.SILICONFLOW_API_KEY = config.siliconflow_key
|
||||
llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url
|
||||
llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key
|
||||
llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url
|
||||
|
||||
|
||||
if not global_config.enable_advance_output:
|
||||
# logger.remove()
|
||||
pass
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Dict
|
||||
from collections import Counter
|
||||
import math
|
||||
from nonebot import get_driver
|
||||
from ..models.utils_model import LLM_request
|
||||
|
||||
driver = get_driver()
|
||||
config = driver.config
|
||||
@@ -64,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
|
||||
return False
|
||||
|
||||
def get_embedding(text):
|
||||
url = "https://api.siliconflow.cn/v1/embeddings"
|
||||
payload = {
|
||||
"model": "BAAI/bge-m3",
|
||||
"input": text,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {config.siliconflow_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
|
||||
if response.status_code != 200:
|
||||
print(f"API请求失败: {response.status_code}")
|
||||
print(f"错误信息: {response.text}")
|
||||
return None
|
||||
|
||||
return response.json()['data'][0]['embedding']
|
||||
"""获取文本的embedding向量"""
|
||||
llm = LLM_request(model=global_config.embedding)
|
||||
return llm.get_embedding_sync(text)
|
||||
|
||||
def cosine_similarity(v1, v2):
|
||||
dot_product = np.dot(v1, v2)
|
||||
|
||||
Reference in New Issue
Block a user