v0.2.1 添加了对官方api的支持
图像识别还得用硅基
This commit is contained in:
@@ -17,6 +17,8 @@ CHAT_ANY_WHERE_KEY=
|
|||||||
SILICONFLOW_KEY=
|
SILICONFLOW_KEY=
|
||||||
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
CHAT_ANY_WHERE_BASE_URL=https://api.chatanywhere.tech/v1
|
||||||
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
SILICONFLOW_BASE_URL=https://api.siliconflow.cn/v1/
|
||||||
|
DEEP_SEEK_KEY=
|
||||||
|
DEEP_SEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ check_interval = 120
|
|||||||
register_interval = 10
|
register_interval = 10
|
||||||
|
|
||||||
[response]
|
[response]
|
||||||
|
api_using = "siliconflow"
|
||||||
model_r1_probability = 0.8
|
model_r1_probability = 0.8
|
||||||
model_v3_probability = 0.1
|
model_v3_probability = 0.1
|
||||||
model_r1_distill_probability = 0.1
|
model_r1_distill_probability = 0.1
|
||||||
|
|||||||
@@ -40,6 +40,7 @@ class BotConfig:
|
|||||||
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
|
||||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||||
|
|
||||||
|
API_USING: str = "siliconflow" # 使用的API
|
||||||
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
||||||
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
||||||
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
|
MODEL_R1_DISTILL_PROBABILITY: float = 0.1 # R1蒸馏模型概率
|
||||||
@@ -76,7 +77,8 @@ class BotConfig:
|
|||||||
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
|
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
|
||||||
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
||||||
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
||||||
|
config.API_USING = response_config.get("api_using", config.API_USING)
|
||||||
|
|
||||||
# 消息配置
|
# 消息配置
|
||||||
if "message" in toml_dict:
|
if "message" in toml_dict:
|
||||||
msg_config = toml_dict["message"]
|
msg_config = toml_dict["message"]
|
||||||
@@ -108,7 +110,11 @@ class LLMConfig:
|
|||||||
# 基础配置
|
# 基础配置
|
||||||
SILICONFLOW_API_KEY: str = None
|
SILICONFLOW_API_KEY: str = None
|
||||||
SILICONFLOW_BASE_URL: str = None
|
SILICONFLOW_BASE_URL: str = None
|
||||||
|
DEEP_SEEK_API_KEY: str = None
|
||||||
|
DEEP_SEEK_BASE_URL: str = None
|
||||||
|
|
||||||
llm_config = LLMConfig()
|
llm_config = LLMConfig()
|
||||||
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
|
llm_config.SILICONFLOW_API_KEY = os.getenv('SILICONFLOW_KEY')
|
||||||
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
|
llm_config.SILICONFLOW_BASE_URL = os.getenv('SILICONFLOW_BASE_URL')
|
||||||
|
llm_config.DEEP_SEEK_API_KEY = os.getenv('DEEP_SEEK_KEY')
|
||||||
|
llm_config.DEEP_SEEK_BASE_URL = os.getenv('DEEP_SEEK_BASE_URL')
|
||||||
|
|||||||
@@ -26,11 +26,17 @@ load_dotenv(os.path.join(root_dir, '.env'))
|
|||||||
class LLMResponseGenerator:
|
class LLMResponseGenerator:
|
||||||
def __init__(self, config: BotConfig):
|
def __init__(self, config: BotConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client = OpenAI(
|
if self.config.API_USING == "siliconflow":
|
||||||
api_key=llm_config.SILICONFLOW_API_KEY,
|
self.client = OpenAI(
|
||||||
base_url=llm_config.SILICONFLOW_BASE_URL
|
api_key=llm_config.SILICONFLOW_API_KEY,
|
||||||
)
|
base_url=llm_config.SILICONFLOW_BASE_URL
|
||||||
|
)
|
||||||
|
elif self.config.API_USING == "deepseek":
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=llm_config.DEEP_SEEK_API_KEY,
|
||||||
|
base_url=llm_config.DEEP_SEEK_BASE_URL
|
||||||
|
)
|
||||||
|
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
|
|
||||||
# 当前使用的模型类型
|
# 当前使用的模型类型
|
||||||
@@ -140,19 +146,33 @@ class LLMResponseGenerator:
|
|||||||
|
|
||||||
async def _generate_r1_response(self, message: Message) -> Optional[str]:
|
async def _generate_r1_response(self, message: Message) -> Optional[str]:
|
||||||
"""使用 DeepSeek-R1 模型生成回复"""
|
"""使用 DeepSeek-R1 模型生成回复"""
|
||||||
return await self._generate_base_response(
|
if self.config.API_USING == "deepseek":
|
||||||
message,
|
return await self._generate_base_response(
|
||||||
"Pro/deepseek-ai/DeepSeek-R1",
|
message,
|
||||||
{"temperature": 0.7, "max_tokens": 1024}
|
"deepseek-reasoner",
|
||||||
)
|
{"temperature": 0.7, "max_tokens": 1024}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await self._generate_base_response(
|
||||||
|
message,
|
||||||
|
"Pro/deepseek-ai/DeepSeek-R1",
|
||||||
|
{"temperature": 0.7, "max_tokens": 1024}
|
||||||
|
)
|
||||||
|
|
||||||
async def _generate_v3_response(self, message: Message) -> Optional[str]:
|
async def _generate_v3_response(self, message: Message) -> Optional[str]:
|
||||||
"""使用 DeepSeek-V3 模型生成回复"""
|
"""使用 DeepSeek-V3 模型生成回复"""
|
||||||
return await self._generate_base_response(
|
if self.config.API_USING == "deepseek":
|
||||||
message,
|
return await self._generate_base_response(
|
||||||
"Pro/deepseek-ai/DeepSeek-V3",
|
message,
|
||||||
{"temperature": 0.8, "max_tokens": 1024}
|
"deepseek-chat",
|
||||||
)
|
{"temperature": 0.8, "max_tokens": 1024}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return await self._generate_base_response(
|
||||||
|
message,
|
||||||
|
"Pro/deepseek-ai/DeepSeek-V3",
|
||||||
|
{"temperature": 0.8, "max_tokens": 1024}
|
||||||
|
)
|
||||||
|
|
||||||
async def _generate_r1_distill_response(self, message: Message) -> Optional[str]:
|
async def _generate_r1_distill_response(self, message: Message) -> Optional[str]:
|
||||||
"""使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复"""
|
"""使用 DeepSeek-R1-Distill-Qwen-32B 模型生成回复"""
|
||||||
@@ -192,9 +212,13 @@ class LLMResponseGenerator:
|
|||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
if self.config.API_USING == "deepseek":
|
||||||
|
model = "deepseek-chat"
|
||||||
|
else:
|
||||||
|
model = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
create_completion = partial(
|
create_completion = partial(
|
||||||
self.client.chat.completions.create,
|
self.client.chat.completions.create,
|
||||||
model="Pro/deepseek-ai/DeepSeek-V3",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=30,
|
max_tokens=30,
|
||||||
|
|||||||
Reference in New Issue
Block a user