feat:精简升级工作记忆模块

This commit is contained in:
SengokuCola
2025-06-14 11:41:34 +08:00
parent e6f93d7dbe
commit 189a68023f
6 changed files with 276 additions and 526 deletions

View File

@@ -31,18 +31,13 @@ def init_prompt():
以下是你已经总结的记忆摘要你可以调取这些记忆查看内容来帮助你聊天不要一次调取太多记忆最多调取3个左右记忆
{memory_str}
观察聊天内容和已经总结的记忆,思考是否有新内容需要总结成记忆,如果有,就输出 true否则输出 false
如果当前聊天记录的内容已经被总结千万不要总结新记忆输出false
如果已经总结的记忆包含了当前聊天记录的内容千万不要总结新记忆输出false
如果已经总结的记忆摘要,包含了当前聊天记录的内容千万不要总结新记忆输出false
如果有相近的记忆请合并记忆输出merge_memory格式为[["id1", "id2"], ["id3", "id4"],...]你可以进行多组合并但是每组合并只能有两个记忆id不要输出其他内容
观察聊天内容和已经总结的记忆,思考如果有相近的记忆请合并记忆输出merge_memory
合并记忆的格式为[["id1", "id2"], ["id3", "id4"],...]你可以进行多组合并但是每组合并只能有两个记忆id不要输出其他内容
请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆以JSON格式输出格式如下
```json
{{
"selected_memory_ids": ["id1", "id2", ...],
"new_memory": "true" or "false",
"selected_memory_ids": ["id1", "id2", ...]
"merge_memory": [["id1", "id2"], ["id3", "id4"],...]
}}
```
@@ -81,120 +76,158 @@ class WorkingMemoryProcessor(BaseProcessor):
for observation in observations:
if isinstance(observation, WorkingMemoryObservation):
working_memory = observation.get_observe_info()
# working_memory_obs = observation
if isinstance(observation, ChattingObservation):
chat_info = observation.get_observe_info()
# chat_info_truncate = observation.talking_message_str_truncate
chat_obs = observation
# 检查是否有待压缩内容
if chat_obs.compressor_prompt:
logger.debug(f"{self.log_prefix} 压缩聊天记忆")
await self.compress_chat_memory(working_memory, chat_obs)
if not working_memory:
logger.debug(f"{self.log_prefix} 没有找到工作记忆对象")
mind_info = MindInfo()
return [mind_info]
all_memory = working_memory.get_all_memories()
if not all_memory:
logger.debug(f"{self.log_prefix} 目前没有工作记忆,跳过提取")
return []
memory_prompts = []
for memory in all_memory:
memory_id = memory.id
memory_brief = memory.brief
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
memory_prompts.append(memory_single_prompt)
memory_choose_str = "".join(memory_prompts)
# 使用提示模板进行处理
prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
bot_name=global_config.bot.nickname,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_info,
memory_str=memory_choose_str,
)
# 调用LLM处理记忆
content = ""
try:
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
print(f"prompt: {prompt}---------------------------------")
print(f"content: {content}---------------------------------")
if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果处理工作记忆失败。")
return []
except Exception as e:
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
logger.error(traceback.format_exc())
return []
# 解析LLM返回的JSON
try:
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict):
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败结果不是字典类型: {type(result)}")
return []
selected_memory_ids = result.get("selected_memory_ids", [])
merge_memory = result.get("merge_memory", [])
except Exception as e:
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
logger.error(traceback.format_exc())
return []
logger.debug(f"{self.log_prefix} 解析LLM返回的JSON,selected_memory_ids: {selected_memory_ids}, merge_memory: {merge_memory}")
# 根据selected_memory_ids调取记忆
memory_str = ""
selected_ids = set(selected_memory_ids) # 转换为集合以便快速查找
# 遍历所有记忆
for memory in all_memory:
if memory.id in selected_ids:
# 选中的记忆显示详细内容
memory = await working_memory.retrieve_memory(memory.id)
if memory:
memory_str += f"{memory.summary}\n"
else:
# 未选中的记忆显示梗概
memory_str += f"{memory.brief}\n"
working_memory_info = WorkingMemoryInfo()
if memory_str:
working_memory_info.add_working_memory(memory_str)
logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
else:
logger.debug(f"{self.log_prefix} 没有找到工作记忆")
if merge_memory:
for merge_pairs in merge_memory:
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
if memory1 and memory2:
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
return [working_memory_info]
except Exception as e:
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
logger.error(traceback.format_exc())
return []
all_memory = working_memory.get_all_memories()
memory_prompts = []
for memory in all_memory:
memory_summary = memory.summary
memory_id = memory.id
memory_brief = memory_summary.get("brief")
memory_points = memory_summary.get("points", [])
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
memory_prompts.append(memory_single_prompt)
memory_choose_str = "".join(memory_prompts)
# 使用提示模板进行处理
prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
bot_name=global_config.bot.nickname,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_info,
memory_str=memory_choose_str,
)
# print(f"prompt: {prompt}")
# 调用LLM处理记忆
content = ""
try:
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果处理工作记忆失败。")
except Exception as e:
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
logger.error(traceback.format_exc())
# 解析LLM返回的JSON
try:
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict):
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败结果不是字典类型: {type(result)}")
return []
selected_memory_ids = result.get("selected_memory_ids", [])
new_memory = result.get("new_memory", "")
merge_memory = result.get("merge_memory", [])
except Exception as e:
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
logger.error(traceback.format_exc())
return []
logger.debug(f"{self.log_prefix} 解析LLM返回的JSON成功: {result}")
# 根据selected_memory_ids调取记忆
memory_str = ""
if selected_memory_ids:
for memory_id in selected_memory_ids:
memory = await working_memory.retrieve_memory(memory_id)
if memory:
memory_summary = memory.summary
memory_id = memory.id
memory_brief = memory_summary.get("brief")
memory_points = memory_summary.get("points", [])
for point in memory_points:
memory_str += f"{point}\n"
working_memory_info = WorkingMemoryInfo()
if memory_str:
working_memory_info.add_working_memory(memory_str)
logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
else:
logger.debug(f"{self.log_prefix} 没有找到工作记忆")
# 根据聊天内容添加新记忆
if new_memory:
# 使用异步方式添加新记忆,不阻塞主流程
logger.debug(f"{self.log_prefix} {new_memory}新记忆: ")
asyncio.create_task(self.add_memory_async(working_memory, chat_info))
if merge_memory:
for merge_pairs in merge_memory:
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
if memory1 and memory2:
memory_str = f"记忆id:{memory1.id},记忆摘要:{memory1.summary.get('brief')}\n"
memory_str += f"记忆id:{memory2.id},记忆摘要:{memory2.summary.get('brief')}\n"
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
return [working_memory_info]
async def add_memory_async(self, working_memory: WorkingMemory, content: str):
"""异步添加记忆,不阻塞主流程
async def compress_chat_memory(self, working_memory: WorkingMemory, obs: ChattingObservation):
"""压缩聊天记忆
Args:
working_memory: 工作记忆对象
content: 记忆内容
obs: 聊天观察对象
"""
try:
await working_memory.add_memory(content=content, from_source="chat_text")
# logger.debug(f"{self.log_prefix} 异步添加新记忆成功: {content[:30]}...")
summary_result, _ = await self.llm_model.generate_response_async(obs.compressor_prompt)
if not summary_result:
logger.debug(f"{self.log_prefix} 压缩聊天记忆失败: 没有生成摘要")
return
print(f"compressor_prompt: {obs.compressor_prompt}")
print(f"summary_result: {summary_result}")
# 修复并解析JSON
try:
fixed_json = repair_json(summary_result)
summary_data = json.loads(fixed_json)
if not isinstance(summary_data, dict):
logger.error(f"{self.log_prefix} 解析压缩结果失败: 不是有效的JSON对象")
return
theme = summary_data.get("theme", "")
content = summary_data.get("content", "")
if not theme or not content:
logger.error(f"{self.log_prefix} 解析压缩结果失败: 缺少必要字段")
return
# 创建新记忆
await working_memory.add_memory(
from_source="chat_compress",
summary=content,
brief=theme
)
logger.debug(f"{self.log_prefix} 压缩聊天记忆成功: {theme} - {content}")
except Exception as e:
logger.error(f"{self.log_prefix} 解析压缩结果失败: {e}")
logger.error(traceback.format_exc())
return
# 清理压缩状态
obs.compressor_prompt = ""
obs.oldest_messages = []
obs.oldest_messages_str = ""
except Exception as e:
logger.error(f"{self.log_prefix} 异步添加新记忆失败: {e}")
logger.error(f"{self.log_prefix} 压缩聊天记忆失败: {e}")
logger.error(traceback.format_exc())
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
@@ -202,13 +235,13 @@ class WorkingMemoryProcessor(BaseProcessor):
Args:
working_memory: 工作记忆对象
memory_str: 记忆内容
memory_id1: 第一个记忆ID
memory_id2: 第二个记忆ID
"""
try:
merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
# logger.debug(f"{self.log_prefix} 异步合并记忆成功: {memory_id1} 和 {memory_id2}...")
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.summary.get('brief')}")
logger.debug(f"{self.log_prefix} 合并后的记忆要点: {merged_memory.summary.get('points')}")
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.brief}")
logger.debug(f"{self.log_prefix} 合并后的记忆内容: {merged_memory.summary}")
except Exception as e:
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")

View File

@@ -7,12 +7,12 @@ import string
class MemoryItem:
"""记忆项类,用于存储单个记忆的所有相关信息"""
def __init__(self, data: Any, from_source: str = "", brief: str = ""):
def __init__(self, summary: str, from_source: str = "", brief: str = ""):
"""
初始化记忆项
Args:
data: 记忆数据
summary: 记忆内容概括
from_source: 数据来源
brief: 记忆内容主题
"""
@@ -20,18 +20,12 @@ class MemoryItem:
timestamp = int(time.time())
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
self.id = f"{timestamp}_{random_str}"
self.data = data
self.data_type = type(data)
self.from_source = from_source
self.brief = brief
self.timestamp = time.time()
# 修改summary的结构说明用于存储可能的总结信息
# summary结构{
# "detailed": "记忆内容概括",
# "keypoints": ["关键概念1", "关键概念2"],
# "events": ["事件1", "事件2"]
# }
self.summary = None
# 记忆内容概括
self.summary = summary
# 记忆精简次数
self.compress_count = 0
@@ -50,10 +44,6 @@ class MemoryItem:
"""检查来源是否匹配"""
return self.from_source == source
def set_summary(self, summary: Dict[str, Any]) -> None:
"""设置总结信息"""
self.summary = summary
def increase_strength(self, amount: float) -> None:
"""增加记忆强度"""
self.memory_strength = min(10.0, self.memory_strength + amount)
@@ -85,9 +75,9 @@ class MemoryItem:
current_time = time.time()
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
def to_tuple(self) -> Tuple[Any, str, float, str]:
def to_tuple(self) -> Tuple[str, str, float, str]:
"""转换为元组格式(为了兼容性)"""
return (self.data, self.from_source, self.timestamp, self.id)
return (self.summary, self.from_source, self.timestamp, self.id)
def is_memory_valid(self) -> bool:
"""检查记忆是否有效强度是否大于等于1"""

View File

@@ -26,8 +26,8 @@ class MemoryManager:
# 关联的聊天ID
self._chat_id = chat_id
# 主存储: 数据类型 -> 记忆项列表
self._memory: Dict[Type, List[MemoryItem]] = {}
# 记忆项列表
self._memories: List[MemoryItem] = []
# ID到记忆项的映射
self._id_map: Dict[str, MemoryItem] = {}
@@ -58,51 +58,12 @@ class MemoryManager:
Returns:
记忆项的ID
"""
data_type = memory_item.data_type
# 确保存在该类型的存储列表
if data_type not in self._memory:
self._memory[data_type] = []
# 添加到内存和ID映射
self._memory[data_type].append(memory_item)
self._memories.append(memory_item)
self._id_map[memory_item.id] = memory_item
return memory_item.id
async def push_with_summary(self, data: T, from_source: str = "") -> MemoryItem:
"""
推送一段有类型的信息到工作记忆中,并自动生成总结
Args:
data: 要存储的数据
from_source: 数据来源
Returns:
包含原始数据和总结信息的字典
"""
# 如果数据是字符串类型,则先进行总结
if isinstance(data, str):
# 先生成总结
summary = await self.summarize_memory_item(data)
# 创建记忆项
memory_item = MemoryItem(data, from_source, brief=summary.get("brief", ""))
# 将总结信息保存到记忆项中
memory_item.set_summary(summary)
# 推送记忆项
self.push_item(memory_item)
return memory_item
else:
# 非字符串类型,直接创建并推送记忆项
memory_item = MemoryItem(data, from_source)
self.push_item(memory_item)
return memory_item
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
"""
通过ID获取记忆项
@@ -129,7 +90,6 @@ class MemoryManager:
def find_items(
self,
data_type: Optional[Type] = None,
source: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
@@ -142,7 +102,6 @@ class MemoryManager:
按条件查找记忆项
Args:
data_type: 要查找的数据类型
source: 数据来源
start_time: 开始时间戳
end_time: 结束时间戳
@@ -161,49 +120,41 @@ class MemoryManager:
results = []
# 确定要搜索的类型列表
types_to_search = [data_type] if data_type else list(self._memory.keys())
# 获取所有项目
items = self._memories
# 对每个类型进行搜索
for typ in types_to_search:
if typ not in self._memory:
# 如果需要最新优先,则反转遍历顺序
if newest_first:
items_to_check = list(reversed(items))
else:
items_to_check = items
# 遍历项目
for item in items_to_check:
# 检查来源是否匹配
if source is not None and not item.matches_source(source):
continue
# 获取该类型的所有项目
items = self._memory[typ]
# 检查时间范围
if start_time is not None and item.timestamp < start_time:
continue
if end_time is not None and item.timestamp > end_time:
continue
# 如果需要最新优先,则反转遍历顺序
if newest_first:
items_to_check = list(reversed(items))
else:
items_to_check = items
# 检查记忆强度
if min_strength > 0 and item.memory_strength < min_strength:
continue
# 遍历项目
for item in items_to_check:
# 检查来源是否匹配
if source is not None and not item.matches_source(source):
continue
# 所有条件都满足,添加到结果中
results.append(item)
# 检查时间范围
if start_time is not None and item.timestamp < start_time:
continue
if end_time is not None and item.timestamp > end_time:
continue
# 检查记忆强度
if min_strength > 0 and item.memory_strength < min_strength:
continue
# 所有条件都满足,添加到结果中
results.append(item)
# 如果达到限制数量,提前返回
if limit is not None and len(results) >= limit:
return results
# 如果达到限制数量,提前返回
if limit is not None and len(results) >= limit:
return results
return results
async def summarize_memory_item(self, content: str) -> Dict[str, Any]:
async def summarize_memory_item(self, content: str) -> Dict[str, str]:
"""
使用LLM总结记忆项
@@ -211,11 +162,11 @@ class MemoryManager:
content: 需要总结的内容
Returns:
包含总结、概括、关键概念和事件的字典
包含brief和summary的字典
"""
prompt = f"""请对以下内容进行总结,总结成记忆,输出两部分:
1. 记忆内容主题精简20字以内让用户可以一眼看出记忆内容是什么
2. content一到三条包含关键的概念、事件每条都要包含解释或描述谁在什么时候干了什么
2. 记忆内容概括对内容进行概括保留重要信息200字以内
内容:
{content}
@@ -223,16 +174,13 @@ class MemoryManager:
请按以下JSON格式输出
{{
"brief": "记忆内容主题",
"points": [
"内容",
"内容"
]
"summary": "记忆内容概括"
}}
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
default_summary = {
"brief": "主题未知的记忆",
"points": ["未知的要点"],
"summary": "无法概括的记忆内容",
}
try:
@@ -264,132 +212,19 @@ class MemoryManager:
if "brief" not in json_result or not isinstance(json_result["brief"], str):
json_result["brief"] = "主题未知的记忆"
# 处理关键要点
if "points" not in json_result or not isinstance(json_result["points"], list):
json_result["points"] = ["未知的要点"]
else:
# 确保points中的每个项目都是字符串
json_result["points"] = [str(point) for point in json_result["points"] if point is not None]
if not json_result["points"]:
json_result["points"] = ["未知的要点"]
if "summary" not in json_result or not isinstance(json_result["summary"], str):
json_result["summary"] = "无法概括的记忆内容"
return json_result
except Exception as json_error:
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
# 返回默认结构
return default_summary
except Exception as e:
# 出错时返回简单的结构
logger.error(f"生成总结时出错: {str(e)}")
return default_summary
# async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
# """
# 对记忆进行精简操作,根据要求修改要点、总结和概括
# Args:
# memory_id: 记忆ID
# requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
# Returns:
# 修改后的记忆总结字典
# """
# # 获取指定ID的记忆项
# logger.info(f"精简记忆: {memory_id}")
# memory_item = self.get_by_id(memory_id)
# if not memory_item:
# raise ValueError(f"未找到ID为{memory_id}的记忆项")
# # 增加精简次数
# memory_item.increase_compress_count()
# summary = memory_item.summary
# # 使用LLM根据要求对总结、概括和要点进行精简修改
# prompt = f"""
# 请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程:
# 要求:{requirements}
# 你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题
# 目前主题:{summary["brief"]}
# 目前关键要点:
# {chr(10).join([f"- {point}" for point in summary.get("points", [])])}
# 请生成修改后的主题和关键要点,遵循以下格式:
# ```json
# {{
# "brief": "修改后的主题20字以内",
# "points": [
# "修改后的要点",
# "修改后的要点"
# ]
# }}
# ```
# 请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
# """
# # 定义默认的精简结果
# default_refined = {
# "brief": summary["brief"],
# "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点
# }
# try:
# # 调用LLM修改总结、概括和要点
# response, _ = await self.llm_summarizer.generate_response_async(prompt)
# logger.debug(f"精简记忆响应: {response}")
# # 使用repair_json处理响应
# try:
# # 修复JSON格式
# fixed_json_string = repair_json(response)
# # 将修复后的字符串解析为Python对象
# if isinstance(fixed_json_string, str):
# try:
# refined_data = json.loads(fixed_json_string)
# except json.JSONDecodeError as decode_error:
# logger.error(f"JSON解析错误: {str(decode_error)}")
# refined_data = default_refined
# else:
# # 如果repair_json直接返回了字典对象直接使用
# refined_data = fixed_json_string
# # 确保是字典类型
# if not isinstance(refined_data, dict):
# logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
# refined_data = default_refined
# # 更新总结
# summary["brief"] = refined_data.get("brief", "主题未知的记忆")
# # 更新关键要点
# points = refined_data.get("points", [])
# if isinstance(points, list) and points:
# # 确保所有要点都是字符串
# summary["points"] = [str(point) for point in points if point is not None]
# else:
# # 如果points不是列表或为空使用默认值
# summary["points"] = ["主要要点已遗忘"]
# except Exception as e:
# logger.error(f"精简记忆出错: {str(e)}")
# traceback.print_exc()
# # 出错时使用简化的默认精简
# summary["brief"] = summary["brief"] + " (已简化)"
# summary["points"] = summary.get("points", ["未知的要点"])[:1]
# except Exception as e:
# logger.error(f"精简记忆调用LLM出错: {str(e)}")
# traceback.print_exc()
# # 更新原记忆项的总结
# memory_item.set_summary(summary)
# return memory_item
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
"""
使单个记忆衰减
@@ -431,32 +266,17 @@ class MemoryManager:
item = self._id_map[memory_id]
# 从内存中删除
data_type = item.data_type
if data_type in self._memory:
self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id]
self._memories = [i for i in self._memories if i.id != memory_id]
# 从ID映射中删除
del self._id_map[memory_id]
return True
def clear(self, data_type: Optional[Type] = None) -> None:
"""
清除记忆中的数据
Args:
data_type: 要清除的数据类型如果为None则清除所有数据
"""
if data_type is None:
# 清除所有数据
self._memory.clear()
self._id_map.clear()
elif data_type in self._memory:
# 清除指定类型的数据
for item in self._memory[data_type]:
if item.id in self._id_map:
del self._id_map[item.id]
del self._memory[data_type]
def clear(self) -> None:
"""清除所有记忆"""
self._memories.clear()
self._id_map.clear()
async def merge_memories(
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
@@ -471,7 +291,7 @@ class MemoryManager:
delete_originals: 是否删除原始记忆默认为True
Returns:
包含合并后的记忆信息的字典
合并后的记忆
"""
# 获取两个记忆项
memory_item1 = self.get_by_id(memory_id1)
@@ -480,58 +300,33 @@ class MemoryManager:
if not memory_item1 or not memory_item2:
raise ValueError("无法找到指定的记忆项")
# 获取记忆的摘要信息(如果有)
summary1 = memory_item1.summary
summary2 = memory_item2.summary
# 构建合并提示
prompt = f"""
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
合并原因:{reason}
"""
# 如果有摘要信息,添加到提示中
if summary1:
prompt += f"记忆1主题{summary1['brief']}\n"
记忆1主题{memory_item1.brief}
记忆1内容{memory_item1.summary}
prompt += "记忆1关键要点\n" + "\n".join([f"- {point}" for point in summary1.get("points", [])]) + "\n\n"
记忆2主题{memory_item2.brief}
记忆2内容{memory_item2.summary}
if summary2:
prompt += f"记忆2主题{summary2['brief']}\n"
prompt += "记忆2关键要点\n" + "\n".join([f"- {point}" for point in summary2.get("points", [])]) + "\n\n"
prompt += """
请按以下JSON格式输出合并结果
```json
{
{{
"brief": "合并后的主题20字以内",
"points": [
"合并后的要点",
"合并后的要点"
]
}
```
"summary": "合并后的内容概括200字以内"
}}
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
# 默认合并结果
default_merged = {
"brief": f"合并:{summary1['brief']} + {summary2['brief']}",
"points": [],
"brief": f"合并:{memory_item1.brief} + {memory_item2.brief}",
"summary": f"合并的记忆:{memory_item1.summary}\n{memory_item2.summary}",
}
# 合并points
if "points" in summary1:
default_merged["points"].extend(summary1["points"])
if "points" in summary2:
default_merged["points"].extend(summary2["points"])
# 确保列表不为空
if not default_merged["points"]:
default_merged["points"] = ["合并的要点"]
try:
# 调用LLM合并记忆
response, _ = await self.llm_summarizer.generate_response_async(prompt)
@@ -560,14 +355,8 @@ class MemoryManager:
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
merged_data["brief"] = default_merged["brief"]
# 处理关键要点
if "points" not in merged_data or not isinstance(merged_data["points"], list):
merged_data["points"] = default_merged["points"]
else:
# 确保points中的每个项目都是字符串
merged_data["points"] = [str(point) for point in merged_data["points"] if point is not None]
if not merged_data["points"]:
merged_data["points"] = ["合并的要点"]
if "summary" not in merged_data or not isinstance(merged_data["summary"], str):
merged_data["summary"] = default_merged["summary"]
except Exception as e:
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
@@ -586,15 +375,8 @@ class MemoryManager:
else memory_item2.from_source
)
# 创建新的记忆项使用空字符串作为data
merged_memory = MemoryItem(data="", from_source=merged_source, brief=merged_data["brief"])
# 设置合并后的摘要
summary = {
"brief": merged_data["brief"],
"points": merged_data["points"],
}
merged_memory.set_summary(summary)
# 创建新的记忆项
merged_memory = MemoryItem(summary=merged_data["summary"], from_source=merged_source, brief=merged_data["brief"])
# 记忆强度取两者最大值
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)

View File

@@ -54,18 +54,25 @@ class WorkingMemory:
except Exception as e:
print(f"自动衰减记忆时出错: {str(e)}")
async def add_memory(self, content: Any, from_source: str = ""):
async def add_memory(self, summary: Any, from_source: str = "",brief: str = ""):
"""
添加一段记忆到指定聊天
Args:
content: 记忆内容
summary: 记忆内容
from_source: 数据来源
Returns:
包含记忆信息的字典
记忆项
"""
memory = await self.memory_manager.push_with_summary(content, from_source)
# 如果是字符串类型,生成总结
memory = MemoryItem(summary, from_source, brief)
# 添加到管理器
self.memory_manager.push_item(memory)
# 如果超过最大记忆数量,删除最早的记忆
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
self.remove_earliest_memory()

View File

@@ -8,37 +8,48 @@ from src.chat.utils.chat_message_builder import (
num_new_messages_since,
get_person_id_list,
)
from src.chat.utils.prompt_builder import global_prompt_manager
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
from typing import Optional
import difflib
from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入
from src.chat.message_receive.message import MessageRecv
from src.chat.heart_flow.observation.observation import Observation
from src.common.logger import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt
logger = get_logger("observation")
# 定义提示模板
Prompt(
"""这是qq群聊的聊天记录请总结以下聊天记录的主题
{chat_logs}
用一句话概括,包括人物、事件和主要信息,不要分点。""",
概括这段聊天记录的主题和主要内容
主题简短的概括包括时间人物和事件不要超过10个字
内容具体的信息内容包括人物、事件和信息不要超过100个字不要分点。
请用json格式返回格式如下
{{
"theme": "主题",
"content": "内容"
}}
""",
"chat_summary_group_prompt", # Template for group chat
)
Prompt(
"""这是你和{chat_target}的私聊记录,请总结以下聊天记录的主题:
{chat_logs}
请用一句话概括,包括事件,时间,和主要信息,不要分点。""",
请用一句话概括,包括事件,时间,和主要信息,不要分点。
主题简短的介绍不要超过10个字
内容:包括人物、事件和主要信息,不要分点。
请用json格式返回格式如下
{{
"theme": "主题",
"content": "内容"
}}""",
"chat_summary_private_prompt", # Template for private chat
)
# --- End Prompt Template Definition ---
# 聊天观察
class ChattingObservation(Observation):
def __init__(self, chat_id):
super().__init__(chat_id)
@@ -47,7 +58,6 @@ class ChattingObservation(Observation):
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
# --- Other attributes initialized in __init__ ---
self.talking_message = []
self.talking_message_str = ""
self.talking_message_str_truncate = ""
@@ -55,13 +65,10 @@ class ChattingObservation(Observation):
self.nick_name = global_config.bot.alias_names
self.max_now_obs_len = global_config.focus_chat.observation_context_size
self.overlap_len = global_config.focus_chat.compressed_length
self.mid_memories = []
self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
self.mid_memory_info = ""
self.person_list = []
self.compressor_prompt = ""
self.oldest_messages = []
self.oldest_messages_str = ""
self.compressor_prompt = ""
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
@@ -79,41 +86,11 @@ class ChattingObservation(Observation):
"talking_message_str_truncate": self.talking_message_str_truncate,
"name": self.name,
"nick_name": self.nick_name,
"mid_memory_info": self.mid_memory_info,
"person_list": self.person_list,
"oldest_messages_str": self.oldest_messages_str,
"compressor_prompt": self.compressor_prompt,
"last_observe_time": self.last_observe_time,
}
# 进行一次观察 返回观察结果observe_info
def get_observe_info(self, ids=None):
mid_memory_str = ""
if ids:
for id in ids:
print(f"id{id}")
try:
for mid_memory in self.mid_memories:
if mid_memory["id"] == id:
mid_memory_by_id = mid_memory
msg_str = ""
for msg in mid_memory_by_id["messages"]:
msg_str += f"{msg['detailed_plain_text']}"
# time_diff = int((datetime.now().timestamp() - mid_memory_by_id["created_at"]) / 60)
# mid_memory_str += f"距离现在{time_diff}分钟前:\n{msg_str}\n"
mid_memory_str += f"{msg_str}\n"
except Exception as e:
logger.error(f"获取mid_memory_id失败: {e}")
traceback.print_exc()
return self.talking_message_str
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
else:
mid_memory_str = "之前的聊天内容:\n"
for mid_memory in self.mid_memories:
mid_memory_str += f"{mid_memory['theme']}\n"
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
return self.talking_message_str
def search_message_by_text(self, text: str) -> Optional[MessageRecv]:
"""
@@ -128,7 +105,6 @@ class ChattingObservation(Observation):
for message in reverse_talking_message:
if message["processed_plain_text"] == text:
find_msg = message
# logger.debug(f"找到的锚定消息find_msg: {find_msg}")
break
else:
raw_message = message.get("raw_message")
@@ -137,11 +113,11 @@ class ChattingObservation(Observation):
else:
similarity = difflib.SequenceMatcher(None, text, message.get("processed_plain_text", "")).ratio()
msg_list.append({"message": message, "similarity": similarity})
# logger.debug(f"对锚定消息检查message: {message['processed_plain_text']},similarity: {similarity}")
if not find_msg:
if msg_list:
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
if msg_list[0]["similarity"] >= 0.9: # 只返回相似度大于等于0.5的消息
if msg_list[0]["similarity"] >= 0.9:
find_msg = msg_list[0]["message"]
else:
logger.debug("没有找到锚定消息,相似度低")
@@ -150,9 +126,6 @@ class ChattingObservation(Observation):
logger.debug("没有找到锚定消息,没有消息捕获")
return None
# logger.debug(f"找到的锚定消息find_msg: {find_msg}")
# 创建所需的user_info字段
user_info = {
"platform": find_msg.get("user_platform", ""),
"user_id": find_msg.get("user_id", ""),
@@ -160,7 +133,6 @@ class ChattingObservation(Observation):
"user_cardname": find_msg.get("user_cardname", ""),
}
# 创建所需的group_info字段如果是群聊的话
group_info = {}
if find_msg.get("chat_info_group_id"):
group_info = {
@@ -194,9 +166,7 @@ class ChattingObservation(Observation):
"detailed_plain_text": find_msg.get("processed_plain_text"),
"processed_plain_text": find_msg.get("processed_plain_text"),
}
# print(f"message_dict: {message_dict}")
find_rec_msg = MessageRecv(message_dict)
# logger.debug(f"锚定消息处理后find_rec_msg: {find_rec_msg}")
return find_rec_msg
async def observe(self):
@@ -209,8 +179,6 @@ class ChattingObservation(Observation):
limit_mode="latest",
)
# print(f"new_messages_list: {new_messages_list}")
last_obs_time_mark = self.last_observe_time
if new_messages_list:
self.last_observe_time = new_messages_list[-1]["time"]
@@ -220,60 +188,47 @@ class ChattingObservation(Observation):
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
oldest_messages = self.talking_message[:messages_to_remove_count]
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
self.talking_message = self.talking_message[messages_to_remove_count:]
# print(f"压缩中oldest_messages: {oldest_messages}")
# 构建压缩提示
oldest_messages_str = build_readable_messages(
messages=oldest_messages, timestamp_mode="normal_no_YMD", read_mark=0, show_actions=True
messages=oldest_messages,
timestamp_mode="normal_no_YMD",
read_mark=0,
show_actions=True
)
# --- Build prompt using template ---
prompt = None # Initialize prompt as None
try:
# 构建 Prompt - 根据 is_group_chat 选择模板
if self.is_group_chat:
prompt_template_name = "chat_summary_group_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt_template_name, chat_logs=oldest_messages_str
# 根据聊天类型选择提示模板
if self.is_group_chat:
prompt_template_name = "chat_summary_group_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt_template_name,
chat_logs=oldest_messages_str
)
else:
prompt_template_name = "chat_summary_private_prompt"
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = (
self.chat_target_info.get("person_name")
or self.chat_target_info.get("user_nickname")
or chat_target_name
)
else:
# For private chat, add chat_target to the prompt variables
prompt_template_name = "chat_summary_private_prompt"
# Determine the target name for the prompt
chat_target_name = "对方" # Default fallback
if self.chat_target_info:
# Prioritize person_name, then nickname
chat_target_name = (
self.chat_target_info.get("person_name")
or self.chat_target_info.get("user_nickname")
or chat_target_name
)
prompt = await global_prompt_manager.format_prompt(
prompt_template_name,
chat_target=chat_target_name,
chat_logs=oldest_messages_str,
)
# Format the private chat prompt
prompt = await global_prompt_manager.format_prompt(
prompt_template_name,
# Assuming the private prompt template uses {chat_target}
chat_target=chat_target_name,
chat_logs=oldest_messages_str,
)
except Exception as e:
logger.error(f"构建总结 Prompt 失败 for chat {self.chat_id}: {e}")
# prompt remains None
self.compressor_prompt = prompt
if prompt: # Check if prompt was built successfully
self.compressor_prompt = prompt
self.oldest_messages = oldest_messages
self.oldest_messages_str = oldest_messages_str
# 构建中
# print(f"构建中self.talking_message: {self.talking_message}")
# 构建当前消息
self.talking_message_str = build_readable_messages(
messages=self.talking_message,
timestamp_mode="lite",
read_mark=last_obs_time_mark,
show_actions=True,
)
# print(f"构建中self.talking_message_str: {self.talking_message_str}")
self.talking_message_str_truncate = build_readable_messages(
messages=self.talking_message,
timestamp_mode="normal_no_YMD",
@@ -281,15 +236,12 @@ class ChattingObservation(Observation):
truncate=True,
show_actions=True,
)
# print(f"构建中self.talking_message_str_truncate: {self.talking_message_str_truncate}")
self.person_list = await get_person_id_list(self.talking_message)
# print(f"构建中self.person_list: {self.person_list}")
logger.debug(
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
)
# logger.debug(
# f"Chat {self.chat_id} - 现在聊天内容:{self.talking_message_str}"
# )
async def has_new_messages_since(self, timestamp: float) -> bool:
"""检查指定时间戳之后是否有新消息"""

View File

@@ -31,18 +31,4 @@ class WorkingMemoryObservation:
return self.retrieved_working_memory
async def observe(self):
pass
def to_dict(self) -> dict:
"""将观察对象转换为可序列化的字典"""
return {
"observe_info": self.observe_info,
"observe_id": self.observe_id,
"last_observe_time": self.last_observe_time,
"working_memory": self.working_memory.to_dict()
if hasattr(self.working_memory, "to_dict")
else str(self.working_memory),
"retrieved_working_memory": [
item.to_dict() if hasattr(item, "to_dict") else str(item) for item in self.retrieved_working_memory
],
}
pass