Files
Mofox-Core/src/heart_flow/sub_heartflow.py
2025-04-10 16:18:45 +08:00

534 lines
23 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from .observation import Observation
import asyncio
from src.plugins.moods.moods import MoodManager
from src.plugins.models.utils_model import LLM_request
from src.plugins.config.config import global_config
import re
import time
from src.plugins.schedule.schedule_generator import bot_schedule
from src.plugins.memory_system.Hippocampus import HippocampusManager
from src.common.logger import get_module_logger, LogConfig, SUB_HEARTFLOW_STYLE_CONFIG # noqa: E402
from src.plugins.chat.utils import get_embedding
from src.common.database import db
from typing import Union
from src.individuality.individuality import Individuality
import random
from src.plugins.chat.chat_stream import ChatStream
from src.plugins.person_info.relationship_manager import relationship_manager
from src.plugins.chat.utils import get_recent_group_speaker
subheartflow_config = LogConfig(
# 使用海马体专用样式
console_format=SUB_HEARTFLOW_STYLE_CONFIG["console_format"],
file_format=SUB_HEARTFLOW_STYLE_CONFIG["file_format"],
)
logger = get_module_logger("subheartflow", config=subheartflow_config)
class CurrentState:
def __init__(self):
self.willing = 0
self.current_state_info = ""
self.mood_manager = MoodManager()
self.mood = self.mood_manager.get_prompt()
def update_current_state_info(self):
self.current_state_info = self.mood_manager.get_current_mood()
class SubHeartflow:
def __init__(self, subheartflow_id):
self.subheartflow_id = subheartflow_id
self.current_mind = ""
self.past_mind = []
self.current_state: CurrentState = CurrentState()
self.llm_model = LLM_request(
model=global_config.llm_sub_heartflow, temperature=0.2, max_tokens=600, request_type="sub_heart_flow"
)
self.main_heartflow_info = ""
self.last_reply_time = time.time()
self.last_active_time = time.time() # 添加最后激活时间
if not self.current_mind:
self.current_mind = "你什么也没想"
self.is_active = False
self.observations: list[Observation] = []
self.running_knowledges = []
self.bot_name = global_config.BOT_NICKNAME
def add_observation(self, observation: Observation):
"""添加一个新的observation对象到列表中如果已存在相同id的observation则不添加"""
# 查找是否存在相同id的observation
for existing_obs in self.observations:
if existing_obs.observe_id == observation.observe_id:
# 如果找到相同id的observation直接返回
return
# 如果没有找到相同id的observation则添加新的
self.observations.append(observation)
def remove_observation(self, observation: Observation):
"""从列表中移除一个observation对象"""
if observation in self.observations:
self.observations.remove(observation)
def get_all_observations(self) -> list[Observation]:
"""获取所有observation对象"""
return self.observations
def clear_observations(self):
"""清空所有observation对象"""
self.observations.clear()
async def subheartflow_start_working(self):
while True:
current_time = time.time()
if (
current_time - self.last_reply_time > global_config.sub_heart_flow_freeze_time
): # 120秒无回复/不在场,冻结
self.is_active = False
await asyncio.sleep(global_config.sub_heart_flow_update_interval) # 每60秒检查一次
else:
self.is_active = True
self.last_active_time = current_time # 更新最后激活时间
self.current_state.update_current_state_info()
# await self.do_a_thinking()
# await self.judge_willing()
await asyncio.sleep(global_config.sub_heart_flow_update_interval)
# 检查是否超过10分钟没有激活
if (
current_time - self.last_active_time > global_config.sub_heart_flow_stop_time
): # 5分钟无回复/不在场,销毁
logger.info(f"子心流 {self.subheartflow_id} 已经5分钟没有激活正在销毁...")
break # 退出循环,销毁自己
async def do_observe(self):
observation = self.observations[0]
await observation.observe()
async def do_thinking_before_reply(self, message_txt:str, sender_name:str, chat_stream:ChatStream):
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
# mood_info = "你很生气,很愤怒"
observation = self.observations[0]
chat_observe_info = observation.observe_info
# print(f"chat_observe_info{chat_observe_info}")
# 开始构建prompt
prompt_personality = f"你的名字是{self.bot_name},你"
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
# 关系
who_chat_in_group = [
(chat_stream.user_info.platform, chat_stream.user_info.user_id, chat_stream.user_info.user_nickname)
]
who_chat_in_group += get_recent_group_speaker(
chat_stream.stream_id,
(chat_stream.user_info.platform, chat_stream.user_info.user_id),
limit=global_config.MAX_CONTEXT_SIZE,
)
relation_prompt = ""
for person in who_chat_in_group:
relation_prompt += await relationship_manager.build_relationship_info(person)
relation_prompt_all = (
f"{relation_prompt}关系等级越大,关系越好,请分析聊天记录,"
f"根据你和说话者{sender_name}的关系和态度进行回复,明确你的立场和情感。"
)
# 调取记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
text=chat_observe_info, max_memory_num=2, max_memory_length=2, max_depth=3, fast_retrieval=False
)
if related_memory:
related_memory_info = ""
for memory in related_memory:
related_memory_info += memory[1]
else:
related_memory_info = ""
related_info, grouped_results = await self.get_prompt_info(chat_observe_info + message_txt, 0.4)
# print(related_info)
for _topic, results in grouped_results.items():
for result in results:
# print(result)
self.running_knowledges.append(result)
# print(f"相关记忆:{related_memory_info}")
schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False)
prompt = ""
# prompt += f"麦麦的总体想法是:{self.main_heartflow_info}\n\n"
prompt += f"{relation_prompt_all}\n"
prompt += f"{prompt_personality}\n"
# prompt += f"你刚刚在做的事情是:{schedule_info}\n"
# if related_memory_info:
# prompt += f"你想起来你之前见过的回忆:{related_memory_info}。\n以上是你的回忆不一定是目前聊天里的人说的也不一定是现在发生的事情请记住。\n"
# if related_info:
# prompt += f"你想起你知道:{related_info}\n"
prompt += f"刚刚你的想法是{current_thinking_info}。如果有新的内容,记得转换话题\n"
prompt += "-----------------------------------\n"
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
prompt += f"你现在{mood_info}\n"
prompt += f"你注意到{sender_name}刚刚说:{message_txt}\n"
prompt += "现在你接下去继续思考,产生新的想法,不要分点输出,输出连贯的内心独白,不要太长,"
prompt += "思考时可以想想如何对群聊内容进行回复。回复的要求是:平淡一些,简短一些,说中文,不要刻意突出自身学科背景,尽量不要说你说过的话 ,注意只输出回复内容。"
prompt += "请注意不要输出多余内容(包括前后缀,冒号和引号,括号,表情等)"
prompt += f"记得结合上述的消息,生成符合内心想法的内心独白,文字不要浮夸,注意你就是{self.bot_name}{self.bot_name}指的就是你。"
try:
response, reasoning_content = await self.llm_model.generate_response_async(prompt)
except Exception as e:
logger.error(f"回复前内心独白获取失败: {e}")
response = ""
self.update_current_mind(response)
self.current_mind = response
logger.info(f"prompt:\n{prompt}\n")
logger.info(f"麦麦的思考前脑内状态:{self.current_mind}")
return self.current_mind ,self.past_mind
async def do_thinking_after_reply(self, reply_content, chat_talking_prompt):
# print("麦麦回复之后脑袋转起来了")
# 开始构建prompt
prompt_personality = f"你的名字是{self.bot_name},你"
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
observation = self.observations[0]
chat_observe_info = observation.observe_info
message_new_info = chat_talking_prompt
reply_info = reply_content
# schedule_info = bot_schedule.get_current_num_task(num=1, time_info=False)
prompt = ""
# prompt += f"你现在正在做的事情是:{schedule_info}\n"
prompt += f"{prompt_personality}\n"
prompt += f"现在你正在上网和qq群里的网友们聊天群里正在聊的话题是{chat_observe_info}\n"
prompt += f"刚刚你的想法是{current_thinking_info}"
prompt += f"你现在看到了网友们发的新消息:{message_new_info}\n"
prompt += f"你刚刚回复了群友们:{reply_info}"
prompt += f"你现在{mood_info}"
prompt += "现在你接下去继续思考,产生新的想法,记得保留你刚刚的想法,不要分点输出,输出连贯的内心独白"
prompt += "不要太长,但是记得结合上述的消息,要记得你的人设,关注聊天和新内容,关注你回复的内容,不要思考太多:"
try:
response, reasoning_content = await self.llm_model.generate_response_async(prompt)
except Exception as e:
logger.error(f"回复后内心独白获取失败: {e}")
response = ""
self.update_current_mind(response)
self.current_mind = response
logger.info(f"麦麦回复后的脑内状态:{self.current_mind}")
self.last_reply_time = time.time()
async def judge_willing(self):
# 开始构建prompt
prompt_personality = ""
# person
individuality = Individuality.get_instance()
personality_core = individuality.personality.personality_core
prompt_personality += personality_core
personality_sides = individuality.personality.personality_sides
random.shuffle(personality_sides)
prompt_personality += f",{personality_sides[0]}"
identity_detail = individuality.identity.identity_detail
random.shuffle(identity_detail)
prompt_personality += f",{identity_detail[0]}"
# print("麦麦闹情绪了1")
current_thinking_info = self.current_mind
mood_info = self.current_state.mood
# print("麦麦闹情绪了2")
prompt = ""
prompt += f"{prompt_personality}\n"
prompt += "现在你正在上网和qq群里的网友们聊天"
prompt += f"你现在的想法是{current_thinking_info}"
prompt += f"你现在{mood_info}"
prompt += "现在请你思考你想不想发言或者回复请你输出一个数字1-101表示非常不想10表示非常想。"
prompt += "请你用<>包裹你的回复意愿,输出<1>表示不想回复,输出<10>表示非常想回复。请你考虑,你完全可以不回复"
try:
response, reasoning_content = await self.llm_model.generate_response_async(prompt)
# 解析willing值
willing_match = re.search(r"<(\d+)>", response)
except Exception as e:
logger.error(f"意愿判断获取失败: {e}")
willing_match = None
if willing_match:
self.current_state.willing = int(willing_match.group(1))
else:
self.current_state.willing = 0
return self.current_state.willing
def update_current_mind(self, response):
self.past_mind.append(self.current_mind)
self.current_mind = response
async def get_prompt_info(self, message: str, threshold: float):
start_time = time.time()
related_info = ""
logger.debug(f"获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
# 1. 先从LLM获取主题类似于记忆系统的做法
topics = []
# try:
# # 先尝试使用记忆系统的方法获取主题
# hippocampus = HippocampusManager.get_instance()._hippocampus
# topic_num = min(5, max(1, int(len(message) * 0.1)))
# topics_response = await hippocampus.llm_topic_judge.generate_response(hippocampus.find_topic_llm(message, topic_num))
# # 提取关键词
# topics = re.findall(r"<([^>]+)>", topics_response[0])
# if not topics:
# topics = []
# else:
# topics = [
# topic.strip()
# for topic in ",".join(topics).replace("", ",").replace("、", ",").replace(" ", ",").split(",")
# if topic.strip()
# ]
# logger.info(f"从LLM提取的主题: {', '.join(topics)}")
# except Exception as e:
# logger.error(f"从LLM提取主题失败: {str(e)}")
# # 如果LLM提取失败使用jieba分词提取关键词作为备选
# words = jieba.cut(message)
# topics = [word for word in words if len(word) > 1][:5]
# logger.info(f"使用jieba提取的主题: {', '.join(topics)}")
# 如果无法提取到主题,直接使用整个消息
if not topics:
logger.debug("未能提取到任何主题,使用整个消息进行查询")
embedding = await get_embedding(message, request_type="info_retrieval")
if not embedding:
logger.error("获取消息嵌入向量失败")
return ""
related_info = self.get_info_from_db(embedding, limit=3, threshold=threshold)
logger.info(f"知识库检索完成,总耗时: {time.time() - start_time:.3f}")
return related_info, {}
# 2. 对每个主题进行知识库查询
logger.info(f"开始处理{len(topics)}个主题的知识库查询")
# 优化批量获取嵌入向量减少API调用
embeddings = {}
topics_batch = [topic for topic in topics if len(topic) > 0]
if message: # 确保消息非空
topics_batch.append(message)
# 批量获取嵌入向量
embed_start_time = time.time()
for text in topics_batch:
if not text or len(text.strip()) == 0:
continue
try:
embedding = await get_embedding(text, request_type="info_retrieval")
if embedding:
embeddings[text] = embedding
else:
logger.warning(f"获取'{text}'的嵌入向量失败")
except Exception as e:
logger.error(f"获取'{text}'的嵌入向量时发生错误: {str(e)}")
logger.info(f"批量获取嵌入向量完成,耗时: {time.time() - embed_start_time:.3f}")
if not embeddings:
logger.error("所有嵌入向量获取失败")
return ""
# 3. 对每个主题进行知识库查询
all_results = []
query_start_time = time.time()
# 首先添加原始消息的查询结果
if message in embeddings:
original_results = self.get_info_from_db(embeddings[message], limit=3, threshold=threshold, return_raw=True)
if original_results:
for result in original_results:
result["topic"] = "原始消息"
all_results.extend(original_results)
logger.info(f"原始消息查询到{len(original_results)}条结果")
# 然后添加每个主题的查询结果
for topic in topics:
if not topic or topic not in embeddings:
continue
try:
topic_results = self.get_info_from_db(embeddings[topic], limit=3, threshold=threshold, return_raw=True)
if topic_results:
# 添加主题标记
for result in topic_results:
result["topic"] = topic
all_results.extend(topic_results)
logger.info(f"主题'{topic}'查询到{len(topic_results)}条结果")
except Exception as e:
logger.error(f"查询主题'{topic}'时发生错误: {str(e)}")
logger.info(f"知识库查询完成,耗时: {time.time() - query_start_time:.3f}秒,共获取{len(all_results)}条结果")
# 4. 去重和过滤
process_start_time = time.time()
unique_contents = set()
filtered_results = []
for result in all_results:
content = result["content"]
if content not in unique_contents:
unique_contents.add(content)
filtered_results.append(result)
# 5. 按相似度排序
filtered_results.sort(key=lambda x: x["similarity"], reverse=True)
# 6. 限制总数量最多10条
filtered_results = filtered_results[:10]
logger.info(
f"结果处理完成,耗时: {time.time() - process_start_time:.3f}秒,过滤后剩余{len(filtered_results)}条结果"
)
# 7. 格式化输出
if filtered_results:
format_start_time = time.time()
grouped_results = {}
for result in filtered_results:
topic = result["topic"]
if topic not in grouped_results:
grouped_results[topic] = []
grouped_results[topic].append(result)
# 按主题组织输出
for topic, results in grouped_results.items():
related_info += f"【主题: {topic}\n"
for _i, result in enumerate(results, 1):
_similarity = result["similarity"]
content = result["content"].strip()
# 调试:为内容添加序号和相似度信息
# related_info += f"{i}. [{similarity:.2f}] {content}\n"
related_info += f"{content}\n"
related_info += "\n"
logger.info(f"格式化输出完成,耗时: {time.time() - format_start_time:.3f}")
logger.info(f"知识库检索总耗时: {time.time() - start_time:.3f}")
return related_info, grouped_results
def get_info_from_db(
self, query_embedding: list, limit: int = 1, threshold: float = 0.5, return_raw: bool = False
) -> Union[str, list]:
if not query_embedding:
return "" if not return_raw else []
# 使用余弦相似度计算
pipeline = [
{
"$addFields": {
"dotProduct": {
"$reduce": {
"input": {"$range": [0, {"$size": "$embedding"}]},
"initialValue": 0,
"in": {
"$add": [
"$$value",
{
"$multiply": [
{"$arrayElemAt": ["$embedding", "$$this"]},
{"$arrayElemAt": [query_embedding, "$$this"]},
]
},
]
},
}
},
"magnitude1": {
"$sqrt": {
"$reduce": {
"input": "$embedding",
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
"magnitude2": {
"$sqrt": {
"$reduce": {
"input": query_embedding,
"initialValue": 0,
"in": {"$add": ["$$value", {"$multiply": ["$$this", "$$this"]}]},
}
}
},
}
},
{"$addFields": {"similarity": {"$divide": ["$dotProduct", {"$multiply": ["$magnitude1", "$magnitude2"]}]}}},
{
"$match": {
"similarity": {"$gte": threshold} # 只保留相似度大于等于阈值的结果
}
},
{"$sort": {"similarity": -1}},
{"$limit": limit},
{"$project": {"content": 1, "similarity": 1}},
]
results = list(db.knowledges.aggregate(pipeline))
logger.debug(f"知识库查询结果数量: {len(results)}")
if not results:
return "" if not return_raw else []
if return_raw:
return results
else:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
# subheartflow = SubHeartflow()