diff --git a/.env b/.env index 382b70fa0..cd0d3f3e4 100644 --- a/.env +++ b/.env @@ -1,26 +1,2 @@ # 您不应该修改默认值,这个文件被仓库索引,请修改.env.prod -ENVIRONMENT=prod -# HOST=127.0.0.1 -# PORT=8080 - -# COMMAND_START=["/"] - -# # 插件配置 -# PLUGINS=["src2.plugins.chat"] - -# # 默认配置 -# MONGODB_HOST=127.0.0.1 -# MONGODB_PORT=27017 -# DATABASE_NAME=MegBot - -# MONGODB_USERNAME = "" # 默认空值 -# MONGODB_PASSWORD = "" # 默认空值 -# MONGODB_AUTH_SOURCE = "" # 默认空值 - -# #key and url -# CHAT_ANY_WHERE_KEY= -# SILICONFLOW_KEY= -# CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 -# SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ -# DEEP_SEEK_KEY= -# DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 +ENVIRONMENT=.dev \ No newline at end of file diff --git a/.env.prod b/.env.prod index f00cd5169..d70bba206 100644 --- a/.env.prod +++ b/.env.prod @@ -1,8 +1,6 @@ HOST=127.0.0.1 PORT=8080 -COMMAND_START=["/"] - # 插件配置 PLUGINS=["src2.plugins.chat"] @@ -16,11 +14,11 @@ MONGODB_PASSWORD = "" # 默认空值 MONGODB_AUTH_SOURCE = "" # 默认空值 #key and url - CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 +#定义你要用的api的base_url DEEP_SEEK_KEY= CHAT_ANY_WHERE_KEY= SILICONFLOW_KEY= \ No newline at end of file diff --git a/bot.py b/bot.py index 906ffc37d..d24c82ac6 100644 --- a/bot.py +++ b/bot.py @@ -15,25 +15,22 @@ for i, char in enumerate(text): print(rainbow_text) '''彩蛋''' -# 首先加载基础环境变量 +# 首先加载基础环境变量.env if os.path.exists(".env"): load_dotenv(".env") logger.success("成功加载基础环境变量配置") else: logger.error("基础环境变量配置文件 .env 不存在") exit(1) -# 根据 ENVIRONMENT 加载对应的环境配置 -env = os.getenv("ENVIRONMENT") -env_file = f".env.{env}" -if env_file == ".env.dev" and os.path.exists(env_file): +if os.path.exists(".env.dev"): logger.success("加载开发环境变量配置") - load_dotenv(env_file, override=True) # override=True 允许覆盖已存在的环境变量 + load_dotenv(".env.dev", override=True) # override=True 允许覆盖已存在的环境变量 elif os.path.exists(".env.prod"): logger.success("加载环境变量配置") load_dotenv(".env.prod", override=True) # override=True 允许覆盖已存在的环境变量 else: - logger.error(f"{env}对应的环境配置文件{env_file}不存在,请修改.env文件中的ENVIRONMENT变量为 prod.") + logger.error(f".env对应的环境配置文件不存在,请修改.env文件中的ENVIRONMENT变量为 prod.") exit(1) # 获取所有环境变量 diff --git a/src/plugins/chat/__init__.py b/src/plugins/chat/__init__.py index 4b1f8d77f..5b13d44c9 100644 --- a/src/plugins/chat/__init__.py +++ b/src/plugins/chat/__init__.py @@ -17,12 +17,12 @@ driver = get_driver() config = driver.config Database.initialize( - host= config.mongodb_host, - port= int(config.mongodb_port), - db_name= config.database_name, - username= config.mongodb_username, - password= config.mongodb_password, - auth_source= config.mongodb_auth_source + host= config.MONGODB_HOST, + port= int(config.MONGODB_PORT), + db_name= config.DATABASE_NAME, + username= config.MONGODB_USERNAME, + password= config.MONGODB_PASSWORD, + auth_source= config.MONGODB_AUTH_SOURCE ) print("\033[1;32m[初始化数据库完成]\033[0m") diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index ae2f1fd5d..777463647 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -116,6 +116,9 @@ class BotConfig: if "vlm" in model_config: config.vlm = model_config["vlm"] + + if "embedding" in model_config: + config.embedding = model_config["embedding"] # 消息配置 if "message" in toml_dict: @@ -152,31 +155,13 @@ bot_config_path = os.path.join(bot_config_floder_path, "bot_config_dev.toml") if not os.path.exists(bot_config_path): # 如果开发环境配置文件不存在,则使用默认配置文件 bot_config_path = os.path.join(bot_config_floder_path, "bot_config.toml") - logger.info("使用默认配置文件") + logger.info("使用bot配置文件") else: - logger.info("已找到开发环境配置文件") + logger.info("已找到开发bot配置文件") global_config = BotConfig.load_config(config_path=bot_config_path) - -@dataclass -class LLMConfig: - """机器人配置类""" - # 基础配置 - SILICONFLOW_API_KEY: str = None - SILICONFLOW_BASE_URL: str = None - DEEP_SEEK_API_KEY: str = None - DEEP_SEEK_BASE_URL: str = None - -llm_config = LLMConfig() -config = get_driver().config -llm_config.SILICONFLOW_API_KEY = config.siliconflow_key -llm_config.SILICONFLOW_BASE_URL = config.siliconflow_base_url -llm_config.DEEP_SEEK_API_KEY = config.deep_seek_key -llm_config.DEEP_SEEK_BASE_URL = config.deep_seek_base_url - - if not global_config.enable_advance_output: # logger.remove() pass diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 0e5db347c..63151592d 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -10,6 +10,7 @@ from typing import Dict from collections import Counter import math from nonebot import get_driver +from ..models.utils_model import LLM_request driver = get_driver() config = driver.config @@ -64,25 +65,9 @@ def is_mentioned_bot_in_txt(message: str) -> bool: return False def get_embedding(text): - url = "https://api.siliconflow.cn/v1/embeddings" - payload = { - "model": "BAAI/bge-m3", - "input": text, - "encoding_format": "float" - } - headers = { - "Authorization": f"Bearer {config.siliconflow_key}", - "Content-Type": "application/json" - } - - response = requests.request("POST", url, json=payload, headers=headers) - - if response.status_code != 200: - print(f"API请求失败: {response.status_code}") - print(f"错误信息: {response.text}") - return None - - return response.json()['data'][0]['embedding'] + """获取文本的embedding向量""" + llm = LLM_request(model=global_config.embedding) + return llm.get_embedding_sync(text) def cosine_similarity(v1, v2): dot_product = np.dot(v1, v2) diff --git a/src/plugins/knowledege/knowledge_library.py b/src/plugins/knowledege/knowledge_library.py index d8c2e1482..f8e91039b 100644 --- a/src/plugins/knowledege/knowledge_library.py +++ b/src/plugins/knowledege/knowledge_library.py @@ -17,12 +17,12 @@ from src.plugins.chat.config import llm_config # 直接配置数据库连接信息 Database.initialize( - host= config.mongodb_host, - port= int(config.mongodb_port), - db_name= config.database_name, - username= config.mongodb_username, - password= config.mongodb_password, - auth_source=config.mongodb_auth_source + host= config.MONGODB_HOST, + port= int(config.MONGODB_PORT), + db_name= config.DATABASE_NAME, + username= config.MONGODB_USERNAME, + password= config.MONGODB_PASSWORD, + auth_source=config.MONGODB_AUTH_SOURCE ) class KnowledgeLibrary: diff --git a/src/plugins/memory_system/llm_module_memory_make.py b/src/plugins/memory_system/llm_module_memory_make.py index f59354570..89fe45cf0 100644 --- a/src/plugins/memory_system/llm_module_memory_make.py +++ b/src/plugins/memory_system/llm_module_memory_make.py @@ -66,7 +66,7 @@ class LLMModel: except Exception as e: if retry < max_retries - 1: # 如果还有重试机会 wait_time = base_wait_time * (2 ** retry) - print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") await asyncio.sleep(wait_time) else: return f"请求失败: {str(e)}", "" diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index a051192a5..e0095dada 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -259,12 +259,12 @@ config = driver.config start_time = time.time() Database.initialize( - host= config.mongodb_host, - port= int(config.mongodb_port), - db_name= config.database_name, - username= config.mongodb_username, - password= config.mongodb_password, - auth_source=config.mongodb_auth_source + host= config.MONGODB_HOST, + port= config.MONGODB_PORT, + db_name= config.DATABASE_NAME, + username= config.MONGODB_USERNAME, + password= config.MONGODB_PASSWORD, + auth_source=config.MONGODB_AUTH_SOURCE ) #创建记忆图 memory_graph = Memory_graph() diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index f911d7495..54be3be34 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -9,7 +9,7 @@ driver = get_driver() config = driver.config class LLM_request: - def __init__(self, model = global_config.llm_normal,**kwargs): + def __init__(self, model ,**kwargs): # 将大写的配置键转换为小写并从config中获取实际值 try: self.api_key = getattr(config, model["key"]) @@ -61,7 +61,7 @@ class LLM_request: except Exception as e: if retry < max_retries - 1: # 如果还有重试机会 wait_time = base_wait_time * (2 ** retry) - print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + print(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") await asyncio.sleep(wait_time) else: return f"请求失败: {str(e)}", "" @@ -126,7 +126,7 @@ class LLM_request: except Exception as e: if retry < max_retries - 1: # 如果还有重试机会 wait_time = base_wait_time * (2 ** retry) - print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + print(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") await asyncio.sleep(wait_time) else: return f"请求失败: {str(e)}", "" @@ -191,9 +191,119 @@ class LLM_request: except Exception as e: if retry < max_retries - 1: # 如果还有重试机会 wait_time = base_wait_time * (2 ** retry) - print(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + print(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") time.sleep(wait_time) else: return f"请求失败: {str(e)}", "" return "达到最大重试次数,请求仍然失败", "" + + def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]: + """同步方法:获取文本的embedding向量 + + Args: + text: 需要获取embedding的文本 + model: 使用的模型名称,默认为"BAAI/bge-m3" + + Returns: + list: embedding向量,如果失败则返回None + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + data = { + "model": model, + "input": text, + "encoding_format": "float" + } + + api_url = f"{self.base_url.rstrip('/')}/embeddings" + + max_retries = 2 + base_wait_time = 6 + + for retry in range(max_retries): + try: + response = requests.post(api_url, headers=headers, json=data, timeout=30) + + if response.status_code == 429: + wait_time = base_wait_time * (2 ** retry) + print(f"遇到请求限制(429),等待{wait_time}秒后重试...") + time.sleep(wait_time) + continue + + response.raise_for_status() + + result = response.json() + if 'data' in result and len(result['data']) > 0: + return result['data'][0]['embedding'] + return None + + except Exception as e: + if retry < max_retries - 1: + wait_time = base_wait_time * (2 ** retry) + print(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + time.sleep(wait_time) + else: + print(f"embedding请求失败: {str(e)}") + return None + + print("达到最大重试次数,embedding请求仍然失败") + return None + + async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]: + """异步方法:获取文本的embedding向量 + + Args: + text: 需要获取embedding的文本 + model: 使用的模型名称,默认为"BAAI/bge-m3" + + Returns: + list: embedding向量,如果失败则返回None + """ + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + data = { + "model": model, + "input": text, + "encoding_format": "float" + } + + api_url = f"{self.base_url.rstrip('/')}/embeddings" + + max_retries = 3 + base_wait_time = 15 + + for retry in range(max_retries): + try: + async with aiohttp.ClientSession() as session: + async with session.post(api_url, headers=headers, json=data) as response: + if response.status == 429: + wait_time = base_wait_time * (2 ** retry) + print(f"遇到请求限制(429),等待{wait_time}秒后重试...") + await asyncio.sleep(wait_time) + continue + + response.raise_for_status() + + result = await response.json() + if 'data' in result and len(result['data']) > 0: + return result['data'][0]['embedding'] + return None + + except Exception as e: + if retry < max_retries - 1: + wait_time = base_wait_time * (2 ** retry) + print(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") + await asyncio.sleep(wait_time) + else: + print(f"embedding请求失败: {str(e)}") + return None + + print("达到最大重试次数,embedding请求仍然失败") + return None diff --git a/src/plugins/schedule/schedule_generator.py b/src/plugins/schedule/schedule_generator.py index fcee0d1be..b2af29f6b 100644 --- a/src/plugins/schedule/schedule_generator.py +++ b/src/plugins/schedule/schedule_generator.py @@ -11,12 +11,12 @@ config = driver.config Database.initialize( - host= config.mongodb_host, - port= int(config.mongodb_port), - db_name= config.database_name, - username= config.mongodb_username, - password= config.mongodb_password, - auth_source=config.mongodb_auth_source + host= config.MONGODB_HOST, + port= int(config.MONGODB_PORT), + db_name= config.DATABASE_NAME, + username= config.MONGODB_USERNAME, + password= config.MONGODB_PASSWORD, + auth_source=config.MONGODB_AUTH_SOURCE ) class ScheduleGenerator: