🤖 自动格式化代码 [skip ci]
This commit is contained in:
@@ -8,7 +8,6 @@ from src.common.logger import get_logger
|
|||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from .base_processor import BaseProcessor
|
from .base_processor import BaseProcessor
|
||||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
||||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||||
@@ -110,10 +109,10 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
content = ""
|
content = ""
|
||||||
try:
|
try:
|
||||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
|
|
||||||
print(f"prompt: {prompt}---------------------------------")
|
print(f"prompt: {prompt}---------------------------------")
|
||||||
print(f"content: {content}---------------------------------")
|
print(f"content: {content}---------------------------------")
|
||||||
|
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
|
logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
|
||||||
return []
|
return []
|
||||||
@@ -138,12 +137,14 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 解析LLM返回的JSON,selected_memory_ids: {selected_memory_ids}, merge_memory: {merge_memory}")
|
logger.debug(
|
||||||
|
f"{self.log_prefix} 解析LLM返回的JSON,selected_memory_ids: {selected_memory_ids}, merge_memory: {merge_memory}"
|
||||||
|
)
|
||||||
|
|
||||||
# 根据selected_memory_ids,调取记忆
|
# 根据selected_memory_ids,调取记忆
|
||||||
memory_str = ""
|
memory_str = ""
|
||||||
selected_ids = set(selected_memory_ids) # 转换为集合以便快速查找
|
selected_ids = set(selected_memory_ids) # 转换为集合以便快速查找
|
||||||
|
|
||||||
# 遍历所有记忆
|
# 遍历所有记忆
|
||||||
for memory in all_memory:
|
for memory in all_memory:
|
||||||
if memory.id in selected_ids:
|
if memory.id in selected_ids:
|
||||||
@@ -187,45 +188,41 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
if not summary_result:
|
if not summary_result:
|
||||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆失败: 没有生成摘要")
|
logger.debug(f"{self.log_prefix} 压缩聊天记忆失败: 没有生成摘要")
|
||||||
return
|
return
|
||||||
|
|
||||||
print(f"compressor_prompt: {obs.compressor_prompt}")
|
print(f"compressor_prompt: {obs.compressor_prompt}")
|
||||||
print(f"summary_result: {summary_result}")
|
print(f"summary_result: {summary_result}")
|
||||||
|
|
||||||
# 修复并解析JSON
|
# 修复并解析JSON
|
||||||
try:
|
try:
|
||||||
fixed_json = repair_json(summary_result)
|
fixed_json = repair_json(summary_result)
|
||||||
summary_data = json.loads(fixed_json)
|
summary_data = json.loads(fixed_json)
|
||||||
|
|
||||||
if not isinstance(summary_data, dict):
|
if not isinstance(summary_data, dict):
|
||||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: 不是有效的JSON对象")
|
logger.error(f"{self.log_prefix} 解析压缩结果失败: 不是有效的JSON对象")
|
||||||
return
|
return
|
||||||
|
|
||||||
theme = summary_data.get("theme", "")
|
theme = summary_data.get("theme", "")
|
||||||
content = summary_data.get("content", "")
|
content = summary_data.get("content", "")
|
||||||
|
|
||||||
if not theme or not content:
|
if not theme or not content:
|
||||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: 缺少必要字段")
|
logger.error(f"{self.log_prefix} 解析压缩结果失败: 缺少必要字段")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 创建新记忆
|
# 创建新记忆
|
||||||
await working_memory.add_memory(
|
await working_memory.add_memory(from_source="chat_compress", summary=content, brief=theme)
|
||||||
from_source="chat_compress",
|
|
||||||
summary=content,
|
|
||||||
brief=theme
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆成功: {theme} - {content}")
|
logger.debug(f"{self.log_prefix} 压缩聊天记忆成功: {theme} - {content}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: {e}")
|
logger.error(f"{self.log_prefix} 解析压缩结果失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
return
|
return
|
||||||
|
|
||||||
# 清理压缩状态
|
# 清理压缩状态
|
||||||
obs.compressor_prompt = ""
|
obs.compressor_prompt = ""
|
||||||
obs.oldest_messages = []
|
obs.oldest_messages = []
|
||||||
obs.oldest_messages_str = ""
|
obs.oldest_messages_str = ""
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 压缩聊天记忆失败: {e}")
|
logger.error(f"{self.log_prefix} 压缩聊天记忆失败: {e}")
|
||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Any, Tuple
|
from typing import Tuple
|
||||||
import time
|
import time
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
@@ -23,7 +23,7 @@ class MemoryItem:
|
|||||||
self.from_source = from_source
|
self.from_source = from_source
|
||||||
self.brief = brief
|
self.brief = brief
|
||||||
self.timestamp = time.time()
|
self.timestamp = time.time()
|
||||||
|
|
||||||
# 记忆内容概括
|
# 记忆内容概括
|
||||||
self.summary = summary
|
self.summary = summary
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from typing import Dict, Any, Type, TypeVar, List, Optional
|
from typing import Dict, TypeVar, List, Optional
|
||||||
import traceback
|
import traceback
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -224,7 +224,7 @@ class MemoryManager:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"生成总结时出错: {str(e)}")
|
logger.error(f"生成总结时出错: {str(e)}")
|
||||||
return default_summary
|
return default_summary
|
||||||
|
|
||||||
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
||||||
"""
|
"""
|
||||||
使单个记忆衰减
|
使单个记忆衰减
|
||||||
@@ -263,7 +263,7 @@ class MemoryManager:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# 获取要删除的项
|
# 获取要删除的项
|
||||||
item = self._id_map[memory_id]
|
self._id_map[memory_id]
|
||||||
|
|
||||||
# 从内存中删除
|
# 从内存中删除
|
||||||
self._memories = [i for i in self._memories if i.id != memory_id]
|
self._memories = [i for i in self._memories if i.id != memory_id]
|
||||||
@@ -376,7 +376,9 @@ class MemoryManager:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 创建新的记忆项
|
# 创建新的记忆项
|
||||||
merged_memory = MemoryItem(summary=merged_data["summary"], from_source=merged_source, brief=merged_data["brief"])
|
merged_memory = MemoryItem(
|
||||||
|
summary=merged_data["summary"], from_source=merged_source, brief=merged_data["brief"]
|
||||||
|
)
|
||||||
|
|
||||||
# 记忆强度取两者最大值
|
# 记忆强度取两者最大值
|
||||||
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
|
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class WorkingMemory:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"自动衰减记忆时出错: {str(e)}")
|
print(f"自动衰减记忆时出错: {str(e)}")
|
||||||
|
|
||||||
async def add_memory(self, summary: Any, from_source: str = "",brief: str = ""):
|
async def add_memory(self, summary: Any, from_source: str = "", brief: str = ""):
|
||||||
"""
|
"""
|
||||||
添加一段记忆到指定聊天
|
添加一段记忆到指定聊天
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
import traceback
|
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
build_readable_messages,
|
build_readable_messages,
|
||||||
@@ -50,6 +49,7 @@ Prompt(
|
|||||||
"chat_summary_private_prompt", # Template for private chat
|
"chat_summary_private_prompt", # Template for private chat
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ChattingObservation(Observation):
|
class ChattingObservation(Observation):
|
||||||
def __init__(self, chat_id):
|
def __init__(self, chat_id):
|
||||||
super().__init__(chat_id)
|
super().__init__(chat_id)
|
||||||
@@ -192,19 +192,13 @@ class ChattingObservation(Observation):
|
|||||||
|
|
||||||
# 构建压缩提示
|
# 构建压缩提示
|
||||||
oldest_messages_str = build_readable_messages(
|
oldest_messages_str = build_readable_messages(
|
||||||
messages=oldest_messages,
|
messages=oldest_messages, timestamp_mode="normal_no_YMD", read_mark=0, show_actions=True
|
||||||
timestamp_mode="normal_no_YMD",
|
|
||||||
read_mark=0,
|
|
||||||
show_actions=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 根据聊天类型选择提示模板
|
# 根据聊天类型选择提示模板
|
||||||
if self.is_group_chat:
|
if self.is_group_chat:
|
||||||
prompt_template_name = "chat_summary_group_prompt"
|
prompt_template_name = "chat_summary_group_prompt"
|
||||||
prompt = await global_prompt_manager.format_prompt(
|
prompt = await global_prompt_manager.format_prompt(prompt_template_name, chat_logs=oldest_messages_str)
|
||||||
prompt_template_name,
|
|
||||||
chat_logs=oldest_messages_str
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
prompt_template_name = "chat_summary_private_prompt"
|
prompt_template_name = "chat_summary_private_prompt"
|
||||||
chat_target_name = "对方"
|
chat_target_name = "对方"
|
||||||
|
|||||||
@@ -31,4 +31,4 @@ class WorkingMemoryObservation:
|
|||||||
return self.retrieved_working_memory
|
return self.retrieved_working_memory
|
||||||
|
|
||||||
async def observe(self):
|
async def observe(self):
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user