From f0bb3149ac3401924050f7b1fc6d910621964c95 Mon Sep 17 00:00:00 2001 From: SengokuCola <1026294844@qq.com> Date: Fri, 28 Feb 2025 10:31:19 +0800 Subject: [PATCH] =?UTF-8?q?v0.2.1=20=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=AF=B9?= =?UTF-8?q?=E5=AE=98=E6=96=B9api=E7=9A=84=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 图像识别还得用硅基 --- env.example | 2 ++ src/plugins/chat/bot_config_toml | 1 + src/plugins/chat/config.py | 8 ++++- src/plugins/chat/llm_generator.py | 56 ++++++++++++++++++++++--------- 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/env.example b/env.example index 7680540df..0ae1560a3 100644 --- a/env.example +++ b/env.example @@ -17,6 +17,8 @@ CHAT_ANY_WHERE_KEY= SILICONFLOW_KEY= CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ +DEEP_SEEK_KEY= +DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1 diff --git a/src/plugins/chat/bot_config_toml b/src/plugins/chat/bot_config_toml index 5b9010035..afda4230a 100644 --- a/src/plugins/chat/bot_config_toml +++ b/src/plugins/chat/bot_config_toml @@ -17,6 +17,7 @@ check_interval = 120 register_interval = 10 [response] +api_using = "siliconflow" model_r1_probability = 0.8 model_v3_probability = 0.1 model_r1_distill_probability = 0.1 diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index d44be7bd0..b7260bf01 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -40,6 +40,7 @@ class BotConfig: EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) + API_USING: str = "siliconflow" # 使用的API MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率 MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率 MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率 @@ -76,7 +77,8 @@ class BotConfig: config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY) config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY) config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY) - + config.API_USING = response_config.get("api_using", config.API_USING) + # 消息配置 if "message" in toml_dict: msg_config = toml_dict["message"] @@ -108,7 +110,11 @@ class LLMConfig: # 基础配置 SILICONFLOW_API_KEY: str = None SILICONFLOW_BASE_URL: str = None + DEEP_SEEK_API_KEY: str = None + DEEP_SEEK_BASE_URL: str = None llm_config = LLMConfig() llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY') llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL') +llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY') +llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL') diff --git a/src/plugins/chat/llm_generator.py b/src/plugins/chat/llm_generator.py index 176996c93..17cabe664 100644 --- a/src/plugins/chat/llm_generator.py +++ b/src/plugins/chat/llm_generator.py @@ -26,11 +26,17 @@ load_dotenv(os.path.join(root_dir, '.env')) class LLMResponseGenerator: def __init__(self, config: BotConfig): self.config = config - self.client = OpenAI( - api_key=llm_config.SILICONFLOW_API_KEY, - base_url=llm_config.SILICONFLOW_BASE_URL - ) - + if self.config.API_USING == "siliconflow": + self.client = OpenAI( + api_key=llm_config.SILICONFLOW_API_KEY, + base_url=llm_config.SILICONFLOW_BASE_URL + ) + elif self.config.API_USING == "deepseek": + self.client = OpenAI( + api_key=llm_config.DEEP_SEEK_API_KEY, + base_url=llm_config.DEEP_SEEK_BASE_URL + ) + self.db = Database.get_instance() # 当前使用的模型类型 @@ -140,19 +146,33 @@ class LLMResponseGenerator: async def _generate_r1_response(self, message: Message) -> Optional[str]: """使用 DeepSeek-R1 模型生成回复""" - return await self._generate_base_response( - message, - "Pro/deepseek-ai/DeepSeek-R1", - {"temperature": 0.7, "max_tokens": 1024} - ) + if self.config.API_USING == "deepseek": + return await self._generate_base_response( + message, + "deepseek-reasoner", + {"temperature": 0.7, "max_tokens": 1024} + ) + else: + return await self._generate_base_response( + message, + "Pro/deepseek-ai/DeepSeek-R1", + {"temperature": 0.7, "max_tokens": 1024} + ) async def _generate_v3_response(self, message: Message) -> Optional[str]: """使用 DeepSeek-V3 模型生成回复""" - return await self._generate_base_response( - message, - "Pro/deepseek-ai/DeepSeek-V3", - {"temperature": 0.8, "max_tokens": 1024} - ) + if self.config.API_USING == "deepseek": + return await self._generate_base_response( + message, + "deepseek-chat", + {"temperature": 0.8, "max_tokens": 1024} + ) + else: + return await self._generate_base_response( + message, + "Pro/deepseek-ai/DeepSeek-V3", + {"temperature": 0.8, "max_tokens": 1024} + ) async def _generate_r1_distill_response(self, message: Message) -> Optional[str]: """使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复""" @@ -192,9 +212,13 @@ class LLMResponseGenerator: messages = [{"role": "user", "content": prompt}] loop = asyncio.get_event_loop() + if self.config.API_USING == "deepseek": + model = "deepseek-chat" + else: + model = "Pro/deepseek-ai/DeepSeek-V3" create_completion = partial( self.client.chat.completions.create, - model="Pro/deepseek-ai/DeepSeek-V3", + model=model, messages=messages, stream=False, max_tokens=30,