fix: 调整topic提取的设置
This commit is contained in:
@@ -104,13 +104,16 @@ class ChatBot:
|
||||
|
||||
current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(message.time))
|
||||
|
||||
topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
||||
topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
|
||||
topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text)
|
||||
print(f"\033[1;32m[主题识别]\033[0m 使用jieba主题: {topic1}")
|
||||
print(f"\033[1;32m[主题识别]\033[0m 使用llm主题: {topic2}")
|
||||
print(f"\033[1;32m[主题识别]\033[0m 使用snownlp主题: {topic3}")
|
||||
topic = topic3
|
||||
identifier=topic_identifier.identify_topic()
|
||||
if global_config.topic_extract=='llm':
|
||||
topic=await identifier(message.processed_plain_text)
|
||||
else:
|
||||
topic=identifier(message.detailed_plain_text)
|
||||
|
||||
# topic1 = topic_identifier.identify_topic_jieba(message.processed_plain_text)
|
||||
# topic2 = await topic_identifier.identify_topic_llm(message.processed_plain_text)
|
||||
# topic3 = topic_identifier.identify_topic_snownlp(message.processed_plain_text)
|
||||
print(f"\033[1;32m[主题识别]\033[0m 使用{global_config.topic_extract}主题: {topic}")
|
||||
|
||||
all_num = 0
|
||||
interested_num = 0
|
||||
|
||||
@@ -41,6 +41,8 @@ class BotConfig:
|
||||
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||
embedding: Dict[str, str] = field(default_factory=lambda: {})
|
||||
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
||||
topic_extract: str = 'snownlp' # 只支持jieba,snownlp,llm
|
||||
llm_topic_extract=llm_normal_minor
|
||||
|
||||
API_USING: str = "siliconflow" # 使用的API
|
||||
API_PAID: bool = False # 是否使用付费API
|
||||
@@ -132,6 +134,15 @@ class BotConfig:
|
||||
if "embedding" in model_config:
|
||||
config.embedding = model_config["embedding"]
|
||||
|
||||
if 'topic' in toml_dict:
|
||||
topic_config=toml_dict['topic']
|
||||
if 'topic_extract' in topic_config:
|
||||
config.topic_extract=topic_config.get('topic_extract',config.topic_extract)
|
||||
print(f"载入自定义主题提取为{config.topic_extract}")
|
||||
if config.topic_extract=='llm' and 'llm_topic' in topic_config:
|
||||
config.llm_topic_extract=topic_config['llm_topic']
|
||||
print(f"载入自定义主题提取模型为{config.llm_topic_extract['name']}")
|
||||
|
||||
# 消息配置
|
||||
if "message" in toml_dict:
|
||||
msg_config = toml_dict["message"]
|
||||
|
||||
@@ -12,7 +12,19 @@ config = driver.config
|
||||
|
||||
class TopicIdentifier:
|
||||
def __init__(self):
|
||||
self.llm_client = LLM_request(model=global_config.llm_normal)
|
||||
self.llm_client = LLM_request(model=global_config.llm_topic_extract)
|
||||
self.select=global_config.topic_extract
|
||||
|
||||
def identify_topic(self):
|
||||
if self.select=='jieba':
|
||||
return self.identify_topic_jieba
|
||||
elif self.select=='snownlp':
|
||||
return self.identify_topic_snownlp
|
||||
elif self.select=='llm':
|
||||
return self.identify_topic_llm
|
||||
else:
|
||||
return self.identify_topic_snownlp
|
||||
|
||||
|
||||
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
|
||||
"""识别消息主题,返回主题列表"""
|
||||
|
||||
Reference in New Issue
Block a user