修改了模型命名
This commit is contained in:
@@ -43,8 +43,12 @@ class BotConfig:
|
|||||||
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
|
llm_reasoning_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
|
llm_normal: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
|
llm_normal_minor: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
llm_topic_judge: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
llm_summary_by_topic: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
llm_emotion_judge: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
embedding: Dict[str, str] = field(default_factory=lambda: {})
|
embedding: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
vlm: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
moderation: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
|
|
||||||
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
MODEL_R1_PROBABILITY: float = 0.8 # R1模型概率
|
||||||
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
MODEL_V3_PROBABILITY: float = 0.1 # V3模型概率
|
||||||
@@ -112,8 +116,6 @@ class BotConfig:
|
|||||||
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
|
config.MODEL_R1_PROBABILITY = response_config.get("model_r1_probability", config.MODEL_R1_PROBABILITY)
|
||||||
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
config.MODEL_V3_PROBABILITY = response_config.get("model_v3_probability", config.MODEL_V3_PROBABILITY)
|
||||||
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
||||||
config.API_USING = response_config.get("api_using", config.API_USING)
|
|
||||||
config.API_PAID = response_config.get("api_paid", config.API_PAID)
|
|
||||||
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
|
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
|
||||||
|
|
||||||
# 加载模型配置
|
# 加载模型配置
|
||||||
@@ -131,6 +133,15 @@ class BotConfig:
|
|||||||
|
|
||||||
if "llm_normal_minor" in model_config:
|
if "llm_normal_minor" in model_config:
|
||||||
config.llm_normal_minor = model_config["llm_normal_minor"]
|
config.llm_normal_minor = model_config["llm_normal_minor"]
|
||||||
|
|
||||||
|
if "llm_topic_judge" in model_config:
|
||||||
|
config.llm_topic_judge = model_config["llm_topic_judge"]
|
||||||
|
|
||||||
|
if "llm_summary_by_topic" in model_config:
|
||||||
|
config.llm_summary_by_topic = model_config["llm_summary_by_topic"]
|
||||||
|
|
||||||
|
if "llm_emotion_judge" in model_config:
|
||||||
|
config.llm_emotion_judge = model_config["llm_emotion_judge"]
|
||||||
|
|
||||||
if "vlm" in model_config:
|
if "vlm" in model_config:
|
||||||
config.vlm = model_config["vlm"]
|
config.vlm = model_config["vlm"]
|
||||||
@@ -138,8 +149,8 @@ class BotConfig:
|
|||||||
if "embedding" in model_config:
|
if "embedding" in model_config:
|
||||||
config.embedding = model_config["embedding"]
|
config.embedding = model_config["embedding"]
|
||||||
|
|
||||||
if "rerank" in model_config:
|
if "moderation" in model_config:
|
||||||
config.rerank = model_config["rerank"]
|
config.moderation = model_config["moderation"]
|
||||||
|
|
||||||
# 消息配置
|
# 消息配置
|
||||||
if "message" in toml_dict:
|
if "message" in toml_dict:
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ class EmojiManager:
|
|||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.db = Database.get_instance()
|
self.db = Database.get_instance()
|
||||||
self._scan_task = None
|
self._scan_task = None
|
||||||
self.llm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
|
self.vlm = LLM_request(model=global_config.vlm, temperature=0.3, max_tokens=1000)
|
||||||
self.lm = LLM_request(model=global_config.llm_normal_minor, max_tokens=1000)
|
self.llm_emotion_judge = LLM_request(model=global_config.llm_normal_minor, max_tokens=60,temperature=0.8) #更高的温度,更少的token(后续可以根据情绪来调整温度)
|
||||||
|
|
||||||
def _ensure_emoji_dir(self):
|
def _ensure_emoji_dir(self):
|
||||||
"""确保表情存储目录存在"""
|
"""确保表情存储目录存在"""
|
||||||
@@ -69,7 +69,17 @@ class EmojiManager:
|
|||||||
raise RuntimeError("EmojiManager not initialized")
|
raise RuntimeError("EmojiManager not initialized")
|
||||||
|
|
||||||
def _ensure_emoji_collection(self):
|
def _ensure_emoji_collection(self):
|
||||||
"""确保emoji集合存在并创建索引"""
|
"""确保emoji集合存在并创建索引
|
||||||
|
|
||||||
|
这个函数用于确保MongoDB数据库中存在emoji集合,并创建必要的索引。
|
||||||
|
|
||||||
|
索引的作用是加快数据库查询速度:
|
||||||
|
- embedding字段的2dsphere索引: 用于加速向量相似度搜索,帮助快速找到相似的表情包
|
||||||
|
- tags字段的普通索引: 加快按标签搜索表情包的速度
|
||||||
|
- filename字段的唯一索引: 确保文件名不重复,同时加快按文件名查找的速度
|
||||||
|
|
||||||
|
没有索引的话,数据库每次查询都需要扫描全部数据,建立索引后可以大大提高查询效率。
|
||||||
|
"""
|
||||||
if 'emoji' not in self.db.db.list_collection_names():
|
if 'emoji' not in self.db.db.list_collection_names():
|
||||||
self.db.db.create_collection('emoji')
|
self.db.db.create_collection('emoji')
|
||||||
self.db.db.emoji.create_index([('embedding', '2dsphere')])
|
self.db.db.emoji.create_index([('embedding', '2dsphere')])
|
||||||
@@ -93,6 +103,11 @@ class EmojiManager:
|
|||||||
text: 输入文本
|
text: 输入文本
|
||||||
Returns:
|
Returns:
|
||||||
Optional[str]: 表情包文件路径,如果没有找到则返回None
|
Optional[str]: 表情包文件路径,如果没有找到则返回None
|
||||||
|
|
||||||
|
|
||||||
|
可不可以通过 配置文件中的指令 来自定义使用表情包的逻辑?
|
||||||
|
我觉得可行
|
||||||
|
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
self._ensure_db()
|
self._ensure_db()
|
||||||
@@ -152,7 +167,8 @@ class EmojiManager:
|
|||||||
{'$inc': {'usage_count': 1}}
|
{'$inc': {'usage_count': 1}}
|
||||||
)
|
)
|
||||||
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
|
logger.success(f"找到匹配的表情包: {selected_emoji.get('discription', '无描述')} (相似度: {similarity:.4f})")
|
||||||
return selected_emoji['path'],"[表情包: %s]" % selected_emoji.get('discription', '无描述')
|
# 稍微改一下文本描述,不然容易产生幻觉,描述已经包含 表情包 了
|
||||||
|
return selected_emoji['path'],"[ %s ]" % selected_emoji.get('discription', '无描述')
|
||||||
|
|
||||||
except Exception as search_error:
|
except Exception as search_error:
|
||||||
logger.error(f"搜索表情包失败: {str(search_error)}")
|
logger.error(f"搜索表情包失败: {str(search_error)}")
|
||||||
@@ -169,7 +185,7 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感'
|
prompt = '这是一个表情包,使用中文简洁的描述一下表情包的内容和表情包所表达的情感'
|
||||||
|
|
||||||
content, _ = await self.llm.generate_response_for_image(prompt, image_base64)
|
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
|
||||||
logger.debug(f"输出描述: {content}")
|
logger.debug(f"输出描述: {content}")
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@@ -181,7 +197,7 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
|
prompt = f'这是一个表情包,请回答这个表情包是否满足\"{global_config.EMOJI_CHECK_PROMPT}\"的要求,是则回答是,否则回答否,不要出现任何其他内容'
|
||||||
|
|
||||||
content, _ = await self.llm.generate_response_for_image(prompt, image_base64)
|
content, _ = await self.vlm.generate_response_for_image(prompt, image_base64)
|
||||||
logger.debug(f"输出描述: {content}")
|
logger.debug(f"输出描述: {content}")
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@@ -193,7 +209,7 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。'
|
prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。'
|
||||||
|
|
||||||
content, _ = await self.lm.generate_response_async(prompt)
|
content, _ = await self.llm_emotion_judge.generate_response_async(prompt)
|
||||||
logger.info(f"输出描述: {content}")
|
logger.info(f"输出描述: {content}")
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ config = driver.config
|
|||||||
|
|
||||||
class TopicIdentifier:
|
class TopicIdentifier:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.llm_client = LLM_request(model=global_config.llm_topic_extract)
|
self.llm_topic_judge = LLM_request(model=global_config.llm_topic_judge)
|
||||||
|
|
||||||
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
|
async def identify_topic_llm(self, text: str) -> Optional[List[str]]:
|
||||||
"""识别消息主题,返回主题列表"""
|
"""识别消息主题,返回主题列表"""
|
||||||
@@ -23,7 +23,7 @@ class TopicIdentifier:
|
|||||||
消息内容:{text}"""
|
消息内容:{text}"""
|
||||||
|
|
||||||
# 使用 LLM_request 类进行请求
|
# 使用 LLM_request 类进行请求
|
||||||
topic, _ = await self.llm_client.generate_response(prompt)
|
topic, _ = await self.llm_topic_judge.generate_response(prompt)
|
||||||
|
|
||||||
if not topic:
|
if not topic:
|
||||||
print(f"\033[1;31m[错误]\033[0m LLM API 返回为空")
|
print(f"\033[1;31m[错误]\033[0m LLM API 返回为空")
|
||||||
|
|||||||
@@ -132,8 +132,8 @@ class Memory_graph:
|
|||||||
class Hippocampus:
|
class Hippocampus:
|
||||||
def __init__(self,memory_graph:Memory_graph):
|
def __init__(self,memory_graph:Memory_graph):
|
||||||
self.memory_graph = memory_graph
|
self.memory_graph = memory_graph
|
||||||
self.llm_model_get_topic = LLM_request(model = global_config.llm_normal_minor,temperature=0.5)
|
self.llm_topic_judge = LLM_request(model = global_config.llm_topic_judge,temperature=0.5)
|
||||||
self.llm_model_summary = LLM_request(model = global_config.llm_normal,temperature=0.5)
|
self.llm_summary_by_topic = LLM_request(model = global_config.llm_summary_by_topic,temperature=0.5)
|
||||||
|
|
||||||
def get_all_node_names(self) -> list:
|
def get_all_node_names(self) -> list:
|
||||||
"""获取记忆图中所有节点的名字列表
|
"""获取记忆图中所有节点的名字列表
|
||||||
@@ -179,7 +179,7 @@ class Hippocampus:
|
|||||||
|
|
||||||
#获取topics
|
#获取topics
|
||||||
topic_num = self.calculate_topic_num(input_text, compress_rate)
|
topic_num = self.calculate_topic_num(input_text, compress_rate)
|
||||||
topics_response = await self.llm_model_get_topic.generate_response(self.find_topic_llm(input_text, topic_num))
|
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(input_text, topic_num))
|
||||||
# 修改话题处理逻辑
|
# 修改话题处理逻辑
|
||||||
# 定义需要过滤的关键词
|
# 定义需要过滤的关键词
|
||||||
filter_keywords = ['表情包', '图片', '回复', '聊天记录']
|
filter_keywords = ['表情包', '图片', '回复', '聊天记录']
|
||||||
@@ -196,7 +196,7 @@ class Hippocampus:
|
|||||||
for topic in filtered_topics:
|
for topic in filtered_topics:
|
||||||
topic_what_prompt = self.topic_what(input_text, topic)
|
topic_what_prompt = self.topic_what(input_text, topic)
|
||||||
# 创建异步任务
|
# 创建异步任务
|
||||||
task = self.llm_model_summary.generate_response_async(topic_what_prompt)
|
task = self.llm_summary_by_topic.generate_response_async(topic_what_prompt)
|
||||||
tasks.append((topic.strip(), task))
|
tasks.append((topic.strip(), task))
|
||||||
|
|
||||||
# 等待所有任务完成
|
# 等待所有任务完成
|
||||||
@@ -506,7 +506,7 @@ class Hippocampus:
|
|||||||
Returns:
|
Returns:
|
||||||
list: 识别出的主题列表
|
list: 识别出的主题列表
|
||||||
"""
|
"""
|
||||||
topics_response = await self.llm_model_get_topic.generate_response(self.find_topic_llm(text, 5))
|
topics_response = await self.llm_topic_judge.generate_response(self.find_topic_llm(text, 5))
|
||||||
# print(f"话题: {topics_response[0]}")
|
# print(f"话题: {topics_response[0]}")
|
||||||
topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
topics = [topic.strip() for topic in topics_response[0].replace(",", ",").replace("、", ",").replace(" ", ",").split(",") if topic.strip()]
|
||||||
# print(f"话题: {topics}")
|
# print(f"话题: {topics}")
|
||||||
|
|||||||
@@ -197,13 +197,12 @@ class LLM_request:
|
|||||||
)
|
)
|
||||||
return content, reasoning_content
|
return content, reasoning_content
|
||||||
|
|
||||||
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
async def generate_response_async(self, prompt: str, **kwargs) -> Union[str, Tuple[str, str]]:
|
||||||
"""异步方式根据输入的提示生成模型的响应"""
|
"""异步方式根据输入的提示生成模型的响应"""
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
data = {
|
data = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"temperature": 0.5,
|
|
||||||
"max_tokens": global_config.max_response_length,
|
"max_tokens": global_config.max_response_length,
|
||||||
**self.params
|
**self.params
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user