feat:精简升级工作记忆模块
This commit is contained in:
@@ -31,18 +31,13 @@ def init_prompt():
|
|||||||
以下是你已经总结的记忆摘要,你可以调取这些记忆查看内容来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆:
|
以下是你已经总结的记忆摘要,你可以调取这些记忆查看内容来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆:
|
||||||
{memory_str}
|
{memory_str}
|
||||||
|
|
||||||
观察聊天内容和已经总结的记忆,思考是否有新内容需要总结成记忆,如果有,就输出 true,否则输出 false
|
观察聊天内容和已经总结的记忆,思考如果有相近的记忆,请合并记忆,输出merge_memory,
|
||||||
如果当前聊天记录的内容已经被总结,千万不要总结新记忆,输出false
|
合并记忆的格式为[["id1", "id2"], ["id3", "id4"],...],你可以进行多组合并,但是每组合并只能有两个记忆id,不要输出其他内容
|
||||||
如果已经总结的记忆包含了当前聊天记录的内容,千万不要总结新记忆,输出false
|
|
||||||
如果已经总结的记忆摘要,包含了当前聊天记录的内容,千万不要总结新记忆,输出false
|
|
||||||
|
|
||||||
如果有相近的记忆,请合并记忆,输出merge_memory,格式为[["id1", "id2"], ["id3", "id4"],...],你可以进行多组合并,但是每组合并只能有两个记忆id,不要输出其他内容
|
|
||||||
|
|
||||||
请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆,以JSON格式输出,格式如下:
|
请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆,以JSON格式输出,格式如下:
|
||||||
```json
|
```json
|
||||||
{{
|
{{
|
||||||
"selected_memory_ids": ["id1", "id2", ...],
|
"selected_memory_ids": ["id1", "id2", ...]
|
||||||
"new_memory": "true" or "false",
|
|
||||||
"merge_memory": [["id1", "id2"], ["id3", "id4"],...]
|
"merge_memory": [["id1", "id2"], ["id3", "id4"],...]
|
||||||
}}
|
}}
|
||||||
```
|
```
|
||||||
@@ -81,27 +76,23 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
for observation in observations:
|
for observation in observations:
|
||||||
if isinstance(observation, WorkingMemoryObservation):
|
if isinstance(observation, WorkingMemoryObservation):
|
||||||
working_memory = observation.get_observe_info()
|
working_memory = observation.get_observe_info()
|
||||||
# working_memory_obs = observation
|
|
||||||
if isinstance(observation, ChattingObservation):
|
if isinstance(observation, ChattingObservation):
|
||||||
chat_info = observation.get_observe_info()
|
chat_info = observation.get_observe_info()
|
||||||
# chat_info_truncate = observation.talking_message_str_truncate
|
chat_obs = observation
|
||||||
|
# 检查是否有待压缩内容
|
||||||
if not working_memory:
|
if chat_obs.compressor_prompt:
|
||||||
logger.debug(f"{self.log_prefix} 没有找到工作记忆对象")
|
logger.debug(f"{self.log_prefix} 压缩聊天记忆")
|
||||||
mind_info = MindInfo()
|
await self.compress_chat_memory(working_memory, chat_obs)
|
||||||
return [mind_info]
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
return []
|
|
||||||
|
|
||||||
all_memory = working_memory.get_all_memories()
|
all_memory = working_memory.get_all_memories()
|
||||||
|
if not all_memory:
|
||||||
|
logger.debug(f"{self.log_prefix} 目前没有工作记忆,跳过提取")
|
||||||
|
return []
|
||||||
|
|
||||||
memory_prompts = []
|
memory_prompts = []
|
||||||
for memory in all_memory:
|
for memory in all_memory:
|
||||||
memory_summary = memory.summary
|
|
||||||
memory_id = memory.id
|
memory_id = memory.id
|
||||||
memory_brief = memory_summary.get("brief")
|
memory_brief = memory.brief
|
||||||
memory_points = memory_summary.get("points", [])
|
|
||||||
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
|
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
|
||||||
memory_prompts.append(memory_single_prompt)
|
memory_prompts.append(memory_single_prompt)
|
||||||
|
|
||||||
@@ -115,17 +106,21 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
memory_str=memory_choose_str,
|
memory_str=memory_choose_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(f"prompt: {prompt}")
|
|
||||||
|
|
||||||
# 调用LLM处理记忆
|
# 调用LLM处理记忆
|
||||||
content = ""
|
content = ""
|
||||||
try:
|
try:
|
||||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
|
print(f"prompt: {prompt}---------------------------------")
|
||||||
|
print(f"content: {content}---------------------------------")
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
|
logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
|
||||||
|
return []
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
|
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
return []
|
||||||
|
|
||||||
# 解析LLM返回的JSON
|
# 解析LLM返回的JSON
|
||||||
try:
|
try:
|
||||||
@@ -137,27 +132,28 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
selected_memory_ids = result.get("selected_memory_ids", [])
|
selected_memory_ids = result.get("selected_memory_ids", [])
|
||||||
new_memory = result.get("new_memory", "")
|
|
||||||
merge_memory = result.get("merge_memory", [])
|
merge_memory = result.get("merge_memory", [])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
|
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 解析LLM返回的JSON成功: {result}")
|
logger.debug(f"{self.log_prefix} 解析LLM返回的JSON,selected_memory_ids: {selected_memory_ids}, merge_memory: {merge_memory}")
|
||||||
|
|
||||||
# 根据selected_memory_ids,调取记忆
|
# 根据selected_memory_ids,调取记忆
|
||||||
memory_str = ""
|
memory_str = ""
|
||||||
if selected_memory_ids:
|
selected_ids = set(selected_memory_ids) # 转换为集合以便快速查找
|
||||||
for memory_id in selected_memory_ids:
|
|
||||||
memory = await working_memory.retrieve_memory(memory_id)
|
# 遍历所有记忆
|
||||||
|
for memory in all_memory:
|
||||||
|
if memory.id in selected_ids:
|
||||||
|
# 选中的记忆显示详细内容
|
||||||
|
memory = await working_memory.retrieve_memory(memory.id)
|
||||||
if memory:
|
if memory:
|
||||||
memory_summary = memory.summary
|
memory_str += f"{memory.summary}\n"
|
||||||
memory_id = memory.id
|
else:
|
||||||
memory_brief = memory_summary.get("brief")
|
# 未选中的记忆显示梗概
|
||||||
memory_points = memory_summary.get("points", [])
|
memory_str += f"{memory.brief}\n"
|
||||||
for point in memory_points:
|
|
||||||
memory_str += f"{point}\n"
|
|
||||||
|
|
||||||
working_memory_info = WorkingMemoryInfo()
|
working_memory_info = WorkingMemoryInfo()
|
||||||
if memory_str:
|
if memory_str:
|
||||||
@@ -166,35 +162,72 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
else:
|
else:
|
||||||
logger.debug(f"{self.log_prefix} 没有找到工作记忆")
|
logger.debug(f"{self.log_prefix} 没有找到工作记忆")
|
||||||
|
|
||||||
# 根据聊天内容添加新记忆
|
|
||||||
if new_memory:
|
|
||||||
# 使用异步方式添加新记忆,不阻塞主流程
|
|
||||||
logger.debug(f"{self.log_prefix} {new_memory}新记忆: ")
|
|
||||||
asyncio.create_task(self.add_memory_async(working_memory, chat_info))
|
|
||||||
|
|
||||||
if merge_memory:
|
if merge_memory:
|
||||||
for merge_pairs in merge_memory:
|
for merge_pairs in merge_memory:
|
||||||
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
|
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
|
||||||
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
|
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
|
||||||
if memory1 and memory2:
|
if memory1 and memory2:
|
||||||
memory_str = f"记忆id:{memory1.id},记忆摘要:{memory1.summary.get('brief')}\n"
|
|
||||||
memory_str += f"记忆id:{memory2.id},记忆摘要:{memory2.summary.get('brief')}\n"
|
|
||||||
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
|
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
|
||||||
|
|
||||||
return [working_memory_info]
|
return [working_memory_info]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return []
|
||||||
|
|
||||||
async def add_memory_async(self, working_memory: WorkingMemory, content: str):
|
async def compress_chat_memory(self, working_memory: WorkingMemory, obs: ChattingObservation):
|
||||||
"""异步添加记忆,不阻塞主流程
|
"""压缩聊天记忆
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
working_memory: 工作记忆对象
|
working_memory: 工作记忆对象
|
||||||
content: 记忆内容
|
obs: 聊天观察对象
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
await working_memory.add_memory(content=content, from_source="chat_text")
|
summary_result, _ = await self.llm_model.generate_response_async(obs.compressor_prompt)
|
||||||
# logger.debug(f"{self.log_prefix} 异步添加新记忆成功: {content[:30]}...")
|
if not summary_result:
|
||||||
|
logger.debug(f"{self.log_prefix} 压缩聊天记忆失败: 没有生成摘要")
|
||||||
|
return
|
||||||
|
|
||||||
|
print(f"compressor_prompt: {obs.compressor_prompt}")
|
||||||
|
print(f"summary_result: {summary_result}")
|
||||||
|
|
||||||
|
# 修复并解析JSON
|
||||||
|
try:
|
||||||
|
fixed_json = repair_json(summary_result)
|
||||||
|
summary_data = json.loads(fixed_json)
|
||||||
|
|
||||||
|
if not isinstance(summary_data, dict):
|
||||||
|
logger.error(f"{self.log_prefix} 解析压缩结果失败: 不是有效的JSON对象")
|
||||||
|
return
|
||||||
|
|
||||||
|
theme = summary_data.get("theme", "")
|
||||||
|
content = summary_data.get("content", "")
|
||||||
|
|
||||||
|
if not theme or not content:
|
||||||
|
logger.error(f"{self.log_prefix} 解析压缩结果失败: 缺少必要字段")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建新记忆
|
||||||
|
await working_memory.add_memory(
|
||||||
|
from_source="chat_compress",
|
||||||
|
summary=content,
|
||||||
|
brief=theme
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(f"{self.log_prefix} 压缩聊天记忆成功: {theme} - {content}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 异步添加新记忆失败: {e}")
|
logger.error(f"{self.log_prefix} 解析压缩结果失败: {e}")
|
||||||
|
logger.error(traceback.format_exc())
|
||||||
|
return
|
||||||
|
|
||||||
|
# 清理压缩状态
|
||||||
|
obs.compressor_prompt = ""
|
||||||
|
obs.oldest_messages = []
|
||||||
|
obs.oldest_messages_str = ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"{self.log_prefix} 压缩聊天记忆失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|
||||||
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
|
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
|
||||||
@@ -202,13 +235,13 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
working_memory: 工作记忆对象
|
working_memory: 工作记忆对象
|
||||||
memory_str: 记忆内容
|
memory_id1: 第一个记忆ID
|
||||||
|
memory_id2: 第二个记忆ID
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
|
merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
|
||||||
# logger.debug(f"{self.log_prefix} 异步合并记忆成功: {memory_id1} 和 {memory_id2}...")
|
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.brief}")
|
||||||
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.summary.get('brief')}")
|
logger.debug(f"{self.log_prefix} 合并后的记忆内容: {merged_memory.summary}")
|
||||||
logger.debug(f"{self.log_prefix} 合并后的记忆要点: {merged_memory.summary.get('points')}")
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
|
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
|
||||||
|
|||||||
@@ -7,12 +7,12 @@ import string
|
|||||||
class MemoryItem:
|
class MemoryItem:
|
||||||
"""记忆项类,用于存储单个记忆的所有相关信息"""
|
"""记忆项类,用于存储单个记忆的所有相关信息"""
|
||||||
|
|
||||||
def __init__(self, data: Any, from_source: str = "", brief: str = ""):
|
def __init__(self, summary: str, from_source: str = "", brief: str = ""):
|
||||||
"""
|
"""
|
||||||
初始化记忆项
|
初始化记忆项
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: 记忆数据
|
summary: 记忆内容概括
|
||||||
from_source: 数据来源
|
from_source: 数据来源
|
||||||
brief: 记忆内容主题
|
brief: 记忆内容主题
|
||||||
"""
|
"""
|
||||||
@@ -20,18 +20,12 @@ class MemoryItem:
|
|||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
|
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
|
||||||
self.id = f"{timestamp}_{random_str}"
|
self.id = f"{timestamp}_{random_str}"
|
||||||
self.data = data
|
|
||||||
self.data_type = type(data)
|
|
||||||
self.from_source = from_source
|
self.from_source = from_source
|
||||||
self.brief = brief
|
self.brief = brief
|
||||||
self.timestamp = time.time()
|
self.timestamp = time.time()
|
||||||
# 修改summary的结构说明,用于存储可能的总结信息
|
|
||||||
# summary结构:{
|
# 记忆内容概括
|
||||||
# "detailed": "记忆内容概括",
|
self.summary = summary
|
||||||
# "keypoints": ["关键概念1", "关键概念2"],
|
|
||||||
# "events": ["事件1", "事件2"]
|
|
||||||
# }
|
|
||||||
self.summary = None
|
|
||||||
|
|
||||||
# 记忆精简次数
|
# 记忆精简次数
|
||||||
self.compress_count = 0
|
self.compress_count = 0
|
||||||
@@ -50,10 +44,6 @@ class MemoryItem:
|
|||||||
"""检查来源是否匹配"""
|
"""检查来源是否匹配"""
|
||||||
return self.from_source == source
|
return self.from_source == source
|
||||||
|
|
||||||
def set_summary(self, summary: Dict[str, Any]) -> None:
|
|
||||||
"""设置总结信息"""
|
|
||||||
self.summary = summary
|
|
||||||
|
|
||||||
def increase_strength(self, amount: float) -> None:
|
def increase_strength(self, amount: float) -> None:
|
||||||
"""增加记忆强度"""
|
"""增加记忆强度"""
|
||||||
self.memory_strength = min(10.0, self.memory_strength + amount)
|
self.memory_strength = min(10.0, self.memory_strength + amount)
|
||||||
@@ -85,9 +75,9 @@ class MemoryItem:
|
|||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
|
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
|
||||||
|
|
||||||
def to_tuple(self) -> Tuple[Any, str, float, str]:
|
def to_tuple(self) -> Tuple[str, str, float, str]:
|
||||||
"""转换为元组格式(为了兼容性)"""
|
"""转换为元组格式(为了兼容性)"""
|
||||||
return (self.data, self.from_source, self.timestamp, self.id)
|
return (self.summary, self.from_source, self.timestamp, self.id)
|
||||||
|
|
||||||
def is_memory_valid(self) -> bool:
|
def is_memory_valid(self) -> bool:
|
||||||
"""检查记忆是否有效(强度是否大于等于1)"""
|
"""检查记忆是否有效(强度是否大于等于1)"""
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ class MemoryManager:
|
|||||||
# 关联的聊天ID
|
# 关联的聊天ID
|
||||||
self._chat_id = chat_id
|
self._chat_id = chat_id
|
||||||
|
|
||||||
# 主存储: 数据类型 -> 记忆项列表
|
# 记忆项列表
|
||||||
self._memory: Dict[Type, List[MemoryItem]] = {}
|
self._memories: List[MemoryItem] = []
|
||||||
|
|
||||||
# ID到记忆项的映射
|
# ID到记忆项的映射
|
||||||
self._id_map: Dict[str, MemoryItem] = {}
|
self._id_map: Dict[str, MemoryItem] = {}
|
||||||
@@ -58,51 +58,12 @@ class MemoryManager:
|
|||||||
Returns:
|
Returns:
|
||||||
记忆项的ID
|
记忆项的ID
|
||||||
"""
|
"""
|
||||||
data_type = memory_item.data_type
|
|
||||||
|
|
||||||
# 确保存在该类型的存储列表
|
|
||||||
if data_type not in self._memory:
|
|
||||||
self._memory[data_type] = []
|
|
||||||
|
|
||||||
# 添加到内存和ID映射
|
# 添加到内存和ID映射
|
||||||
self._memory[data_type].append(memory_item)
|
self._memories.append(memory_item)
|
||||||
self._id_map[memory_item.id] = memory_item
|
self._id_map[memory_item.id] = memory_item
|
||||||
|
|
||||||
return memory_item.id
|
return memory_item.id
|
||||||
|
|
||||||
async def push_with_summary(self, data: T, from_source: str = "") -> MemoryItem:
|
|
||||||
"""
|
|
||||||
推送一段有类型的信息到工作记忆中,并自动生成总结
|
|
||||||
|
|
||||||
Args:
|
|
||||||
data: 要存储的数据
|
|
||||||
from_source: 数据来源
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含原始数据和总结信息的字典
|
|
||||||
"""
|
|
||||||
# 如果数据是字符串类型,则先进行总结
|
|
||||||
if isinstance(data, str):
|
|
||||||
# 先生成总结
|
|
||||||
summary = await self.summarize_memory_item(data)
|
|
||||||
|
|
||||||
# 创建记忆项
|
|
||||||
memory_item = MemoryItem(data, from_source, brief=summary.get("brief", ""))
|
|
||||||
|
|
||||||
# 将总结信息保存到记忆项中
|
|
||||||
memory_item.set_summary(summary)
|
|
||||||
|
|
||||||
# 推送记忆项
|
|
||||||
self.push_item(memory_item)
|
|
||||||
|
|
||||||
return memory_item
|
|
||||||
else:
|
|
||||||
# 非字符串类型,直接创建并推送记忆项
|
|
||||||
memory_item = MemoryItem(data, from_source)
|
|
||||||
self.push_item(memory_item)
|
|
||||||
|
|
||||||
return memory_item
|
|
||||||
|
|
||||||
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
|
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
|
||||||
"""
|
"""
|
||||||
通过ID获取记忆项
|
通过ID获取记忆项
|
||||||
@@ -129,7 +90,6 @@ class MemoryManager:
|
|||||||
|
|
||||||
def find_items(
|
def find_items(
|
||||||
self,
|
self,
|
||||||
data_type: Optional[Type] = None,
|
|
||||||
source: Optional[str] = None,
|
source: Optional[str] = None,
|
||||||
start_time: Optional[float] = None,
|
start_time: Optional[float] = None,
|
||||||
end_time: Optional[float] = None,
|
end_time: Optional[float] = None,
|
||||||
@@ -142,7 +102,6 @@ class MemoryManager:
|
|||||||
按条件查找记忆项
|
按条件查找记忆项
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data_type: 要查找的数据类型
|
|
||||||
source: 数据来源
|
source: 数据来源
|
||||||
start_time: 开始时间戳
|
start_time: 开始时间戳
|
||||||
end_time: 结束时间戳
|
end_time: 结束时间戳
|
||||||
@@ -161,16 +120,8 @@ class MemoryManager:
|
|||||||
|
|
||||||
results = []
|
results = []
|
||||||
|
|
||||||
# 确定要搜索的类型列表
|
# 获取所有项目
|
||||||
types_to_search = [data_type] if data_type else list(self._memory.keys())
|
items = self._memories
|
||||||
|
|
||||||
# 对每个类型进行搜索
|
|
||||||
for typ in types_to_search:
|
|
||||||
if typ not in self._memory:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取该类型的所有项目
|
|
||||||
items = self._memory[typ]
|
|
||||||
|
|
||||||
# 如果需要最新优先,则反转遍历顺序
|
# 如果需要最新优先,则反转遍历顺序
|
||||||
if newest_first:
|
if newest_first:
|
||||||
@@ -203,7 +154,7 @@ class MemoryManager:
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
async def summarize_memory_item(self, content: str) -> Dict[str, Any]:
|
async def summarize_memory_item(self, content: str) -> Dict[str, str]:
|
||||||
"""
|
"""
|
||||||
使用LLM总结记忆项
|
使用LLM总结记忆项
|
||||||
|
|
||||||
@@ -211,11 +162,11 @@ class MemoryManager:
|
|||||||
content: 需要总结的内容
|
content: 需要总结的内容
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含总结、概括、关键概念和事件的字典
|
包含brief和summary的字典
|
||||||
"""
|
"""
|
||||||
prompt = f"""请对以下内容进行总结,总结成记忆,输出两部分:
|
prompt = f"""请对以下内容进行总结,总结成记忆,输出两部分:
|
||||||
1. 记忆内容主题(精简,20字以内):让用户可以一眼看出记忆内容是什么
|
1. 记忆内容主题(精简,20字以内):让用户可以一眼看出记忆内容是什么
|
||||||
2. content:一到三条,包含关键的概念、事件,每条都要包含解释或描述,谁在什么时候干了什么
|
2. 记忆内容概括:对内容进行概括,保留重要信息,200字以内
|
||||||
|
|
||||||
内容:
|
内容:
|
||||||
{content}
|
{content}
|
||||||
@@ -223,16 +174,13 @@ class MemoryManager:
|
|||||||
请按以下JSON格式输出:
|
请按以下JSON格式输出:
|
||||||
{{
|
{{
|
||||||
"brief": "记忆内容主题",
|
"brief": "记忆内容主题",
|
||||||
"points": [
|
"summary": "记忆内容概括"
|
||||||
"内容",
|
|
||||||
"内容"
|
|
||||||
]
|
|
||||||
}}
|
}}
|
||||||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||||
"""
|
"""
|
||||||
default_summary = {
|
default_summary = {
|
||||||
"brief": "主题未知的记忆",
|
"brief": "主题未知的记忆",
|
||||||
"points": ["未知的要点"],
|
"summary": "无法概括的记忆内容",
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -264,132 +212,19 @@ class MemoryManager:
|
|||||||
if "brief" not in json_result or not isinstance(json_result["brief"], str):
|
if "brief" not in json_result or not isinstance(json_result["brief"], str):
|
||||||
json_result["brief"] = "主题未知的记忆"
|
json_result["brief"] = "主题未知的记忆"
|
||||||
|
|
||||||
# 处理关键要点
|
if "summary" not in json_result or not isinstance(json_result["summary"], str):
|
||||||
if "points" not in json_result or not isinstance(json_result["points"], list):
|
json_result["summary"] = "无法概括的记忆内容"
|
||||||
json_result["points"] = ["未知的要点"]
|
|
||||||
else:
|
|
||||||
# 确保points中的每个项目都是字符串
|
|
||||||
json_result["points"] = [str(point) for point in json_result["points"] if point is not None]
|
|
||||||
if not json_result["points"]:
|
|
||||||
json_result["points"] = ["未知的要点"]
|
|
||||||
|
|
||||||
return json_result
|
return json_result
|
||||||
|
|
||||||
except Exception as json_error:
|
except Exception as json_error:
|
||||||
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
|
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
|
||||||
# 返回默认结构
|
|
||||||
return default_summary
|
return default_summary
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# 出错时返回简单的结构
|
|
||||||
logger.error(f"生成总结时出错: {str(e)}")
|
logger.error(f"生成总结时出错: {str(e)}")
|
||||||
return default_summary
|
return default_summary
|
||||||
|
|
||||||
# async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
|
|
||||||
# """
|
|
||||||
# 对记忆进行精简操作,根据要求修改要点、总结和概括
|
|
||||||
|
|
||||||
# Args:
|
|
||||||
# memory_id: 记忆ID
|
|
||||||
# requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
|
|
||||||
|
|
||||||
# Returns:
|
|
||||||
# 修改后的记忆总结字典
|
|
||||||
# """
|
|
||||||
# # 获取指定ID的记忆项
|
|
||||||
# logger.info(f"精简记忆: {memory_id}")
|
|
||||||
# memory_item = self.get_by_id(memory_id)
|
|
||||||
# if not memory_item:
|
|
||||||
# raise ValueError(f"未找到ID为{memory_id}的记忆项")
|
|
||||||
|
|
||||||
# # 增加精简次数
|
|
||||||
# memory_item.increase_compress_count()
|
|
||||||
|
|
||||||
# summary = memory_item.summary
|
|
||||||
|
|
||||||
# # 使用LLM根据要求对总结、概括和要点进行精简修改
|
|
||||||
# prompt = f"""
|
|
||||||
# 请根据以下要求,对记忆内容的主题和关键要点进行精简,模拟记忆的遗忘过程:
|
|
||||||
# 要求:{requirements}
|
|
||||||
# 你可以随机对关键要点进行压缩,模糊或者丢弃,修改后,同样修改主题
|
|
||||||
|
|
||||||
# 目前主题:{summary["brief"]}
|
|
||||||
|
|
||||||
# 目前关键要点:
|
|
||||||
# {chr(10).join([f"- {point}" for point in summary.get("points", [])])}
|
|
||||||
|
|
||||||
# 请生成修改后的主题和关键要点,遵循以下格式:
|
|
||||||
# ```json
|
|
||||||
# {{
|
|
||||||
# "brief": "修改后的主题(20字以内)",
|
|
||||||
# "points": [
|
|
||||||
# "修改后的要点",
|
|
||||||
# "修改后的要点"
|
|
||||||
# ]
|
|
||||||
# }}
|
|
||||||
# ```
|
|
||||||
# 请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
|
||||||
# """
|
|
||||||
# # 定义默认的精简结果
|
|
||||||
# default_refined = {
|
|
||||||
# "brief": summary["brief"],
|
|
||||||
# "points": summary.get("points", ["未知的要点"])[:1], # 默认只保留第一个要点
|
|
||||||
# }
|
|
||||||
|
|
||||||
# try:
|
|
||||||
# # 调用LLM修改总结、概括和要点
|
|
||||||
# response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
|
||||||
# logger.debug(f"精简记忆响应: {response}")
|
|
||||||
# # 使用repair_json处理响应
|
|
||||||
# try:
|
|
||||||
# # 修复JSON格式
|
|
||||||
# fixed_json_string = repair_json(response)
|
|
||||||
|
|
||||||
# # 将修复后的字符串解析为Python对象
|
|
||||||
# if isinstance(fixed_json_string, str):
|
|
||||||
# try:
|
|
||||||
# refined_data = json.loads(fixed_json_string)
|
|
||||||
# except json.JSONDecodeError as decode_error:
|
|
||||||
# logger.error(f"JSON解析错误: {str(decode_error)}")
|
|
||||||
# refined_data = default_refined
|
|
||||||
# else:
|
|
||||||
# # 如果repair_json直接返回了字典对象,直接使用
|
|
||||||
# refined_data = fixed_json_string
|
|
||||||
|
|
||||||
# # 确保是字典类型
|
|
||||||
# if not isinstance(refined_data, dict):
|
|
||||||
# logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
|
|
||||||
# refined_data = default_refined
|
|
||||||
|
|
||||||
# # 更新总结
|
|
||||||
# summary["brief"] = refined_data.get("brief", "主题未知的记忆")
|
|
||||||
|
|
||||||
# # 更新关键要点
|
|
||||||
# points = refined_data.get("points", [])
|
|
||||||
# if isinstance(points, list) and points:
|
|
||||||
# # 确保所有要点都是字符串
|
|
||||||
# summary["points"] = [str(point) for point in points if point is not None]
|
|
||||||
# else:
|
|
||||||
# # 如果points不是列表或为空,使用默认值
|
|
||||||
# summary["points"] = ["主要要点已遗忘"]
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"精简记忆出错: {str(e)}")
|
|
||||||
# traceback.print_exc()
|
|
||||||
|
|
||||||
# # 出错时使用简化的默认精简
|
|
||||||
# summary["brief"] = summary["brief"] + " (已简化)"
|
|
||||||
# summary["points"] = summary.get("points", ["未知的要点"])[:1]
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# logger.error(f"精简记忆调用LLM出错: {str(e)}")
|
|
||||||
# traceback.print_exc()
|
|
||||||
|
|
||||||
# # 更新原记忆项的总结
|
|
||||||
# memory_item.set_summary(summary)
|
|
||||||
|
|
||||||
# return memory_item
|
|
||||||
|
|
||||||
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
||||||
"""
|
"""
|
||||||
使单个记忆衰减
|
使单个记忆衰减
|
||||||
@@ -431,32 +266,17 @@ class MemoryManager:
|
|||||||
item = self._id_map[memory_id]
|
item = self._id_map[memory_id]
|
||||||
|
|
||||||
# 从内存中删除
|
# 从内存中删除
|
||||||
data_type = item.data_type
|
self._memories = [i for i in self._memories if i.id != memory_id]
|
||||||
if data_type in self._memory:
|
|
||||||
self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id]
|
|
||||||
|
|
||||||
# 从ID映射中删除
|
# 从ID映射中删除
|
||||||
del self._id_map[memory_id]
|
del self._id_map[memory_id]
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def clear(self, data_type: Optional[Type] = None) -> None:
|
def clear(self) -> None:
|
||||||
"""
|
"""清除所有记忆"""
|
||||||
清除记忆中的数据
|
self._memories.clear()
|
||||||
|
|
||||||
Args:
|
|
||||||
data_type: 要清除的数据类型,如果为None则清除所有数据
|
|
||||||
"""
|
|
||||||
if data_type is None:
|
|
||||||
# 清除所有数据
|
|
||||||
self._memory.clear()
|
|
||||||
self._id_map.clear()
|
self._id_map.clear()
|
||||||
elif data_type in self._memory:
|
|
||||||
# 清除指定类型的数据
|
|
||||||
for item in self._memory[data_type]:
|
|
||||||
if item.id in self._id_map:
|
|
||||||
del self._id_map[item.id]
|
|
||||||
del self._memory[data_type]
|
|
||||||
|
|
||||||
async def merge_memories(
|
async def merge_memories(
|
||||||
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
|
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
|
||||||
@@ -471,7 +291,7 @@ class MemoryManager:
|
|||||||
delete_originals: 是否删除原始记忆,默认为True
|
delete_originals: 是否删除原始记忆,默认为True
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含合并后的记忆信息的字典
|
合并后的记忆项
|
||||||
"""
|
"""
|
||||||
# 获取两个记忆项
|
# 获取两个记忆项
|
||||||
memory_item1 = self.get_by_id(memory_id1)
|
memory_item1 = self.get_by_id(memory_id1)
|
||||||
@@ -480,58 +300,33 @@ class MemoryManager:
|
|||||||
if not memory_item1 or not memory_item2:
|
if not memory_item1 or not memory_item2:
|
||||||
raise ValueError("无法找到指定的记忆项")
|
raise ValueError("无法找到指定的记忆项")
|
||||||
|
|
||||||
# 获取记忆的摘要信息(如果有)
|
|
||||||
summary1 = memory_item1.summary
|
|
||||||
summary2 = memory_item2.summary
|
|
||||||
|
|
||||||
# 构建合并提示
|
# 构建合并提示
|
||||||
prompt = f"""
|
prompt = f"""
|
||||||
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
|
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
|
||||||
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
|
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
|
||||||
|
|
||||||
合并原因:{reason}
|
合并原因:{reason}
|
||||||
"""
|
|
||||||
|
|
||||||
# 如果有摘要信息,添加到提示中
|
记忆1主题:{memory_item1.brief}
|
||||||
if summary1:
|
记忆1内容:{memory_item1.summary}
|
||||||
prompt += f"记忆1主题:{summary1['brief']}\n"
|
|
||||||
|
|
||||||
prompt += "记忆1关键要点:\n" + "\n".join([f"- {point}" for point in summary1.get("points", [])]) + "\n\n"
|
记忆2主题:{memory_item2.brief}
|
||||||
|
记忆2内容:{memory_item2.summary}
|
||||||
|
|
||||||
if summary2:
|
|
||||||
prompt += f"记忆2主题:{summary2['brief']}\n"
|
|
||||||
prompt += "记忆2关键要点:\n" + "\n".join([f"- {point}" for point in summary2.get("points", [])]) + "\n\n"
|
|
||||||
|
|
||||||
prompt += """
|
|
||||||
请按以下JSON格式输出合并结果:
|
请按以下JSON格式输出合并结果:
|
||||||
```json
|
{{
|
||||||
{
|
|
||||||
"brief": "合并后的主题(20字以内)",
|
"brief": "合并后的主题(20字以内)",
|
||||||
"points": [
|
"summary": "合并后的内容概括(200字以内)"
|
||||||
"合并后的要点",
|
}}
|
||||||
"合并后的要点"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
```
|
|
||||||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 默认合并结果
|
# 默认合并结果
|
||||||
default_merged = {
|
default_merged = {
|
||||||
"brief": f"合并:{summary1['brief']} + {summary2['brief']}",
|
"brief": f"合并:{memory_item1.brief} + {memory_item2.brief}",
|
||||||
"points": [],
|
"summary": f"合并的记忆:{memory_item1.summary}\n{memory_item2.summary}",
|
||||||
}
|
}
|
||||||
|
|
||||||
# 合并points
|
|
||||||
if "points" in summary1:
|
|
||||||
default_merged["points"].extend(summary1["points"])
|
|
||||||
if "points" in summary2:
|
|
||||||
default_merged["points"].extend(summary2["points"])
|
|
||||||
|
|
||||||
# 确保列表不为空
|
|
||||||
if not default_merged["points"]:
|
|
||||||
default_merged["points"] = ["合并的要点"]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# 调用LLM合并记忆
|
# 调用LLM合并记忆
|
||||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||||
@@ -560,14 +355,8 @@ class MemoryManager:
|
|||||||
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
|
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
|
||||||
merged_data["brief"] = default_merged["brief"]
|
merged_data["brief"] = default_merged["brief"]
|
||||||
|
|
||||||
# 处理关键要点
|
if "summary" not in merged_data or not isinstance(merged_data["summary"], str):
|
||||||
if "points" not in merged_data or not isinstance(merged_data["points"], list):
|
merged_data["summary"] = default_merged["summary"]
|
||||||
merged_data["points"] = default_merged["points"]
|
|
||||||
else:
|
|
||||||
# 确保points中的每个项目都是字符串
|
|
||||||
merged_data["points"] = [str(point) for point in merged_data["points"] if point is not None]
|
|
||||||
if not merged_data["points"]:
|
|
||||||
merged_data["points"] = ["合并的要点"]
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
|
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
|
||||||
@@ -586,15 +375,8 @@ class MemoryManager:
|
|||||||
else memory_item2.from_source
|
else memory_item2.from_source
|
||||||
)
|
)
|
||||||
|
|
||||||
# 创建新的记忆项,使用空字符串作为data
|
# 创建新的记忆项
|
||||||
merged_memory = MemoryItem(data="", from_source=merged_source, brief=merged_data["brief"])
|
merged_memory = MemoryItem(summary=merged_data["summary"], from_source=merged_source, brief=merged_data["brief"])
|
||||||
|
|
||||||
# 设置合并后的摘要
|
|
||||||
summary = {
|
|
||||||
"brief": merged_data["brief"],
|
|
||||||
"points": merged_data["points"],
|
|
||||||
}
|
|
||||||
merged_memory.set_summary(summary)
|
|
||||||
|
|
||||||
# 记忆强度取两者最大值
|
# 记忆强度取两者最大值
|
||||||
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
|
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
|
||||||
|
|||||||
@@ -54,18 +54,25 @@ class WorkingMemory:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"自动衰减记忆时出错: {str(e)}")
|
print(f"自动衰减记忆时出错: {str(e)}")
|
||||||
|
|
||||||
async def add_memory(self, content: Any, from_source: str = ""):
|
async def add_memory(self, summary: Any, from_source: str = "",brief: str = ""):
|
||||||
"""
|
"""
|
||||||
添加一段记忆到指定聊天
|
添加一段记忆到指定聊天
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
content: 记忆内容
|
summary: 记忆内容
|
||||||
from_source: 数据来源
|
from_source: 数据来源
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
包含记忆信息的字典
|
记忆项
|
||||||
"""
|
"""
|
||||||
memory = await self.memory_manager.push_with_summary(content, from_source)
|
# 如果是字符串类型,生成总结
|
||||||
|
|
||||||
|
memory = MemoryItem(summary, from_source, brief)
|
||||||
|
|
||||||
|
# 添加到管理器
|
||||||
|
self.memory_manager.push_item(memory)
|
||||||
|
|
||||||
|
# 如果超过最大记忆数量,删除最早的记忆
|
||||||
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
|
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
|
||||||
self.remove_earliest_memory()
|
self.remove_earliest_memory()
|
||||||
|
|
||||||
|
|||||||
@@ -8,37 +8,48 @@ from src.chat.utils.chat_message_builder import (
|
|||||||
num_new_messages_since,
|
num_new_messages_since,
|
||||||
get_person_id_list,
|
get_person_id_list,
|
||||||
)
|
)
|
||||||
from src.chat.utils.prompt_builder import global_prompt_manager
|
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
import difflib
|
import difflib
|
||||||
from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入
|
from src.chat.message_receive.message import MessageRecv
|
||||||
from src.chat.heart_flow.observation.observation import Observation
|
from src.chat.heart_flow.observation.observation import Observation
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||||
from src.chat.utils.prompt_builder import Prompt
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("observation")
|
logger = get_logger("observation")
|
||||||
|
|
||||||
|
# 定义提示模板
|
||||||
Prompt(
|
Prompt(
|
||||||
"""这是qq群聊的聊天记录,请总结以下聊天记录的主题:
|
"""这是qq群聊的聊天记录,请总结以下聊天记录的主题:
|
||||||
{chat_logs}
|
{chat_logs}
|
||||||
请用一句话概括,包括人物、事件和主要信息,不要分点。""",
|
请概括这段聊天记录的主题和主要内容
|
||||||
|
主题:简短的概括,包括时间,人物和事件,不要超过10个字
|
||||||
|
内容:具体的信息内容,包括人物、事件和信息,不要超过100个字,不要分点。
|
||||||
|
|
||||||
|
请用json格式返回,格式如下:
|
||||||
|
{{
|
||||||
|
"theme": "主题",
|
||||||
|
"content": "内容"
|
||||||
|
}}
|
||||||
|
""",
|
||||||
"chat_summary_group_prompt", # Template for group chat
|
"chat_summary_group_prompt", # Template for group chat
|
||||||
)
|
)
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""这是你和{chat_target}的私聊记录,请总结以下聊天记录的主题:
|
"""这是你和{chat_target}的私聊记录,请总结以下聊天记录的主题:
|
||||||
{chat_logs}
|
{chat_logs}
|
||||||
请用一句话概括,包括事件,时间,和主要信息,不要分点。""",
|
请用一句话概括,包括事件,时间,和主要信息,不要分点。
|
||||||
|
主题:简短的介绍,不要超过10个字
|
||||||
|
内容:包括人物、事件和主要信息,不要分点。
|
||||||
|
|
||||||
|
请用json格式返回,格式如下:
|
||||||
|
{{
|
||||||
|
"theme": "主题",
|
||||||
|
"content": "内容"
|
||||||
|
}}""",
|
||||||
"chat_summary_private_prompt", # Template for private chat
|
"chat_summary_private_prompt", # Template for private chat
|
||||||
)
|
)
|
||||||
# --- End Prompt Template Definition ---
|
|
||||||
|
|
||||||
|
|
||||||
# 聊天观察
|
|
||||||
class ChattingObservation(Observation):
|
class ChattingObservation(Observation):
|
||||||
def __init__(self, chat_id):
|
def __init__(self, chat_id):
|
||||||
super().__init__(chat_id)
|
super().__init__(chat_id)
|
||||||
@@ -47,7 +58,6 @@ class ChattingObservation(Observation):
|
|||||||
|
|
||||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||||
|
|
||||||
# --- Other attributes initialized in __init__ ---
|
|
||||||
self.talking_message = []
|
self.talking_message = []
|
||||||
self.talking_message_str = ""
|
self.talking_message_str = ""
|
||||||
self.talking_message_str_truncate = ""
|
self.talking_message_str_truncate = ""
|
||||||
@@ -55,13 +65,10 @@ class ChattingObservation(Observation):
|
|||||||
self.nick_name = global_config.bot.alias_names
|
self.nick_name = global_config.bot.alias_names
|
||||||
self.max_now_obs_len = global_config.focus_chat.observation_context_size
|
self.max_now_obs_len = global_config.focus_chat.observation_context_size
|
||||||
self.overlap_len = global_config.focus_chat.compressed_length
|
self.overlap_len = global_config.focus_chat.compressed_length
|
||||||
self.mid_memories = []
|
|
||||||
self.max_mid_memory_len = global_config.focus_chat.compress_length_limit
|
|
||||||
self.mid_memory_info = ""
|
|
||||||
self.person_list = []
|
self.person_list = []
|
||||||
|
self.compressor_prompt = ""
|
||||||
self.oldest_messages = []
|
self.oldest_messages = []
|
||||||
self.oldest_messages_str = ""
|
self.oldest_messages_str = ""
|
||||||
self.compressor_prompt = ""
|
|
||||||
|
|
||||||
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
|
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
|
||||||
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
|
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
|
||||||
@@ -79,42 +86,12 @@ class ChattingObservation(Observation):
|
|||||||
"talking_message_str_truncate": self.talking_message_str_truncate,
|
"talking_message_str_truncate": self.talking_message_str_truncate,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"nick_name": self.nick_name,
|
"nick_name": self.nick_name,
|
||||||
"mid_memory_info": self.mid_memory_info,
|
|
||||||
"person_list": self.person_list,
|
|
||||||
"oldest_messages_str": self.oldest_messages_str,
|
|
||||||
"compressor_prompt": self.compressor_prompt,
|
|
||||||
"last_observe_time": self.last_observe_time,
|
"last_observe_time": self.last_observe_time,
|
||||||
}
|
}
|
||||||
|
|
||||||
# 进行一次观察 返回观察结果observe_info
|
|
||||||
def get_observe_info(self, ids=None):
|
def get_observe_info(self, ids=None):
|
||||||
mid_memory_str = ""
|
|
||||||
if ids:
|
|
||||||
for id in ids:
|
|
||||||
print(f"id:{id}")
|
|
||||||
try:
|
|
||||||
for mid_memory in self.mid_memories:
|
|
||||||
if mid_memory["id"] == id:
|
|
||||||
mid_memory_by_id = mid_memory
|
|
||||||
msg_str = ""
|
|
||||||
for msg in mid_memory_by_id["messages"]:
|
|
||||||
msg_str += f"{msg['detailed_plain_text']}"
|
|
||||||
# time_diff = int((datetime.now().timestamp() - mid_memory_by_id["created_at"]) / 60)
|
|
||||||
# mid_memory_str += f"距离现在{time_diff}分钟前:\n{msg_str}\n"
|
|
||||||
mid_memory_str += f"{msg_str}\n"
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取mid_memory_id失败: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return self.talking_message_str
|
return self.talking_message_str
|
||||||
|
|
||||||
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
|
|
||||||
|
|
||||||
else:
|
|
||||||
mid_memory_str = "之前的聊天内容:\n"
|
|
||||||
for mid_memory in self.mid_memories:
|
|
||||||
mid_memory_str += f"{mid_memory['theme']}\n"
|
|
||||||
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
|
|
||||||
|
|
||||||
def search_message_by_text(self, text: str) -> Optional[MessageRecv]:
|
def search_message_by_text(self, text: str) -> Optional[MessageRecv]:
|
||||||
"""
|
"""
|
||||||
根据回复的纯文本
|
根据回复的纯文本
|
||||||
@@ -128,7 +105,6 @@ class ChattingObservation(Observation):
|
|||||||
for message in reverse_talking_message:
|
for message in reverse_talking_message:
|
||||||
if message["processed_plain_text"] == text:
|
if message["processed_plain_text"] == text:
|
||||||
find_msg = message
|
find_msg = message
|
||||||
# logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
|
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
raw_message = message.get("raw_message")
|
raw_message = message.get("raw_message")
|
||||||
@@ -137,11 +113,11 @@ class ChattingObservation(Observation):
|
|||||||
else:
|
else:
|
||||||
similarity = difflib.SequenceMatcher(None, text, message.get("processed_plain_text", "")).ratio()
|
similarity = difflib.SequenceMatcher(None, text, message.get("processed_plain_text", "")).ratio()
|
||||||
msg_list.append({"message": message, "similarity": similarity})
|
msg_list.append({"message": message, "similarity": similarity})
|
||||||
# logger.debug(f"对锚定消息检查:message: {message['processed_plain_text']},similarity: {similarity}")
|
|
||||||
if not find_msg:
|
if not find_msg:
|
||||||
if msg_list:
|
if msg_list:
|
||||||
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
|
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
|
||||||
if msg_list[0]["similarity"] >= 0.9: # 只返回相似度大于等于0.5的消息
|
if msg_list[0]["similarity"] >= 0.9:
|
||||||
find_msg = msg_list[0]["message"]
|
find_msg = msg_list[0]["message"]
|
||||||
else:
|
else:
|
||||||
logger.debug("没有找到锚定消息,相似度低")
|
logger.debug("没有找到锚定消息,相似度低")
|
||||||
@@ -150,9 +126,6 @@ class ChattingObservation(Observation):
|
|||||||
logger.debug("没有找到锚定消息,没有消息捕获")
|
logger.debug("没有找到锚定消息,没有消息捕获")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# logger.debug(f"找到的锚定消息:find_msg: {find_msg}")
|
|
||||||
|
|
||||||
# 创建所需的user_info字段
|
|
||||||
user_info = {
|
user_info = {
|
||||||
"platform": find_msg.get("user_platform", ""),
|
"platform": find_msg.get("user_platform", ""),
|
||||||
"user_id": find_msg.get("user_id", ""),
|
"user_id": find_msg.get("user_id", ""),
|
||||||
@@ -160,7 +133,6 @@ class ChattingObservation(Observation):
|
|||||||
"user_cardname": find_msg.get("user_cardname", ""),
|
"user_cardname": find_msg.get("user_cardname", ""),
|
||||||
}
|
}
|
||||||
|
|
||||||
# 创建所需的group_info字段,如果是群聊的话
|
|
||||||
group_info = {}
|
group_info = {}
|
||||||
if find_msg.get("chat_info_group_id"):
|
if find_msg.get("chat_info_group_id"):
|
||||||
group_info = {
|
group_info = {
|
||||||
@@ -194,9 +166,7 @@ class ChattingObservation(Observation):
|
|||||||
"detailed_plain_text": find_msg.get("processed_plain_text"),
|
"detailed_plain_text": find_msg.get("processed_plain_text"),
|
||||||
"processed_plain_text": find_msg.get("processed_plain_text"),
|
"processed_plain_text": find_msg.get("processed_plain_text"),
|
||||||
}
|
}
|
||||||
# print(f"message_dict: {message_dict}")
|
|
||||||
find_rec_msg = MessageRecv(message_dict)
|
find_rec_msg = MessageRecv(message_dict)
|
||||||
# logger.debug(f"锚定消息处理后:find_rec_msg: {find_rec_msg}")
|
|
||||||
return find_rec_msg
|
return find_rec_msg
|
||||||
|
|
||||||
async def observe(self):
|
async def observe(self):
|
||||||
@@ -209,8 +179,6 @@ class ChattingObservation(Observation):
|
|||||||
limit_mode="latest",
|
limit_mode="latest",
|
||||||
)
|
)
|
||||||
|
|
||||||
# print(f"new_messages_list: {new_messages_list}")
|
|
||||||
|
|
||||||
last_obs_time_mark = self.last_observe_time
|
last_obs_time_mark = self.last_observe_time
|
||||||
if new_messages_list:
|
if new_messages_list:
|
||||||
self.last_observe_time = new_messages_list[-1]["time"]
|
self.last_observe_time = new_messages_list[-1]["time"]
|
||||||
@@ -220,60 +188,47 @@ class ChattingObservation(Observation):
|
|||||||
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
|
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
|
||||||
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
|
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
|
||||||
oldest_messages = self.talking_message[:messages_to_remove_count]
|
oldest_messages = self.talking_message[:messages_to_remove_count]
|
||||||
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
|
self.talking_message = self.talking_message[messages_to_remove_count:]
|
||||||
|
|
||||||
# print(f"压缩中:oldest_messages: {oldest_messages}")
|
# 构建压缩提示
|
||||||
oldest_messages_str = build_readable_messages(
|
oldest_messages_str = build_readable_messages(
|
||||||
messages=oldest_messages, timestamp_mode="normal_no_YMD", read_mark=0, show_actions=True
|
messages=oldest_messages,
|
||||||
|
timestamp_mode="normal_no_YMD",
|
||||||
|
read_mark=0,
|
||||||
|
show_actions=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- Build prompt using template ---
|
# 根据聊天类型选择提示模板
|
||||||
prompt = None # Initialize prompt as None
|
|
||||||
try:
|
|
||||||
# 构建 Prompt - 根据 is_group_chat 选择模板
|
|
||||||
if self.is_group_chat:
|
if self.is_group_chat:
|
||||||
prompt_template_name = "chat_summary_group_prompt"
|
prompt_template_name = "chat_summary_group_prompt"
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
prompt_template_name, chat_logs=oldest_messages_str
|
prompt_template_name,
|
||||||
|
chat_logs=oldest_messages_str
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# For private chat, add chat_target to the prompt variables
|
|
||||||
prompt_template_name = "chat_summary_private_prompt"
|
prompt_template_name = "chat_summary_private_prompt"
|
||||||
# Determine the target name for the prompt
|
chat_target_name = "对方"
|
||||||
chat_target_name = "对方" # Default fallback
|
|
||||||
if self.chat_target_info:
|
if self.chat_target_info:
|
||||||
# Prioritize person_name, then nickname
|
|
||||||
chat_target_name = (
|
chat_target_name = (
|
||||||
self.chat_target_info.get("person_name")
|
self.chat_target_info.get("person_name")
|
||||||
or self.chat_target_info.get("user_nickname")
|
or self.chat_target_info.get("user_nickname")
|
||||||
or chat_target_name
|
or chat_target_name
|
||||||
)
|
)
|
||||||
|
|
||||||
# Format the private chat prompt
|
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
prompt_template_name,
|
prompt_template_name,
|
||||||
# Assuming the private prompt template uses {chat_target}
|
|
||||||
chat_target=chat_target_name,
|
chat_target=chat_target_name,
|
||||||
chat_logs=oldest_messages_str,
|
chat_logs=oldest_messages_str,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"构建总结 Prompt 失败 for chat {self.chat_id}: {e}")
|
|
||||||
# prompt remains None
|
|
||||||
|
|
||||||
if prompt: # Check if prompt was built successfully
|
|
||||||
self.compressor_prompt = prompt
|
self.compressor_prompt = prompt
|
||||||
self.oldest_messages = oldest_messages
|
|
||||||
self.oldest_messages_str = oldest_messages_str
|
|
||||||
|
|
||||||
# 构建中
|
# 构建当前消息
|
||||||
# print(f"构建中:self.talking_message: {self.talking_message}")
|
|
||||||
self.talking_message_str = build_readable_messages(
|
self.talking_message_str = build_readable_messages(
|
||||||
messages=self.talking_message,
|
messages=self.talking_message,
|
||||||
timestamp_mode="lite",
|
timestamp_mode="lite",
|
||||||
read_mark=last_obs_time_mark,
|
read_mark=last_obs_time_mark,
|
||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
# print(f"构建中:self.talking_message_str: {self.talking_message_str}")
|
|
||||||
self.talking_message_str_truncate = build_readable_messages(
|
self.talking_message_str_truncate = build_readable_messages(
|
||||||
messages=self.talking_message,
|
messages=self.talking_message,
|
||||||
timestamp_mode="normal_no_YMD",
|
timestamp_mode="normal_no_YMD",
|
||||||
@@ -281,15 +236,12 @@ class ChattingObservation(Observation):
|
|||||||
truncate=True,
|
truncate=True,
|
||||||
show_actions=True,
|
show_actions=True,
|
||||||
)
|
)
|
||||||
# print(f"构建中:self.talking_message_str_truncate: {self.talking_message_str_truncate}")
|
|
||||||
|
|
||||||
self.person_list = await get_person_id_list(self.talking_message)
|
self.person_list = await get_person_id_list(self.talking_message)
|
||||||
|
|
||||||
# print(f"构建中:self.person_list: {self.person_list}")
|
# logger.debug(
|
||||||
|
# f"Chat {self.chat_id} - 现在聊天内容:{self.talking_message_str}"
|
||||||
logger.debug(
|
# )
|
||||||
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
|
|
||||||
)
|
|
||||||
|
|
||||||
async def has_new_messages_since(self, timestamp: float) -> bool:
|
async def has_new_messages_since(self, timestamp: float) -> bool:
|
||||||
"""检查指定时间戳之后是否有新消息"""
|
"""检查指定时间戳之后是否有新消息"""
|
||||||
|
|||||||
@@ -32,17 +32,3 @@ class WorkingMemoryObservation:
|
|||||||
|
|
||||||
async def observe(self):
|
async def observe(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def to_dict(self) -> dict:
|
|
||||||
"""将观察对象转换为可序列化的字典"""
|
|
||||||
return {
|
|
||||||
"observe_info": self.observe_info,
|
|
||||||
"observe_id": self.observe_id,
|
|
||||||
"last_observe_time": self.last_observe_time,
|
|
||||||
"working_memory": self.working_memory.to_dict()
|
|
||||||
if hasattr(self.working_memory, "to_dict")
|
|
||||||
else str(self.working_memory),
|
|
||||||
"retrieved_working_memory": [
|
|
||||||
item.to_dict() if hasattr(item, "to_dict") else str(item) for item in self.retrieved_working_memory
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user