v0.2.1 添加了对官方api的支持

图像识别还得用硅基
This commit is contained in:
SengokuCola
2025-02-28 10:31:19 +08:00
parent dc3c781401
commit f0bb3149ac
4 changed files with 50 additions and 17 deletions

View File

@@ -17,6 +17,8 @@ CHAT_ANY_WHERE_KEY=
SILICONFLOW_KEY=
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
DEEP_SEEK_KEY=
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1

View File

@@ -17,6 +17,7 @@ check_interval = 120
register_interval = 10
[response]
api_using = "siliconflow"
model_r1_probability = 0.8
model_v3_probability = 0.1
model_r1_distill_probability = 0.1

View File

@@ -40,6 +40,7 @@ class BotConfig:
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
API_USING: str = "siliconflow" # 使用的API
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
@@ -76,6 +77,7 @@ class BotConfig:
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
config.API_USING = response_config.get("api_using", config.API_USING)
# 消息配置
if "message" in toml_dict:
@@ -108,7 +110,11 @@ class LLMConfig:
# 基础配置
SILICONFLOW_API_KEY: str = None
SILICONFLOW_BASE_URL: str = None
DEEP_SEEK_API_KEY: str = None
DEEP_SEEK_BASE_URL: str = None
llm_config = LLMConfig()
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')

View File

@@ -26,10 +26,16 @@ load_dotenv(os.path.join(root_dir, '.env'))
class LLMResponseGenerator:
def __init__(self, config: BotConfig):
self.config = config
if self.config.API_USING == "siliconflow":
self.client = OpenAI(
api_key=llm_config.SILICONFLOW_API_KEY,
base_url=llm_config.SILICONFLOW_BASE_URL
)
elif self.config.API_USING == "deepseek":
self.client = OpenAI(
api_key=llm_config.DEEP_SEEK_API_KEY,
base_url=llm_config.DEEP_SEEK_BASE_URL
)
self.db = Database.get_instance()
@@ -140,6 +146,13 @@ class LLMResponseGenerator:
async def _generate_r1_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-R1 模型生成回复"""
if self.config.API_USING == "deepseek":
return await self._generate_base_response(
message,
"deepseek-reasoner",
{"temperature": 0.7, "max_tokens": 1024}
)
else:
return await self._generate_base_response(
message,
"Pro/deepseek-ai/DeepSeek-R1",
@@ -148,6 +161,13 @@ class LLMResponseGenerator:
async def _generate_v3_response(self, message: Message) -> Optional[str]:
"""使用 DeepSeek-V3 模型生成回复"""
if self.config.API_USING == "deepseek":
return await self._generate_base_response(
message,
"deepseek-chat",
{"temperature": 0.8, "max_tokens": 1024}
)
else:
return await self._generate_base_response(
message,
"Pro/deepseek-ai/DeepSeek-V3",
@@ -192,9 +212,13 @@ class LLMResponseGenerator:
messages = [{"role": "user", "content": prompt}]
loop = asyncio.get_event_loop()
if self.config.API_USING == "deepseek":
model = "deepseek-chat"
else:
model = "Pro/deepseek-ai/DeepSeek-V3"
create_completion = partial(
self.client.chat.completions.create,
model="Pro/deepseek-ai/DeepSeek-V3",
model=model,
messages=messages,
stream=False,
max_tokens=30,