v0.2.1 添加了对官方api的支持
图像识别还得用硅基
This commit is contained in:
@@ -17,6 +17,8 @@ CHAT_ANY_WHERE_KEY=
|
||||
SILICONFLOW_KEY=
|
||||
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||
DEEP_SEEK_KEY=
|
||||
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ check_interval = 120
|
||||
register_interval = 10
|
||||
|
||||
[response]
|
||||
api_using = "siliconflow"
|
||||
model_r1_probability = 0.8
|
||||
model_v3_probability = 0.1
|
||||
model_r1_distill_probability = 0.1
|
||||
|
||||
@@ -40,6 +40,7 @@ class BotConfig:
|
||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||
|
||||
API_USING: str = "siliconflow" # 使用的API
|
||||
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
||||
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
||||
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
|
||||
@@ -76,7 +77,8 @@ class BotConfig:
|
||||
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
|
||||
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
||||
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
||||
|
||||
config.API_USING = response_config.get("api_using", config.API_USING)
|
||||
|
||||
# 消息配置
|
||||
if "message" in toml_dict:
|
||||
msg_config = toml_dict["message"]
|
||||
@@ -108,7 +110,11 @@ class LLMConfig:
|
||||
# 基础配置
|
||||
SILICONFLOW_API_KEY: str = None
|
||||
SILICONFLOW_BASE_URL: str = None
|
||||
DEEP_SEEK_API_KEY: str = None
|
||||
DEEP_SEEK_BASE_URL: str = None
|
||||
|
||||
llm_config = LLMConfig()
|
||||
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
|
||||
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
|
||||
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
|
||||
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
|
||||
|
||||
@@ -26,11 +26,17 @@ load_dotenv(os.path.join(root_dir, '.env'))
|
||||
class LLMResponseGenerator:
|
||||
def __init__(self, config: BotConfig):
|
||||
self.config = config
|
||||
self.client = OpenAI(
|
||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
||||
)
|
||||
|
||||
if self.config.API_USING == "siliconflow":
|
||||
self.client = OpenAI(
|
||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
||||
)
|
||||
elif self.config.API_USING == "deepseek":
|
||||
self.client = OpenAI(
|
||||
api_key=llm_config.DEEP_SEEK_API_KEY,
|
||||
base_url=llm_config.DEEP_SEEK_BASE_URL
|
||||
)
|
||||
|
||||
self.db = Database.get_instance()
|
||||
|
||||
# 当前使用的模型类型
|
||||
@@ -140,19 +146,33 @@ class LLMResponseGenerator:
|
||||
|
||||
async def _generate_r1_response(self, message: Message) -> Optional[str]:
|
||||
"""使用 DeepSeek-R1 模型生成回复"""
|
||||
return await self._generate_base_response(
|
||||
message,
|
||||
"Pro/deepseek-ai/DeepSeek-R1",
|
||||
{"temperature": 0.7, "max_tokens": 1024}
|
||||
)
|
||||
if self.config.API_USING == "deepseek":
|
||||
return await self._generate_base_response(
|
||||
message,
|
||||
"deepseek-reasoner",
|
||||
{"temperature": 0.7, "max_tokens": 1024}
|
||||
)
|
||||
else:
|
||||
return await self._generate_base_response(
|
||||
message,
|
||||
"Pro/deepseek-ai/DeepSeek-R1",
|
||||
{"temperature": 0.7, "max_tokens": 1024}
|
||||
)
|
||||
|
||||
async def _generate_v3_response(self, message: Message) -> Optional[str]:
|
||||
"""使用 DeepSeek-V3 模型生成回复"""
|
||||
return await self._generate_base_response(
|
||||
message,
|
||||
"Pro/deepseek-ai/DeepSeek-V3",
|
||||
{"temperature": 0.8, "max_tokens": 1024}
|
||||
)
|
||||
if self.config.API_USING == "deepseek":
|
||||
return await self._generate_base_response(
|
||||
message,
|
||||
"deepseek-chat",
|
||||
{"temperature": 0.8, "max_tokens": 1024}
|
||||
)
|
||||
else:
|
||||
return await self._generate_base_response(
|
||||
message,
|
||||
"Pro/deepseek-ai/DeepSeek-V3",
|
||||
{"temperature": 0.8, "max_tokens": 1024}
|
||||
)
|
||||
|
||||
async def _generate_r1_distill_response(self, message: Message) -> Optional[str]:
|
||||
"""使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复"""
|
||||
@@ -192,9 +212,13 @@ class LLMResponseGenerator:
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
if self.config.API_USING == "deepseek":
|
||||
model = "deepseek-chat"
|
||||
else:
|
||||
model = "Pro/deepseek-ai/DeepSeek-V3"
|
||||
create_completion = partial(
|
||||
self.client.chat.completions.create,
|
||||
model="Pro/deepseek-ai/DeepSeek-V3",
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=False,
|
||||
max_tokens=30,
|
||||
|
||||
Reference in New Issue
Block a user