feat:精简升级工作记忆模块
This commit is contained in:
@@ -8,37 +8,48 @@ from src.chat.utils.chat_message_builder import (
|
||||
num_new_messages_since,
|
||||
get_person_id_list,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
|
||||
from typing import Optional
|
||||
import difflib
|
||||
from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
from src.chat.utils.prompt_builder import Prompt
|
||||
|
||||
|
||||
logger = get_logger("observation")
|
||||
|
||||
|
||||
# 定义提示模板
|
||||
Prompt(
|
||||
"""这是qq群聊的聊天记录,请总结以下聊天记录的主题:
|
||||
{chat_logs}
|
||||
请用一句话概括,包括人物、事件和主要信息,不要分点。""",
|
||||
请概括这段聊天记录的主题和主要内容
|
||||
主题:简短的概括,包括时间,人物和事件,不要超过10个字
|
||||
内容:具体的信息内容,包括人物、事件和信息,不要超过100个字,不要分点。
|
||||
|
||||
请用json格式返回,格式如下:
|
||||
{{
|
||||
"theme": "主题",
|
||||
"content": "内容"
|
||||
}}
|
||||
""",
|
||||
"chat_summary_group_prompt", # Template for group chat
|
||||
)
|
||||
|
||||
Prompt(
|
||||
"""这是你和{chat_target}的私聊记录,请总结以下聊天记录的主题:
|
||||
{chat_logs}
|
||||
请用一句话概括,包括事件,时间,和主要信息,不要分点。""",
|
||||
请用一句话概括,包括事件,时间,和主要信息,不要分点。
|
||||
主题:简短的介绍,不要超过10个字
|
||||
内容:包括人物、事件和主要信息,不要分点。
|
||||
|
||||
请用json格式返回,格式如下:
|
||||
{{
|
||||
"theme": "主题",
|
||||
"content": "内容"
|
||||
}}""",
|
||||
"chat_summary_private_prompt", # Template for private chat
|
||||
)
|
||||
# --- End Prompt Template Definition ---
|
||||
|
||||
|
||||
# 聊天观察
|
||||
class ChattingObservation(Observation):
|
||||
def __init__(self, chat_id):
|
||||
super().__init__(chat_id)
|
||||
@@ -47,7 +58,6 @@ class ChattingObservation(Observation):
|
||||
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
# --- Other attributes initialized in __init__ ---
|
||||
self.talking_message = []
|
||||
self.talking_message_str = ""
|
||||
self.talking_message_str_truncate = ""
|
||||
@@ -55,13 +65,10 @@ class ChattingObservation(Observation):
|
||||
self.nick_name = global_config.bot.alias_names
|
||||
self.max_now_obs_len = global_config.focus_chat.observation_context_size
|
||||
self.overlap_len = global_config.focus_chat.compressed_length
|
||||
self.mid_memories = []
|
||||
self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
|
||||
self.mid_memory_info = ""
|
||||
self.person_list = []
|
||||
self.compressor_prompt = ""
|
||||
self.oldest_messages = []
|
||||
self.oldest_messages_str = ""
|
||||
self.compressor_prompt = ""
|
||||
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
|
||||
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
|
||||
@@ -79,41 +86,11 @@ class ChattingObservation(Observation):
|
||||
"talking_message_str_truncate": self.talking_message_str_truncate,
|
||||
"name": self.name,
|
||||
"nick_name": self.nick_name,
|
||||
"mid_memory_info": self.mid_memory_info,
|
||||
"person_list": self.person_list,
|
||||
"oldest_messages_str": self.oldest_messages_str,
|
||||
"compressor_prompt": self.compressor_prompt,
|
||||
"last_observe_time": self.last_observe_time,
|
||||
}
|
||||
|
||||
# 进行一次观察 返回观察结果observe_info
|
||||
def get_observe_info(self, ids=None):
|
||||
mid_memory_str = ""
|
||||
if ids:
|
||||
for id in ids:
|
||||
print(f"id:{id}")
|
||||
try:
|
||||
for mid_memory in self.mid_memories:
|
||||
if mid_memory["id"] == id:
|
||||
mid_memory_by_id = mid_memory
|
||||
msg_str = ""
|
||||
for msg in mid_memory_by_id["messages"]:
|
||||
msg_str += f"{msg['detailed_plain_text']}"
|
||||
# time_diff = int((datetime.now().timestamp() - mid_memory_by_id["created_at"]) / 60)
|
||||
# mid_memory_str += f"距离现在{time_diff}分钟前:\n{msg_str}\n"
|
||||
mid_memory_str += f"{msg_str}\n"
|
||||
except Exception as e:
|
||||
logger.error(f"获取mid_memory_id失败: {e}")
|
||||
traceback.print_exc()
|
||||
return self.talking_message_str
|
||||
|
||||
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
|
||||
|
||||
else:
|
||||
mid_memory_str = "之前的聊天内容:\n"
|
||||
for mid_memory in self.mid_memories:
|
||||
mid_memory_str += f"{mid_memory['theme']}\n"
|
||||
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
|
||||
return self.talking_message_str
|
||||
|
||||
def search_message_by_text(self, text: str) -> Optional[MessageRecv]:
|
||||
"""
|
||||
@@ -128,7 +105,6 @@ class ChattingObservation(Observation):
|
||||
for message in reverse_talking_message:
|
||||
if message["processed_plain_text"] == text:
|
||||
find_msg = message
|
||||
# logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
|
||||
break
|
||||
else:
|
||||
raw_message = message.get("raw_message")
|
||||
@@ -137,11 +113,11 @@ class ChattingObservation(Observation):
|
||||
else:
|
||||
similarity = difflib.SequenceMatcher(None, text, message.get("processed_plain_text", "")).ratio()
|
||||
msg_list.append({"message": message, "similarity": similarity})
|
||||
# logger.debug(f"对锚定消息检查:message: {message['processed_plain_text']},similarity: {similarity}")
|
||||
|
||||
if not find_msg:
|
||||
if msg_list:
|
||||
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
|
||||
if msg_list[0]["similarity"] >= 0.9: # 只返回相似度大于等于0.5的消息
|
||||
if msg_list[0]["similarity"] >= 0.9:
|
||||
find_msg = msg_list[0]["message"]
|
||||
else:
|
||||
logger.debug("没有找到锚定消息,相似度低")
|
||||
@@ -150,9 +126,6 @@ class ChattingObservation(Observation):
|
||||
logger.debug("没有找到锚定消息,没有消息捕获")
|
||||
return None
|
||||
|
||||
# logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
|
||||
|
||||
# 创建所需的user_info字段
|
||||
user_info = {
|
||||
"platform": find_msg.get("user_platform", ""),
|
||||
"user_id": find_msg.get("user_id", ""),
|
||||
@@ -160,7 +133,6 @@ class ChattingObservation(Observation):
|
||||
"user_cardname": find_msg.get("user_cardname", ""),
|
||||
}
|
||||
|
||||
# 创建所需的group_info字段,如果是群聊的话
|
||||
group_info = {}
|
||||
if find_msg.get("chat_info_group_id"):
|
||||
group_info = {
|
||||
@@ -194,9 +166,7 @@ class ChattingObservation(Observation):
|
||||
"detailed_plain_text": find_msg.get("processed_plain_text"),
|
||||
"processed_plain_text": find_msg.get("processed_plain_text"),
|
||||
}
|
||||
# print(f"message_dict: {message_dict}")
|
||||
find_rec_msg = MessageRecv(message_dict)
|
||||
# logger.debug(f"锚定消息处理后:find_rec_msg: {find_rec_msg}")
|
||||
return find_rec_msg
|
||||
|
||||
async def observe(self):
|
||||
@@ -209,8 +179,6 @@ class ChattingObservation(Observation):
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
# print(f"new_messages_list: {new_messages_list}")
|
||||
|
||||
last_obs_time_mark = self.last_observe_time
|
||||
if new_messages_list:
|
||||
self.last_observe_time = new_messages_list[-1]["time"]
|
||||
@@ -220,60 +188,47 @@ class ChattingObservation(Observation):
|
||||
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
|
||||
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
|
||||
oldest_messages = self.talking_message[:messages_to_remove_count]
|
||||
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
|
||||
self.talking_message = self.talking_message[messages_to_remove_count:]
|
||||
|
||||
# print(f"压缩中:oldest_messages: {oldest_messages}")
|
||||
# 构建压缩提示
|
||||
oldest_messages_str = build_readable_messages(
|
||||
messages=oldest_messages, timestamp_mode="normal_no_YMD", read_mark=0, show_actions=True
|
||||
messages=oldest_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=0,
|
||||
show_actions=True
|
||||
)
|
||||
|
||||
# --- Build prompt using template ---
|
||||
prompt = None # Initialize prompt as None
|
||||
try:
|
||||
# 构建 Prompt - 根据 is_group_chat 选择模板
|
||||
if self.is_group_chat:
|
||||
prompt_template_name = "chat_summary_group_prompt"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt_template_name, chat_logs=oldest_messages_str
|
||||
# 根据聊天类型选择提示模板
|
||||
if self.is_group_chat:
|
||||
prompt_template_name = "chat_summary_group_prompt"
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt_template_name,
|
||||
chat_logs=oldest_messages_str
|
||||
)
|
||||
else:
|
||||
prompt_template_name = "chat_summary_private_prompt"
|
||||
chat_target_name = "对方"
|
||||
if self.chat_target_info:
|
||||
chat_target_name = (
|
||||
self.chat_target_info.get("person_name")
|
||||
or self.chat_target_info.get("user_nickname")
|
||||
or chat_target_name
|
||||
)
|
||||
else:
|
||||
# For private chat, add chat_target to the prompt variables
|
||||
prompt_template_name = "chat_summary_private_prompt"
|
||||
# Determine the target name for the prompt
|
||||
chat_target_name = "对方" # Default fallback
|
||||
if self.chat_target_info:
|
||||
# Prioritize person_name, then nickname
|
||||
chat_target_name = (
|
||||
self.chat_target_info.get("person_name")
|
||||
or self.chat_target_info.get("user_nickname")
|
||||
or chat_target_name
|
||||
)
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt_template_name,
|
||||
chat_target=chat_target_name,
|
||||
chat_logs=oldest_messages_str,
|
||||
)
|
||||
|
||||
# Format the private chat prompt
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt_template_name,
|
||||
# Assuming the private prompt template uses {chat_target}
|
||||
chat_target=chat_target_name,
|
||||
chat_logs=oldest_messages_str,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"构建总结 Prompt 失败 for chat {self.chat_id}: {e}")
|
||||
# prompt remains None
|
||||
self.compressor_prompt = prompt
|
||||
|
||||
if prompt: # Check if prompt was built successfully
|
||||
self.compressor_prompt = prompt
|
||||
self.oldest_messages = oldest_messages
|
||||
self.oldest_messages_str = oldest_messages_str
|
||||
|
||||
# 构建中
|
||||
# print(f"构建中:self.talking_message: {self.talking_message}")
|
||||
# 构建当前消息
|
||||
self.talking_message_str = build_readable_messages(
|
||||
messages=self.talking_message,
|
||||
timestamp_mode="lite",
|
||||
read_mark=last_obs_time_mark,
|
||||
show_actions=True,
|
||||
)
|
||||
# print(f"构建中:self.talking_message_str: {self.talking_message_str}")
|
||||
self.talking_message_str_truncate = build_readable_messages(
|
||||
messages=self.talking_message,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
@@ -281,15 +236,12 @@ class ChattingObservation(Observation):
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
# print(f"构建中:self.talking_message_str_truncate: {self.talking_message_str_truncate}")
|
||||
|
||||
self.person_list = await get_person_id_list(self.talking_message)
|
||||
|
||||
# print(f"构建中:self.person_list: {self.person_list}")
|
||||
|
||||
logger.debug(
|
||||
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Chat {self.chat_id} - 现在聊天内容:{self.talking_message_str}"
|
||||
# )
|
||||
|
||||
async def has_new_messages_since(self, timestamp: float) -> bool:
|
||||
"""检查指定时间戳之后是否有新消息"""
|
||||
|
||||
@@ -31,18 +31,4 @@ class WorkingMemoryObservation:
|
||||
return self.retrieved_working_memory
|
||||
|
||||
async def observe(self):
|
||||
pass
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将观察对象转换为可序列化的字典"""
|
||||
return {
|
||||
"observe_info": self.observe_info,
|
||||
"observe_id": self.observe_id,
|
||||
"last_observe_time": self.last_observe_time,
|
||||
"working_memory": self.working_memory.to_dict()
|
||||
if hasattr(self.working_memory, "to_dict")
|
||||
else str(self.working_memory),
|
||||
"retrieved_working_memory": [
|
||||
item.to_dict() if hasattr(item, "to_dict") else str(item) for item in self.retrieved_working_memory
|
||||
],
|
||||
}
|
||||
pass
|
||||
Reference in New Issue
Block a user