v0.2.1 添加了对官方api的支持

图像识别还得用硅基
This commit is contained in:
SengokuCola
2025-02-28 10:31:19 +08:00
parent dc3c781401
commit f0bb3149ac
4 changed files with 50 additions and 17 deletions

View File

@@ -17,6 +17,8 @@ CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY= SILICONFLOW_KEY=
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1 CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/ SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1

View File

@@ -17,6 +17,7 @@ check_interval = 120
register_interval = 10 register_interval = 10
[response] [response]
api_using = "siliconflow"
model_r1_probability = 0.8 model_r1_probability = 0.8
model_v3_probability = 0.1 model_v3_probability = 0.1
model_r1_distill_probability = 0.1 model_r1_distill_probability = 0.1

View File

@@ -40,6 +40,7 @@ class BotConfig:
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
API_USING: str = "siliconflow" # 使用的API
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率 MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率 MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率 MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
@@ -76,7 +77,8 @@ class BotConfig:
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY) config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY) config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY) config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
config.API_USING = response_config.get("api_using", config.API_USING)
# 消息配置 # 消息配置
if "message" in toml_dict: if "message" in toml_dict:
msg_config = toml_dict["message"] msg_config = toml_dict["message"]
@@ -108,7 +110,11 @@ class LLMConfig:
# 基础配置 # 基础配置
SILICONFLOW_API_KEY: str = None SILICONFLOW_API_KEY: str = None
SILICONFLOW_BASE_URL: str = None SILICONFLOW_BASE_URL: str = None
DEEP_SEEK_API_KEY: str = None
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig() llm_config = LLMConfig()
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY') llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL') llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')

View File

@@ -26,11 +26,17 @@ load_dotenv(os.path.join(root_dir, '.env'))
class LLMResponseGenerator: class LLMResponseGenerator:
def __init__(self, config: BotConfig): def __init__(self, config: BotConfig):
self.config = config self.config = config
self.client = OpenAI( if self.config.API_USING == "siliconflow":
api_key=llm_config.SILICONFLOW_API_KEY, self.client = OpenAI(
base_url=llm_config.SILICONFLOW_BASE_URL api_key=llm_config.SILICONFLOW_API_KEY,
) base_url=llm_config.SILICONFLOW_BASE_URL
)
elif self.config.API_USING == "deepseek":
self.client = OpenAI(
api_key=llm_config.DEEP_SEEK_API_KEY,
base_url=llm_config.DEEP_SEEK_BASE_URL
)
self.db = Database.get_instance() self.db = Database.get_instance()
# 当前使用的模型类型 # 当前使用的模型类型
@@ -140,19 +146,33 @@ class LLMResponseGenerator:
async def _generate_r1_response(self, message: Message) -> Optional[str]: async def _generate_r1_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-R1 模型生成回复""" """使用 DeepSeek-R1 模型生成回复"""
return await self._generate_base_response( if self.config.API_USING == "deepseek":
message, return await self._generate_base_response(
"Pro/deepseek-ai/DeepSeek-R1", message,
{"temperature": 0.7, "max_tokens": 1024} "deepseek-reasoner",
) {"temperature": 0.7, "max_tokens": 1024}
)
else:
return await self._generate_base_response(
message,
"Pro/deepseek-ai/DeepSeek-R1",
{"temperature": 0.7, "max_tokens": 1024}
)
async def _generate_v3_response(self, message: Message) -> Optional[str]: async def _generate_v3_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-V3 模型生成回复""" """使用 DeepSeek-V3 模型生成回复"""
return await self._generate_base_response( if self.config.API_USING == "deepseek":
message, return await self._generate_base_response(
"Pro/deepseek-ai/DeepSeek-V3", message,
{"temperature": 0.8, "max_tokens": 1024} "deepseek-chat",
) {"temperature": 0.8, "max_tokens": 1024}
)
else:
return await self._generate_base_response(
message,
"Pro/deepseek-ai/DeepSeek-V3",
{"temperature": 0.8, "max_tokens": 1024}
)
async def _generate_r1_distill_response(self, message: Message) -> Optional[str]: async def _generate_r1_distill_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复""" """使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复"""
@@ -192,9 +212,13 @@ class LLMResponseGenerator:
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
if self.config.API_USING == "deepseek":
model = "deepseek-chat"
else:
model = "Pro/deepseek-ai/DeepSeek-V3"
create_completion = partial( create_completion = partial(
self.client.chat.completions.create, self.client.chat.completions.create,
model="Pro/deepseek-ai/DeepSeek-V3", model=model,
messages=messages, messages=messages,
stream=False, stream=False,
max_tokens=30, max_tokens=30,