diff --git a/config/bot_config_template.toml b/config/bot_config_template.toml index f174fcebf..f0c7d52cb 100644 --- a/config/bot_config_template.toml +++ b/config/bot_config_template.toml @@ -55,6 +55,8 @@ ban_user_id = [] #禁止回复消息的QQ号 #base_url = "DEEP_SEEK_BASE_URL" #key = "DEEP_SEEK_KEY" +#下面的模型若使用硅基流动则不需要更改,使用ds官方则改成.env.prod自定义的宏,使用自定义模型则选择定位相似的模型自己填写 + [model.llm_reasoning] #R1 name = "Pro/deepseek-ai/DeepSeek-R1" base_url = "SILICONFLOW_BASE_URL" @@ -84,3 +86,12 @@ key = "SILICONFLOW_KEY" name = "BAAI/bge-m3" base_url = "SILICONFLOW_BASE_URL" key = "SILICONFLOW_KEY" + +# 主题提取,jieba和snownlp不用api,llm需要api +[topic] +topic='llm' # 只支持jieba,snownlp,llm三种选项 + +[topic.llm_topic] +name = "Pro/deepseek-ai/DeepSeek-V3" +base_url = "SILICONFLOW_BASE_URL" +key = "SILICONFLOW_KEY" diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index d119b4ab9..85dbf2223 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -104,13 +104,16 @@ class ChatBot: current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time)) - topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text) - topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text) - topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text) - print(f"\033[1;32m[主题识别]\033[0m 使用jieba主题: {topic1}") - print(f"\033[1;32m[主题识别]\033[0m 使用llm主题: {topic2}") - print(f"\033[1;32m[主题识别]\033[0m 使用snownlp主题: {topic3}") - topic = topic3 + identifier=topic_identifier.identify_topic() + if global_config.topic_extract=='llm': + topic=await identifier(message.processed_plain_text) + else: + topic=identifier(message.detailed_plain_text) + + # topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text) + # topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text) + # topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text) + print(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}") all_num = 0 interested_num = 0 diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index 7a3c85633..24cf12925 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -41,6 +41,8 @@ class BotConfig: llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {}) embedding: Dict[str, str] = field(default_factory=lambda: {}) vlm: Dict[str, str] = field(default_factory=lambda: {}) + topic_extract: str = 'snownlp' # 只支持jieba,snownlp,llm + llm_topic_extract=llm_normal_minor API_USING: str = "siliconflow" # 使用的API API_PAID: bool = False # 是否使用付费API @@ -132,6 +134,15 @@ class BotConfig: if "embedding" in model_config: config.embedding = model_config["embedding"] + if 'topic' in toml_dict: + topic_config=toml_dict['topic'] + if 'topic_extract' in topic_config: + config.topic_extract=topic_config.get('topic_extract',config.topic_extract) + print(f"载入自定义主题提取为{config.topic_extract}") + if config.topic_extract=='llm' and 'llm_topic' in topic_config: + config.llm_topic_extract=topic_config['llm_topic'] + print(f"载入自定义主题提取模型为{config.llm_topic_extract['name']}") + # 消息配置 if "message" in toml_dict: msg_config = toml_dict["message"] diff --git a/src/plugins/chat/topic_identifier.py b/src/plugins/chat/topic_identifier.py index 5c51e0bde..07749e837 100644 --- a/src/plugins/chat/topic_identifier.py +++ b/src/plugins/chat/topic_identifier.py @@ -12,7 +12,19 @@ config = driver.config class TopicIdentifier: def __init__(self): - self.llm_client = LLM_request(model=global_config.llm_normal) + self.llm_client = LLM_request(model=global_config.llm_topic_extract) + self.select=global_config.topic_extract + + def identify_topic(self): + if self.select=='jieba': + return self.identify_topic_jieba + elif self.select=='snownlp': + return self.identify_topic_snownlp + elif self.select=='llm': + return self.identify_topic_llm + else: + return self.identify_topic_snownlp + async def identify_topic_llm(self, text: str) -> Optional[List[str]]: """识别消息主题,返回主题列表"""