feat:精简升级工作记忆模块

This commit is contained in:
SengokuCola
2025-06-14 11:41:34 +08:00
parent e6f93d7dbe
commit 189a68023f
6 changed files with 276 additions and 526 deletions

View File

@@ -8,37 +8,48 @@ from src.chat.utils.chat_message_builder import (
num_new_messages_since,
get_person_id_list,
)
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
from typing import Optional
import difflib
from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入
from src.chat.message_receive.message import MessageRecv
from src.chat.heart_flow.observation.observation import Observation
from src.common.logger import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt
logger = get_logger("observation")
# 定义提示模板
Prompt(
"""这是qq群聊的聊天记录请总结以下聊天记录的主题
{chat_logs}
用一句话概括,包括人物、事件和主要信息,不要分点。""",
概括这段聊天记录的主题和主要内容
主题简短的概括包括时间人物和事件不要超过10个字
内容具体的信息内容包括人物、事件和信息不要超过100个字不要分点。
请用json格式返回格式如下
{{
"theme": "主题",
"content": "内容"
}}
""",
"chat_summary_group_prompt", # Template for group chat
)
Prompt(
"""这是你和{chat_target}的私聊记录,请总结以下聊天记录的主题:
{chat_logs}
请用一句话概括,包括事件,时间,和主要信息,不要分点。""",
请用一句话概括,包括事件,时间,和主要信息,不要分点。
主题简短的介绍不要超过10个字
内容:包括人物、事件和主要信息,不要分点。
请用json格式返回格式如下
{{
"theme": "主题",
"content": "内容"
}}""",
"chat_summary_private_prompt", # Template for private chat
)
# --- End Prompt Template Definition ---
# 聊天观察
class ChattingObservation(Observation):
def __init__(self, chat_id):
super().__init__(chat_id)
@@ -47,7 +58,6 @@ class ChattingObservation(Observation):
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
# --- Other attributes initialized in __init__ ---
self.talking_message = []
self.talking_message_str = ""
self.talking_message_str_truncate = ""
@@ -55,13 +65,10 @@ class ChattingObservation(Observation):
self.nick_name = global_config.bot.alias_names
self.max_now_obs_len = global_config.focus_chat.observation_context_size
self.overlap_len = global_config.focus_chat.compressed_length
self.mid_memories = []
self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
self.mid_memory_info = ""
self.person_list = []
self.compressor_prompt = ""
self.oldest_messages = []
self.oldest_messages_str = ""
self.compressor_prompt = ""
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
@@ -79,41 +86,11 @@ class ChattingObservation(Observation):
"talking_message_str_truncate": self.talking_message_str_truncate,
"name": self.name,
"nick_name": self.nick_name,
"mid_memory_info": self.mid_memory_info,
"person_list": self.person_list,
"oldest_messages_str": self.oldest_messages_str,
"compressor_prompt": self.compressor_prompt,
"last_observe_time": self.last_observe_time,
}
# 进行一次观察 返回观察结果observe_info
def get_observe_info(self, ids=None):
mid_memory_str = ""
if ids:
for id in ids:
print(f"id{id}")
try:
for mid_memory in self.mid_memories:
if mid_memory["id"] == id:
mid_memory_by_id = mid_memory
msg_str = ""
for msg in mid_memory_by_id["messages"]:
msg_str += f"{msg['detailed_plain_text']}"
# time_diff = int((datetime.now().timestamp() - mid_memory_by_id["created_at"]) / 60)
# mid_memory_str += f"距离现在{time_diff}分钟前:\n{msg_str}\n"
mid_memory_str += f"{msg_str}\n"
except Exception as e:
logger.error(f"获取mid_memory_id失败: {e}")
traceback.print_exc()
return self.talking_message_str
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
else:
mid_memory_str = "之前的聊天内容:\n"
for mid_memory in self.mid_memories:
mid_memory_str += f"{mid_memory['theme']}\n"
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
return self.talking_message_str
def search_message_by_text(self, text: str) -> Optional[MessageRecv]:
"""
@@ -128,7 +105,6 @@ class ChattingObservation(Observation):
for message in reverse_talking_message:
if message["processed_plain_text"] == text:
find_msg = message
# logger.debug(f"找到的锚定消息find_msg: {find_msg}")
break
else:
raw_message = message.get("raw_message")
@@ -137,11 +113,11 @@ class ChattingObservation(Observation):
else:
similarity = difflib.SequenceMatcher(None, text, message.get("processed_plain_text", "")).ratio()
msg_list.append({"message": message, "similarity": similarity})
# logger.debug(f"对锚定消息检查message: {message['processed_plain_text']},similarity: {similarity}")
if not find_msg:
if msg_list:
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
if msg_list[0]["similarity"] >= 0.9: # 只返回相似度大于等于0.5的消息
if msg_list[0]["similarity"] >= 0.9:
find_msg = msg_list[0]["message"]
else:
logger.debug("没有找到锚定消息,相似度低")
@@ -150,9 +126,6 @@ class ChattingObservation(Observation):
logger.debug("没有找到锚定消息,没有消息捕获")
return None
# logger.debug(f"找到的锚定消息find_msg: {find_msg}")
# 创建所需的user_info字段
user_info = {
"platform": find_msg.get("user_platform", ""),
"user_id": find_msg.get("user_id", ""),
@@ -160,7 +133,6 @@ class ChattingObservation(Observation):
"user_cardname": find_msg.get("user_cardname", ""),
}
# 创建所需的group_info字段如果是群聊的话
group_info = {}
if find_msg.get("chat_info_group_id"):
group_info = {
@@ -194,9 +166,7 @@ class ChattingObservation(Observation):
"detailed_plain_text": find_msg.get("processed_plain_text"),
"processed_plain_text": find_msg.get("processed_plain_text"),
}
# print(f"message_dict: {message_dict}")
find_rec_msg = MessageRecv(message_dict)
# logger.debug(f"锚定消息处理后find_rec_msg: {find_rec_msg}")
return find_rec_msg
async def observe(self):
@@ -209,8 +179,6 @@ class ChattingObservation(Observation):
limit_mode="latest",
)
# print(f"new_messages_list: {new_messages_list}")
last_obs_time_mark = self.last_observe_time
if new_messages_list:
self.last_observe_time = new_messages_list[-1]["time"]
@@ -220,60 +188,47 @@ class ChattingObservation(Observation):
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
oldest_messages = self.talking_message[:messages_to_remove_count]
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
self.talking_message = self.talking_message[messages_to_remove_count:]
# print(f"压缩中oldest_messages: {oldest_messages}")
# 构建压缩提示
oldest_messages_str = build_readable_messages(
messages=oldest_messages, timestamp_mode="normal_no_YMD", read_mark=0, show_actions=True
messages=oldest_messages,
timestamp_mode="normal_no_YMD",
read_mark=0,
show_actions=True
)
# --- Build prompt using template ---
prompt = None # Initialize prompt as None
try:
# 构建 Prompt - 根据 is_group_chat 选择模板
if self.is_group_chat:
prompt_template_name = "chat_summary_group_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt_template_name, chat_logs=oldest_messages_str
# 根据聊天类型选择提示模板
if self.is_group_chat:
prompt_template_name = "chat_summary_group_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt_template_name,
chat_logs=oldest_messages_str
)
else:
prompt_template_name = "chat_summary_private_prompt"
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = (
self.chat_target_info.get("person_name")
or self.chat_target_info.get("user_nickname")
or chat_target_name
)
else:
# For private chat, add chat_target to the prompt variables
prompt_template_name = "chat_summary_private_prompt"
# Determine the target name for the prompt
chat_target_name = "对方" # Default fallback
if self.chat_target_info:
# Prioritize person_name, then nickname
chat_target_name = (
self.chat_target_info.get("person_name")
or self.chat_target_info.get("user_nickname")
or chat_target_name
)
prompt = await global_prompt_manager.format_prompt(
prompt_template_name,
chat_target=chat_target_name,
chat_logs=oldest_messages_str,
)
# Format the private chat prompt
prompt = await global_prompt_manager.format_prompt(
prompt_template_name,
# Assuming the private prompt template uses {chat_target}
chat_target=chat_target_name,
chat_logs=oldest_messages_str,
)
except Exception as e:
logger.error(f"构建总结 Prompt 失败 for chat {self.chat_id}: {e}")
# prompt remains None
self.compressor_prompt = prompt
if prompt: # Check if prompt was built successfully
self.compressor_prompt = prompt
self.oldest_messages = oldest_messages
self.oldest_messages_str = oldest_messages_str
# 构建中
# print(f"构建中self.talking_message: {self.talking_message}")
# 构建当前消息
self.talking_message_str = build_readable_messages(
messages=self.talking_message,
timestamp_mode="lite",
read_mark=last_obs_time_mark,
show_actions=True,
)
# print(f"构建中self.talking_message_str: {self.talking_message_str}")
self.talking_message_str_truncate = build_readable_messages(
messages=self.talking_message,
timestamp_mode="normal_no_YMD",
@@ -281,15 +236,12 @@ class ChattingObservation(Observation):
truncate=True,
show_actions=True,
)
# print(f"构建中self.talking_message_str_truncate: {self.talking_message_str_truncate}")
self.person_list = await get_person_id_list(self.talking_message)
# print(f"构建中self.person_list: {self.person_list}")
logger.debug(
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
)
# logger.debug(
# f"Chat {self.chat_id} - 现在聊天内容:{self.talking_message_str}"
# )
async def has_new_messages_since(self, timestamp: float) -> bool:
"""检查指定时间戳之后是否有新消息"""

View File

@@ -31,18 +31,4 @@ class WorkingMemoryObservation:
return self.retrieved_working_memory
async def observe(self):
pass
def to_dict(self) -> dict:
"""将观察对象转换为可序列化的字典"""
return {
"observe_info": self.observe_info,
"observe_id": self.observe_id,
"last_observe_time": self.last_observe_time,
"working_memory": self.working_memory.to_dict()
if hasattr(self.working_memory, "to_dict")
else str(self.working_memory),
"retrieved_working_memory": [
item.to_dict() if hasattr(item, "to_dict") else str(item) for item in self.retrieved_working_memory
],
}
pass