better:优化了人格和其他配置文件,更加精简易懂
This commit is contained in:
@@ -91,9 +91,7 @@ class HeartFChatting:
|
||||
|
||||
# 新增:消息计数器和疲惫阈值
|
||||
self._message_count = 0 # 发送的消息计数
|
||||
# 基于exit_focus_threshold动态计算疲惫阈值
|
||||
# 基础值30条,通过exit_focus_threshold调节:threshold越小,越容易疲惫
|
||||
self._message_threshold = max(10, int(30 * global_config.chat.exit_focus_threshold))
|
||||
self._message_threshold = max(10, int(30 * global_config.chat.focus_value))
|
||||
self._fatigue_triggered = False # 是否已触发疲惫退出
|
||||
|
||||
self.action_manager = ActionManager()
|
||||
@@ -127,7 +125,7 @@ class HeartFChatting:
|
||||
self.priority_manager = None
|
||||
|
||||
logger.info(
|
||||
f"{self.log_prefix} HeartFChatting 初始化完成,消息疲惫阈值: {self._message_threshold}条(基于exit_focus_threshold={global_config.chat.exit_focus_threshold}计算,仅在auto模式下生效)"
|
||||
f"{self.log_prefix} HeartFChatting 初始化完成"
|
||||
)
|
||||
|
||||
self.energy_value = 100
|
||||
@@ -195,7 +193,7 @@ class HeartFChatting:
|
||||
|
||||
async def _loopbody(self):
|
||||
if self.loop_mode == "focus":
|
||||
self.energy_value -= 5 * (1 / global_config.chat.exit_focus_threshold)
|
||||
self.energy_value -= 5 * global_config.chat.focus_value
|
||||
if self.energy_value <= 0:
|
||||
self.loop_mode = "normal"
|
||||
return True
|
||||
@@ -211,7 +209,7 @@ class HeartFChatting:
|
||||
filter_bot=True,
|
||||
)
|
||||
|
||||
if len(new_messages_data) > 4 * global_config.chat.auto_focus_threshold:
|
||||
if len(new_messages_data) > 4 * global_config.chat.focus_value:
|
||||
self.loop_mode = "focus"
|
||||
self.energy_value = 100
|
||||
return True
|
||||
|
||||
@@ -39,90 +39,3 @@ class SubHeartflow:
|
||||
async def initialize(self):
|
||||
"""异步初始化方法,创建兴趣流并确定聊天类型"""
|
||||
await self.heart_fc_instance.start()
|
||||
|
||||
# async def _stop_heart_fc_chat(self):
|
||||
# """停止并清理 HeartFChatting 实例"""
|
||||
# if self.heart_fc_instance.running:
|
||||
# logger.info(f"{self.log_prefix} 结束专注聊天...")
|
||||
# try:
|
||||
# await self.heart_fc_instance.shutdown()
|
||||
# except Exception as e:
|
||||
# logger.error(f"{self.log_prefix} 关闭 HeartFChatting 实例时出错: {e}")
|
||||
# logger.error(traceback.format_exc())
|
||||
# else:
|
||||
# logger.info(f"{self.log_prefix} 没有专注聊天实例,无需停止专注聊天")
|
||||
|
||||
# async def _start_heart_fc_chat(self) -> bool:
|
||||
# """启动 HeartFChatting 实例,确保 NormalChat 已停止"""
|
||||
# try:
|
||||
# # 如果任务已完成或不存在,则尝试重新启动
|
||||
# if self.heart_fc_instance._loop_task is None or self.heart_fc_instance._loop_task.done():
|
||||
# logger.info(f"{self.log_prefix} HeartFChatting 实例存在但循环未运行,尝试启动...")
|
||||
# try:
|
||||
# # 添加超时保护
|
||||
# await asyncio.wait_for(self.heart_fc_instance.start(), timeout=15.0)
|
||||
# logger.info(f"{self.log_prefix} HeartFChatting 循环已启动。")
|
||||
# return True
|
||||
# except Exception as e:
|
||||
# logger.error(f"{self.log_prefix} 尝试启动现有 HeartFChatting 循环时出错: {e}")
|
||||
# logger.error(traceback.format_exc())
|
||||
# # 出错时清理实例,准备重新创建
|
||||
# self.heart_fc_instance = None # type: ignore
|
||||
# return False
|
||||
# else:
|
||||
# # 任务正在运行
|
||||
# logger.debug(f"{self.log_prefix} HeartFChatting 已在运行中。")
|
||||
# return True # 已经在运行
|
||||
|
||||
# except Exception as e:
|
||||
# logger.error(f"{self.log_prefix} _start_heart_fc_chat 执行时出错: {e}")
|
||||
# logger.error(traceback.format_exc())
|
||||
# return False
|
||||
|
||||
# def is_in_focus_cooldown(self) -> bool:
|
||||
# """检查是否在focus模式的冷却期内
|
||||
|
||||
# Returns:
|
||||
# bool: 如果在冷却期内返回True,否则返回False
|
||||
# """
|
||||
# if self.last_focus_exit_time == 0:
|
||||
# return False
|
||||
|
||||
# # 基础冷却时间10分钟,受auto_focus_threshold调控
|
||||
# base_cooldown = 10 * 60 # 10分钟转换为秒
|
||||
# cooldown_duration = base_cooldown / global_config.chat.auto_focus_threshold
|
||||
|
||||
# current_time = time.time()
|
||||
# elapsed_since_exit = current_time - self.last_focus_exit_time
|
||||
|
||||
# is_cooling = elapsed_since_exit < cooldown_duration
|
||||
|
||||
# if is_cooling:
|
||||
# remaining_time = cooldown_duration - elapsed_since_exit
|
||||
# remaining_minutes = remaining_time / 60
|
||||
# logger.debug(
|
||||
# f"[{self.log_prefix}] focus冷却中,剩余时间: {remaining_minutes:.1f}分钟 (阈值: {global_config.chat.auto_focus_threshold})"
|
||||
# )
|
||||
|
||||
# return is_cooling
|
||||
|
||||
# def get_cooldown_progress(self) -> float:
|
||||
# """获取冷却进度,返回0-1之间的值
|
||||
|
||||
# Returns:
|
||||
# float: 0表示刚开始冷却,1表示冷却完成
|
||||
# """
|
||||
# if self.last_focus_exit_time == 0:
|
||||
# return 1.0 # 没有冷却,返回1表示完全恢复
|
||||
|
||||
# # 基础冷却时间10分钟,受auto_focus_threshold调控
|
||||
# base_cooldown = 10 * 60 # 10分钟转换为秒
|
||||
# cooldown_duration = base_cooldown / global_config.chat.auto_focus_threshold
|
||||
|
||||
# current_time = time.time()
|
||||
# elapsed_since_exit = current_time - self.last_focus_exit_time
|
||||
|
||||
# if elapsed_since_exit >= cooldown_duration:
|
||||
# return 1.0 # 冷却完成
|
||||
|
||||
# return elapsed_since_exit / cooldown_duration
|
||||
|
||||
@@ -436,72 +436,72 @@ class ActionModifier:
|
||||
logger.debug(f"{self.log_prefix}动作 {action_name} 未匹配到任何关键词: {activation_keywords}")
|
||||
return False
|
||||
|
||||
async def analyze_loop_actions(self, history_loop: List[CycleDetail]) -> List[tuple[str, str]]:
|
||||
"""分析最近的循环内容并决定动作的移除
|
||||
# async def analyze_loop_actions(self, history_loop: List[CycleDetail]) -> List[tuple[str, str]]:
|
||||
# """分析最近的循环内容并决定动作的移除
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, str]]: 包含要删除的动作及原因的元组列表
|
||||
[("action3", "some reason")]
|
||||
"""
|
||||
removals = []
|
||||
# Returns:
|
||||
# List[Tuple[str, str]]: 包含要删除的动作及原因的元组列表
|
||||
# [("action3", "some reason")]
|
||||
# """
|
||||
# removals = []
|
||||
|
||||
# 获取最近10次循环
|
||||
recent_cycles = history_loop[-10:] if len(history_loop) > 10 else history_loop
|
||||
if not recent_cycles:
|
||||
return removals
|
||||
# # 获取最近10次循环
|
||||
# recent_cycles = history_loop[-10:] if len(history_loop) > 10 else history_loop
|
||||
# if not recent_cycles:
|
||||
# return removals
|
||||
|
||||
reply_sequence = [] # 记录最近的动作序列
|
||||
# reply_sequence = [] # 记录最近的动作序列
|
||||
|
||||
for cycle in recent_cycles:
|
||||
action_result = cycle.loop_plan_info.get("action_result", {})
|
||||
action_type = action_result.get("action_type", "unknown")
|
||||
reply_sequence.append(action_type == "reply")
|
||||
# for cycle in recent_cycles:
|
||||
# action_result = cycle.loop_plan_info.get("action_result", {})
|
||||
# action_type = action_result.get("action_type", "unknown")
|
||||
# reply_sequence.append(action_type == "reply")
|
||||
|
||||
# 计算连续回复的相关阈值
|
||||
# # 计算连续回复的相关阈值
|
||||
|
||||
max_reply_num = int(global_config.focus_chat.consecutive_replies * 3.2)
|
||||
sec_thres_reply_num = int(global_config.focus_chat.consecutive_replies * 2)
|
||||
one_thres_reply_num = int(global_config.focus_chat.consecutive_replies * 1.5)
|
||||
# max_reply_num = int(global_config.focus_chat.consecutive_replies * 3.2)
|
||||
# sec_thres_reply_num = int(global_config.focus_chat.consecutive_replies * 2)
|
||||
# one_thres_reply_num = int(global_config.focus_chat.consecutive_replies * 1.5)
|
||||
|
||||
# 获取最近max_reply_num次的reply状态
|
||||
if len(reply_sequence) >= max_reply_num:
|
||||
last_max_reply_num = reply_sequence[-max_reply_num:]
|
||||
else:
|
||||
last_max_reply_num = reply_sequence[:]
|
||||
# # 获取最近max_reply_num次的reply状态
|
||||
# if len(reply_sequence) >= max_reply_num:
|
||||
# last_max_reply_num = reply_sequence[-max_reply_num:]
|
||||
# else:
|
||||
# last_max_reply_num = reply_sequence[:]
|
||||
|
||||
# 详细打印阈值和序列信息,便于调试
|
||||
logger.info(
|
||||
f"连续回复阈值: max={max_reply_num}, sec={sec_thres_reply_num}, one={one_thres_reply_num},"
|
||||
f"最近reply序列: {last_max_reply_num}"
|
||||
)
|
||||
# print(f"consecutive_replies: {consecutive_replies}")
|
||||
# # 详细打印阈值和序列信息,便于调试
|
||||
# logger.info(
|
||||
# f"连续回复阈值: max={max_reply_num}, sec={sec_thres_reply_num}, one={one_thres_reply_num},"
|
||||
# f"最近reply序列: {last_max_reply_num}"
|
||||
# )
|
||||
# # print(f"consecutive_replies: {consecutive_replies}")
|
||||
|
||||
# 根据最近的reply情况决定是否移除reply动作
|
||||
if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num):
|
||||
# 如果最近max_reply_num次都是reply,直接移除
|
||||
reason = f"连续回复过多(最近{len(last_max_reply_num)}次全是reply,超过阈值{max_reply_num})"
|
||||
removals.append(("reply", reason))
|
||||
# reply_count = len(last_max_reply_num) - no_reply_count
|
||||
elif len(last_max_reply_num) >= sec_thres_reply_num and all(last_max_reply_num[-sec_thres_reply_num:]):
|
||||
# 如果最近sec_thres_reply_num次都是reply,40%概率移除
|
||||
removal_probability = 0.4 / global_config.focus_chat.consecutive_replies
|
||||
if random.random() < removal_probability:
|
||||
reason = (
|
||||
f"连续回复较多(最近{sec_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
)
|
||||
removals.append(("reply", reason))
|
||||
elif len(last_max_reply_num) >= one_thres_reply_num and all(last_max_reply_num[-one_thres_reply_num:]):
|
||||
# 如果最近one_thres_reply_num次都是reply,20%概率移除
|
||||
removal_probability = 0.2 / global_config.focus_chat.consecutive_replies
|
||||
if random.random() < removal_probability:
|
||||
reason = (
|
||||
f"连续回复检测(最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
)
|
||||
removals.append(("reply", reason))
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常")
|
||||
# # 根据最近的reply情况决定是否移除reply动作
|
||||
# if len(last_max_reply_num) >= max_reply_num and all(last_max_reply_num):
|
||||
# # 如果最近max_reply_num次都是reply,直接移除
|
||||
# reason = f"连续回复过多(最近{len(last_max_reply_num)}次全是reply,超过阈值{max_reply_num})"
|
||||
# removals.append(("reply", reason))
|
||||
# # reply_count = len(last_max_reply_num) - no_reply_count
|
||||
# elif len(last_max_reply_num) >= sec_thres_reply_num and all(last_max_reply_num[-sec_thres_reply_num:]):
|
||||
# # 如果最近sec_thres_reply_num次都是reply,40%概率移除
|
||||
# removal_probability = 0.4 / global_config.focus_chat.consecutive_replies
|
||||
# if random.random() < removal_probability:
|
||||
# reason = (
|
||||
# f"连续回复较多(最近{sec_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
# )
|
||||
# removals.append(("reply", reason))
|
||||
# elif len(last_max_reply_num) >= one_thres_reply_num and all(last_max_reply_num[-one_thres_reply_num:]):
|
||||
# # 如果最近one_thres_reply_num次都是reply,20%概率移除
|
||||
# removal_probability = 0.2 / global_config.focus_chat.consecutive_replies
|
||||
# if random.random() < removal_probability:
|
||||
# reason = (
|
||||
# f"连续回复检测(最近{one_thres_reply_num}次全是reply,{removal_probability:.2f}概率移除,触发移除)"
|
||||
# )
|
||||
# removals.append(("reply", reason))
|
||||
# else:
|
||||
# logger.debug(f"{self.log_prefix}连续回复检测:无需移除reply动作,最近回复模式正常")
|
||||
|
||||
return removals
|
||||
# return removals
|
||||
|
||||
# def get_available_actions_count(self, mode: str = "focus") -> int:
|
||||
# """获取当前可用动作数量(排除默认的no_action)"""
|
||||
|
||||
@@ -10,6 +10,7 @@ from datetime import datetime
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.individuality.individuality import get_individuality
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.chat.message_receive.message import UserInfo, Seg, MessageRecv, MessageSending
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
@@ -561,7 +562,6 @@ class DefaultReplyer:
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
person_info_manager = get_person_info_manager()
|
||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
extra_info_block = reply_data.get("extra_info", "") or reply_data.get("extra_info_block", "")
|
||||
@@ -661,31 +661,7 @@ class DefaultReplyer:
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
# logger.debug("开始构建 focus prompt")
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
short_impression = await person_info_manager.get_value(bot_person_id, "short_impression")
|
||||
# 解析字符串形式的Python列表
|
||||
try:
|
||||
if isinstance(short_impression, str) and short_impression.strip():
|
||||
short_impression = ast.literal_eval(short_impression)
|
||||
elif not short_impression:
|
||||
logger.warning("short_impression为空,使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logger.error(f"解析short_impression失败: {e}, 原始值: {short_impression}")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
# 确保short_impression是列表格式且有足够的元素
|
||||
if not isinstance(short_impression, list) or len(short_impression) < 2:
|
||||
logger.warning(f"short_impression格式不正确: {short_impression}, 使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
personality = short_impression[0]
|
||||
identity = short_impression[1]
|
||||
prompt_personality = f"{personality},{identity}"
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
identity_block = get_individuality().get_personality_block()
|
||||
|
||||
moderation_prompt_block = (
|
||||
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
|
||||
@@ -732,24 +708,24 @@ class DefaultReplyer:
|
||||
"chat_target_private2", sender_name=chat_target_name
|
||||
)
|
||||
|
||||
target_user_id = ""
|
||||
if sender:
|
||||
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
|
||||
|
||||
|
||||
# 根据配置选择使用哪种 prompt 构建模式
|
||||
if global_config.chat.use_s4u_prompt_mode:
|
||||
if global_config.chat.use_s4u_prompt_mode and person_id:
|
||||
# 使用 s4u 对话构建模式:分离当前对话对象和其他对话
|
||||
|
||||
# 获取目标用户ID用于消息过滤
|
||||
target_user_id = ""
|
||||
if sender:
|
||||
# 根据sender通过person_info_manager反向查找person_id,再获取user_id
|
||||
person_id = person_info_manager.get_person_id_by_person_name(sender)
|
||||
if person_id:
|
||||
# 通过person_info_manager获取person_id对应的user_id字段
|
||||
try:
|
||||
user_id_value = await person_info_manager.get_value(person_id, "user_id")
|
||||
if user_id_value:
|
||||
target_user_id = str(user_id_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
||||
target_user_id = ""
|
||||
try:
|
||||
user_id_value = await person_info_manager.get_value(person_id, "user_id")
|
||||
if user_id_value:
|
||||
target_user_id = str(user_id_value)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法从person_id {person_id} 获取user_id: {e}")
|
||||
target_user_id = ""
|
||||
|
||||
|
||||
# 构建分离的对话 prompt
|
||||
core_dialogue_prompt, background_dialogue_prompt = self.build_s4u_chat_history_prompts(
|
||||
@@ -811,8 +787,6 @@ class DefaultReplyer:
|
||||
) -> str:
|
||||
chat_stream = self.chat_stream
|
||||
chat_id = chat_stream.stream_id
|
||||
person_info_manager = get_person_info_manager()
|
||||
bot_person_id = person_info_manager.get_person_id("system", "bot_id")
|
||||
is_group_chat = bool(chat_stream.group_info)
|
||||
|
||||
reply_to = reply_data.get("reply_to", "none")
|
||||
@@ -844,29 +818,7 @@ class DefaultReplyer:
|
||||
|
||||
time_block = f"当前时间:{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}"
|
||||
|
||||
bot_name = global_config.bot.nickname
|
||||
if global_config.bot.alias_names:
|
||||
bot_nickname = f",也有人叫你{','.join(global_config.bot.alias_names)}"
|
||||
else:
|
||||
bot_nickname = ""
|
||||
short_impression = await person_info_manager.get_value(bot_person_id, "short_impression")
|
||||
try:
|
||||
if isinstance(short_impression, str) and short_impression.strip():
|
||||
short_impression = ast.literal_eval(short_impression)
|
||||
elif not short_impression:
|
||||
logger.warning("short_impression为空,使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
except (ValueError, SyntaxError) as e:
|
||||
logger.error(f"解析short_impression失败: {e}, 原始值: {short_impression}")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
# 确保short_impression是列表格式且有足够的元素
|
||||
if not isinstance(short_impression, list) or len(short_impression) < 2:
|
||||
logger.warning(f"short_impression格式不正确: {short_impression}, 使用默认值")
|
||||
short_impression = ["友好活泼", "人类"]
|
||||
personality = short_impression[0]
|
||||
identity = short_impression[1]
|
||||
prompt_personality = f"{personality},{identity}"
|
||||
identity_block = f"你的名字是{bot_name}{bot_nickname},你{prompt_personality}:"
|
||||
identity_block = get_individuality().get_personality_block()
|
||||
|
||||
moderation_prompt_block = (
|
||||
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"
|
||||
|
||||
@@ -1,84 +0,0 @@
|
||||
from typing import Tuple
|
||||
import time
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
class MemoryItem:
|
||||
"""记忆项类,用于存储单个记忆的所有相关信息"""
|
||||
|
||||
def __init__(self, summary: str, from_source: str = "", brief: str = ""):
|
||||
"""
|
||||
初始化记忆项
|
||||
|
||||
Args:
|
||||
summary: 记忆内容概括
|
||||
from_source: 数据来源
|
||||
brief: 记忆内容主题
|
||||
"""
|
||||
# 生成可读ID:时间戳_随机字符串
|
||||
timestamp = int(time.time())
|
||||
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
|
||||
self.id = f"{timestamp}_{random_str}"
|
||||
self.from_source = from_source
|
||||
self.brief = brief
|
||||
self.timestamp = time.time()
|
||||
|
||||
# 记忆内容概括
|
||||
self.summary = summary
|
||||
|
||||
# 记忆精简次数
|
||||
self.compress_count = 0
|
||||
|
||||
# 记忆提取次数
|
||||
self.retrieval_count = 0
|
||||
|
||||
# 记忆强度 (初始为10)
|
||||
self.memory_strength = 10.0
|
||||
|
||||
# 记忆操作历史记录
|
||||
# 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...]
|
||||
self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)]
|
||||
|
||||
def matches_source(self, source: str) -> bool:
|
||||
"""检查来源是否匹配"""
|
||||
return self.from_source == source
|
||||
|
||||
def increase_strength(self, amount: float) -> None:
|
||||
"""增加记忆强度"""
|
||||
self.memory_strength = min(10.0, self.memory_strength + amount)
|
||||
# 记录操作历史
|
||||
self.record_operation("strengthen")
|
||||
|
||||
def decrease_strength(self, amount: float) -> None:
|
||||
"""减少记忆强度"""
|
||||
self.memory_strength = max(0.1, self.memory_strength - amount)
|
||||
# 记录操作历史
|
||||
self.record_operation("weaken")
|
||||
|
||||
def increase_compress_count(self) -> None:
|
||||
"""增加精简次数并减弱记忆强度"""
|
||||
self.compress_count += 1
|
||||
# 记录操作历史
|
||||
self.record_operation("compress")
|
||||
|
||||
def record_retrieval(self) -> None:
|
||||
"""记录记忆被提取的情况"""
|
||||
self.retrieval_count += 1
|
||||
# 提取后强度翻倍
|
||||
self.memory_strength = min(10.0, self.memory_strength * 2)
|
||||
# 记录操作历史
|
||||
self.record_operation("retrieval")
|
||||
|
||||
def record_operation(self, operation_type: str) -> None:
|
||||
"""记录操作历史"""
|
||||
current_time = time.time()
|
||||
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
|
||||
|
||||
def to_tuple(self) -> Tuple[str, str, float, str]:
|
||||
"""转换为元组格式(为了兼容性)"""
|
||||
return (self.summary, self.from_source, self.timestamp, self.id)
|
||||
|
||||
def is_memory_valid(self) -> bool:
|
||||
"""检查记忆是否有效(强度是否大于等于1)"""
|
||||
return self.memory_strength >= 1.0
|
||||
@@ -1,413 +0,0 @@
|
||||
from typing import Dict, TypeVar, List, Optional
|
||||
import traceback
|
||||
from json_repair import repair_json
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
import json # 添加json模块导入
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
logger = get_logger("working_memory")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class MemoryManager:
|
||||
def __init__(self, chat_id: str):
|
||||
"""
|
||||
初始化工作记忆
|
||||
|
||||
Args:
|
||||
chat_id: 关联的聊天ID,用于标识该工作记忆属于哪个聊天
|
||||
"""
|
||||
# 关联的聊天ID
|
||||
self._chat_id = chat_id
|
||||
|
||||
# 记忆项列表
|
||||
self._memories: List[MemoryItem] = []
|
||||
|
||||
# ID到记忆项的映射
|
||||
self._id_map: Dict[str, MemoryItem] = {}
|
||||
|
||||
self.llm_summarizer = LLMRequest(
|
||||
model=global_config.model.memory,
|
||||
temperature=0.3,
|
||||
request_type="working_memory",
|
||||
)
|
||||
|
||||
@property
|
||||
def chat_id(self) -> str:
|
||||
"""获取关联的聊天ID"""
|
||||
return self._chat_id
|
||||
|
||||
@chat_id.setter
|
||||
def chat_id(self, value: str):
|
||||
"""设置关联的聊天ID"""
|
||||
self._chat_id = value
|
||||
|
||||
def push_item(self, memory_item: MemoryItem) -> str:
|
||||
"""
|
||||
推送一个已创建的记忆项到工作记忆中
|
||||
|
||||
Args:
|
||||
memory_item: 要存储的记忆项
|
||||
|
||||
Returns:
|
||||
记忆项的ID
|
||||
"""
|
||||
# 添加到内存和ID映射
|
||||
self._memories.append(memory_item)
|
||||
self._id_map[memory_item.id] = memory_item
|
||||
|
||||
return memory_item.id
|
||||
|
||||
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
|
||||
"""
|
||||
通过ID获取记忆项
|
||||
|
||||
Args:
|
||||
memory_id: 记忆项ID
|
||||
|
||||
Returns:
|
||||
找到的记忆项,如果不存在则返回None
|
||||
"""
|
||||
memory_item = self._id_map.get(memory_id)
|
||||
if memory_item:
|
||||
# 检查记忆强度,如果小于1则删除
|
||||
if not memory_item.is_memory_valid():
|
||||
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
|
||||
self.delete(memory_id)
|
||||
return None
|
||||
|
||||
return memory_item
|
||||
|
||||
def get_all_items(self) -> List[MemoryItem]:
|
||||
"""获取所有记忆项"""
|
||||
return list(self._id_map.values())
|
||||
|
||||
def find_items(
|
||||
self,
|
||||
source: Optional[str] = None,
|
||||
start_time: Optional[float] = None,
|
||||
end_time: Optional[float] = None,
|
||||
memory_id: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
newest_first: bool = False,
|
||||
min_strength: float = 0.0,
|
||||
) -> List[MemoryItem]:
|
||||
"""
|
||||
按条件查找记忆项
|
||||
|
||||
Args:
|
||||
source: 数据来源
|
||||
start_time: 开始时间戳
|
||||
end_time: 结束时间戳
|
||||
memory_id: 特定记忆项ID
|
||||
limit: 返回结果的最大数量
|
||||
newest_first: 是否按最新优先排序
|
||||
min_strength: 最小记忆强度
|
||||
|
||||
Returns:
|
||||
符合条件的记忆项列表
|
||||
"""
|
||||
# 如果提供了特定ID,直接查找
|
||||
if memory_id:
|
||||
item = self.get_by_id(memory_id)
|
||||
return [item] if item else []
|
||||
|
||||
results = []
|
||||
|
||||
# 获取所有项目
|
||||
items = self._memories
|
||||
|
||||
# 如果需要最新优先,则反转遍历顺序
|
||||
if newest_first:
|
||||
items_to_check = list(reversed(items))
|
||||
else:
|
||||
items_to_check = items
|
||||
|
||||
# 遍历项目
|
||||
for item in items_to_check:
|
||||
# 检查来源是否匹配
|
||||
if source is not None and not item.matches_source(source):
|
||||
continue
|
||||
|
||||
# 检查时间范围
|
||||
if start_time is not None and item.timestamp < start_time:
|
||||
continue
|
||||
if end_time is not None and item.timestamp > end_time:
|
||||
continue
|
||||
|
||||
# 检查记忆强度
|
||||
if min_strength > 0 and item.memory_strength < min_strength:
|
||||
continue
|
||||
|
||||
# 所有条件都满足,添加到结果中
|
||||
results.append(item)
|
||||
|
||||
# 如果达到限制数量,提前返回
|
||||
if limit is not None and len(results) >= limit:
|
||||
return results
|
||||
|
||||
return results
|
||||
|
||||
async def summarize_memory_item(self, content: str) -> Dict[str, str]:
|
||||
"""
|
||||
使用LLM总结记忆项
|
||||
|
||||
Args:
|
||||
content: 需要总结的内容
|
||||
|
||||
Returns:
|
||||
包含brief和summary的字典
|
||||
"""
|
||||
prompt = f"""请对以下内容进行总结,总结成记忆,输出两部分:
|
||||
1. 记忆内容主题(精简,20字以内):让用户可以一眼看出记忆内容是什么
|
||||
2. 记忆内容概括:对内容进行概括,保留重要信息,200字以内
|
||||
|
||||
内容:
|
||||
{content}
|
||||
|
||||
请按以下JSON格式输出:
|
||||
{{
|
||||
"brief": "记忆内容主题",
|
||||
"summary": "记忆内容概括"
|
||||
}}
|
||||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||
"""
|
||||
default_summary = {
|
||||
"brief": "主题未知的记忆",
|
||||
"summary": "无法概括的记忆内容",
|
||||
}
|
||||
|
||||
try:
|
||||
# 调用LLM生成总结
|
||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||
|
||||
# 使用repair_json解析响应
|
||||
try:
|
||||
# 使用repair_json修复JSON格式
|
||||
fixed_json_string = repair_json(response)
|
||||
|
||||
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||
if isinstance(fixed_json_string, str):
|
||||
try:
|
||||
json_result = json.loads(fixed_json_string)
|
||||
except json.JSONDecodeError as decode_error:
|
||||
logger.error(f"JSON解析错误: {str(decode_error)}")
|
||||
return default_summary
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
json_result = fixed_json_string
|
||||
|
||||
# 进行额外的类型检查
|
||||
if not isinstance(json_result, dict):
|
||||
logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
|
||||
return default_summary
|
||||
|
||||
# 确保所有必要字段都存在且类型正确
|
||||
if "brief" not in json_result or not isinstance(json_result["brief"], str):
|
||||
json_result["brief"] = "主题未知的记忆"
|
||||
|
||||
if "summary" not in json_result or not isinstance(json_result["summary"], str):
|
||||
json_result["summary"] = "无法概括的记忆内容"
|
||||
|
||||
return json_result
|
||||
|
||||
except Exception as json_error:
|
||||
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
|
||||
return default_summary
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成总结时出错: {str(e)}")
|
||||
return default_summary
|
||||
|
||||
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
|
||||
"""
|
||||
使单个记忆衰减
|
||||
|
||||
Args:
|
||||
memory_id: 记忆ID
|
||||
decay_factor: 衰减因子(0-1之间)
|
||||
|
||||
Returns:
|
||||
是否成功衰减
|
||||
"""
|
||||
memory_item = self.get_by_id(memory_id)
|
||||
if not memory_item:
|
||||
return False
|
||||
|
||||
# 计算衰减量(当前强度 * (1-衰减因子))
|
||||
old_strength = memory_item.memory_strength
|
||||
decay_amount = old_strength * (1 - decay_factor)
|
||||
|
||||
# 更新强度
|
||||
memory_item.memory_strength = decay_amount
|
||||
|
||||
return True
|
||||
|
||||
def delete(self, memory_id: str) -> bool:
|
||||
"""
|
||||
删除指定ID的记忆项
|
||||
|
||||
Args:
|
||||
memory_id: 要删除的记忆项ID
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
if memory_id not in self._id_map:
|
||||
return False
|
||||
|
||||
# 获取要删除的项
|
||||
self._id_map[memory_id]
|
||||
|
||||
# 从内存中删除
|
||||
self._memories = [i for i in self._memories if i.id != memory_id]
|
||||
|
||||
# 从ID映射中删除
|
||||
del self._id_map[memory_id]
|
||||
|
||||
return True
|
||||
|
||||
def clear(self) -> None:
|
||||
"""清除所有记忆"""
|
||||
self._memories.clear()
|
||||
self._id_map.clear()
|
||||
|
||||
async def merge_memories(
|
||||
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
|
||||
) -> MemoryItem:
|
||||
"""
|
||||
合并两个记忆项
|
||||
|
||||
Args:
|
||||
memory_id1: 第一个记忆项ID
|
||||
memory_id2: 第二个记忆项ID
|
||||
reason: 合并原因
|
||||
delete_originals: 是否删除原始记忆,默认为True
|
||||
|
||||
Returns:
|
||||
合并后的记忆项
|
||||
"""
|
||||
# 获取两个记忆项
|
||||
memory_item1 = self.get_by_id(memory_id1)
|
||||
memory_item2 = self.get_by_id(memory_id2)
|
||||
|
||||
if not memory_item1 or not memory_item2:
|
||||
raise ValueError("无法找到指定的记忆项")
|
||||
|
||||
# 构建合并提示
|
||||
prompt = f"""
|
||||
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
|
||||
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
|
||||
|
||||
合并原因:{reason}
|
||||
|
||||
记忆1主题:{memory_item1.brief}
|
||||
记忆1内容:{memory_item1.summary}
|
||||
|
||||
记忆2主题:{memory_item2.brief}
|
||||
记忆2内容:{memory_item2.summary}
|
||||
|
||||
请按以下JSON格式输出合并结果:
|
||||
{{
|
||||
"brief": "合并后的主题(20字以内)",
|
||||
"summary": "合并后的内容概括(200字以内)"
|
||||
}}
|
||||
请确保输出是有效的JSON格式,不要添加任何额外的说明或解释。
|
||||
"""
|
||||
|
||||
# 默认合并结果
|
||||
default_merged = {
|
||||
"brief": f"合并:{memory_item1.brief} + {memory_item2.brief}",
|
||||
"summary": f"合并的记忆:{memory_item1.summary}\n{memory_item2.summary}",
|
||||
}
|
||||
|
||||
try:
|
||||
# 调用LLM合并记忆
|
||||
response, _ = await self.llm_summarizer.generate_response_async(prompt)
|
||||
|
||||
# 处理LLM返回的合并结果
|
||||
try:
|
||||
# 修复JSON格式
|
||||
fixed_json_string = repair_json(response)
|
||||
|
||||
# 将修复后的字符串解析为Python对象
|
||||
if isinstance(fixed_json_string, str):
|
||||
try:
|
||||
merged_data = json.loads(fixed_json_string)
|
||||
except json.JSONDecodeError as decode_error:
|
||||
logger.error(f"JSON解析错误: {str(decode_error)}")
|
||||
merged_data = default_merged
|
||||
else:
|
||||
# 如果repair_json直接返回了字典对象,直接使用
|
||||
merged_data = fixed_json_string
|
||||
|
||||
# 确保是字典类型
|
||||
if not isinstance(merged_data, dict):
|
||||
logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
|
||||
merged_data = default_merged
|
||||
|
||||
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
|
||||
merged_data["brief"] = default_merged["brief"]
|
||||
|
||||
if "summary" not in merged_data or not isinstance(merged_data["summary"], str):
|
||||
merged_data["summary"] = default_merged["summary"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
merged_data = default_merged
|
||||
except Exception as e:
|
||||
logger.error(f"合并记忆调用LLM出错: {str(e)}")
|
||||
traceback.print_exc()
|
||||
merged_data = default_merged
|
||||
|
||||
# 创建新的记忆项
|
||||
# 取两个记忆项中更强的来源
|
||||
merged_source = (
|
||||
memory_item1.from_source
|
||||
if memory_item1.memory_strength >= memory_item2.memory_strength
|
||||
else memory_item2.from_source
|
||||
)
|
||||
|
||||
# 创建新的记忆项
|
||||
merged_memory = MemoryItem(
|
||||
summary=merged_data["summary"], from_source=merged_source, brief=merged_data["brief"]
|
||||
)
|
||||
|
||||
# 记忆强度取两者最大值
|
||||
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
|
||||
|
||||
# 添加到存储中
|
||||
self.push_item(merged_memory)
|
||||
|
||||
# 如果需要,删除原始记忆
|
||||
if delete_originals:
|
||||
self.delete(memory_id1)
|
||||
self.delete(memory_id2)
|
||||
|
||||
return merged_memory
|
||||
|
||||
def delete_earliest_memory(self) -> bool:
|
||||
"""
|
||||
删除最早的记忆项
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
# 获取所有记忆项
|
||||
all_memories = self.get_all_items()
|
||||
|
||||
if not all_memories:
|
||||
return False
|
||||
|
||||
# 按时间戳排序,找到最早的记忆项
|
||||
earliest_memory = min(all_memories, key=lambda item: item.timestamp)
|
||||
|
||||
# 删除最早的记忆项
|
||||
return self.delete(earliest_memory.id)
|
||||
@@ -1,156 +0,0 @@
|
||||
from typing import List, Any, Optional
|
||||
import asyncio
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
||||
from src.config.config import global_config
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务
|
||||
|
||||
|
||||
class WorkingMemory:
|
||||
"""
|
||||
工作记忆,负责协调和运作记忆
|
||||
从属于特定的流,用chat_id来标识
|
||||
"""
|
||||
|
||||
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
|
||||
"""
|
||||
初始化工作记忆管理器
|
||||
|
||||
Args:
|
||||
max_memories_per_chat: 每个聊天的最大记忆数量
|
||||
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
|
||||
"""
|
||||
self.memory_manager = MemoryManager(chat_id)
|
||||
|
||||
# 记忆容量上限
|
||||
self.max_memories_per_chat = max_memories_per_chat
|
||||
|
||||
# 自动衰减间隔
|
||||
self.auto_decay_interval = auto_decay_interval
|
||||
|
||||
# 衰减任务
|
||||
self.decay_task = None
|
||||
|
||||
# 只有在工作记忆处理器启用时才启动自动衰减任务
|
||||
if global_config.focus_chat_processor.working_memory_processor:
|
||||
self._start_auto_decay()
|
||||
else:
|
||||
logger.debug(f"工作记忆处理器已禁用,跳过启动自动衰减任务 (chat_id: {chat_id})")
|
||||
|
||||
def _start_auto_decay(self):
|
||||
"""启动自动衰减任务"""
|
||||
if self.decay_task is None:
|
||||
self.decay_task = asyncio.create_task(self._auto_decay_loop())
|
||||
|
||||
async def _auto_decay_loop(self):
|
||||
"""自动衰减循环"""
|
||||
while True:
|
||||
await asyncio.sleep(self.auto_decay_interval)
|
||||
try:
|
||||
await self.decay_all_memories()
|
||||
except Exception as e:
|
||||
print(f"自动衰减记忆时出错: {str(e)}")
|
||||
|
||||
async def add_memory(self, summary: Any, from_source: str = "", brief: str = ""):
|
||||
"""
|
||||
添加一段记忆到指定聊天
|
||||
|
||||
Args:
|
||||
summary: 记忆内容
|
||||
from_source: 数据来源
|
||||
|
||||
Returns:
|
||||
记忆项
|
||||
"""
|
||||
# 如果是字符串类型,生成总结
|
||||
|
||||
memory = MemoryItem(summary, from_source, brief)
|
||||
|
||||
# 添加到管理器
|
||||
self.memory_manager.push_item(memory)
|
||||
|
||||
# 如果超过最大记忆数量,删除最早的记忆
|
||||
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
|
||||
self.remove_earliest_memory()
|
||||
|
||||
return memory
|
||||
|
||||
def remove_earliest_memory(self):
|
||||
"""
|
||||
删除最早的记忆
|
||||
"""
|
||||
return self.memory_manager.delete_earliest_memory()
|
||||
|
||||
async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]:
|
||||
"""
|
||||
检索记忆
|
||||
|
||||
Args:
|
||||
chat_id: 聊天ID
|
||||
memory_id: 记忆ID
|
||||
|
||||
Returns:
|
||||
检索到的记忆项,如果不存在则返回None
|
||||
"""
|
||||
memory_item = self.memory_manager.get_by_id(memory_id)
|
||||
if memory_item:
|
||||
memory_item.retrieval_count += 1
|
||||
memory_item.increase_strength(5)
|
||||
return memory_item
|
||||
return None
|
||||
|
||||
async def decay_all_memories(self, decay_factor: float = 0.5):
|
||||
"""
|
||||
对所有聊天的所有记忆进行衰减
|
||||
衰减:对记忆进行refine压缩,强度会变为原先的0.5
|
||||
|
||||
Args:
|
||||
decay_factor: 衰减因子(0-1之间)
|
||||
"""
|
||||
logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}")
|
||||
|
||||
all_memories = self.memory_manager.get_all_items()
|
||||
|
||||
for memory_item in all_memories:
|
||||
# 如果压缩完小于1会被删除
|
||||
memory_id = memory_item.id
|
||||
self.memory_manager.decay_memory(memory_id, decay_factor)
|
||||
if memory_item.memory_strength < 1:
|
||||
self.memory_manager.delete(memory_id)
|
||||
continue
|
||||
# 计算衰减量
|
||||
# if memory_item.memory_strength < 5:
|
||||
# await self.memory_manager.refine_memory(
|
||||
# memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
|
||||
# )
|
||||
|
||||
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
|
||||
"""合并记忆
|
||||
|
||||
Args:
|
||||
memory_str: 记忆内容
|
||||
"""
|
||||
return await self.memory_manager.merge_memories(
|
||||
memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""关闭管理器,停止所有任务"""
|
||||
if self.decay_task and not self.decay_task.done():
|
||||
self.decay_task.cancel()
|
||||
try:
|
||||
await self.decay_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
def get_all_memories(self) -> List[MemoryItem]:
|
||||
"""
|
||||
获取所有记忆项目
|
||||
|
||||
Returns:
|
||||
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
|
||||
"""
|
||||
return self.memory_manager.get_all_items()
|
||||
@@ -1,261 +0,0 @@
|
||||
from src.chat.focus_chat.observation.chatting_observation import ChattingObservation
|
||||
from src.chat.focus_chat.observation.observation import Observation
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
import time
|
||||
import traceback
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from typing import List
|
||||
from src.chat.focus_chat.observation.working_observation import WorkingMemoryObservation
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from src.chat.focus_chat.info.info_base import InfoBase
|
||||
from json_repair import repair_json
|
||||
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
logger = get_logger("processor")
|
||||
|
||||
|
||||
def init_prompt():
|
||||
memory_proces_prompt = """
|
||||
你的名字是{bot_name}
|
||||
|
||||
现在是{time_now},你正在上网,和qq群里的网友们聊天,以下是正在进行的聊天内容:
|
||||
{chat_observe_info}
|
||||
|
||||
以下是你已经总结的记忆摘要,你可以调取这些记忆查看内容来帮助你聊天,不要一次调取太多记忆,最多调取3个左右记忆:
|
||||
{memory_str}
|
||||
|
||||
观察聊天内容和已经总结的记忆,思考如果有相近的记忆,请合并记忆,输出merge_memory,
|
||||
合并记忆的格式为[["id1", "id2"], ["id3", "id4"],...],你可以进行多组合并,但是每组合并只能有两个记忆id,不要输出其他内容
|
||||
|
||||
请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆,以JSON格式输出,格式如下:
|
||||
```json
|
||||
{{
|
||||
"selected_memory_ids": ["id1", "id2", ...]
|
||||
"merge_memory": [["id1", "id2"], ["id3", "id4"],...]
|
||||
}}
|
||||
```
|
||||
"""
|
||||
Prompt(memory_proces_prompt, "prompt_memory_proces")
|
||||
|
||||
|
||||
class WorkingMemoryProcessor:
|
||||
log_prefix = "工作记忆"
|
||||
|
||||
def __init__(self, subheartflow_id: str):
|
||||
self.subheartflow_id = subheartflow_id
|
||||
|
||||
self.llm_model = LLMRequest(
|
||||
model=global_config.model.planner,
|
||||
request_type="focus.processor.working_memory",
|
||||
)
|
||||
|
||||
name = get_chat_manager().get_stream_name(self.subheartflow_id)
|
||||
self.log_prefix = f"[{name}] "
|
||||
|
||||
async def process_info(self, observations: List[Observation] = None, *infos) -> List[InfoBase]:
|
||||
"""处理信息对象
|
||||
|
||||
Args:
|
||||
*infos: 可变数量的InfoBase类型的信息对象
|
||||
|
||||
Returns:
|
||||
List[InfoBase]: 处理后的结构化信息列表
|
||||
"""
|
||||
working_memory = None
|
||||
chat_info = ""
|
||||
chat_obs = None
|
||||
try:
|
||||
for observation in observations:
|
||||
if isinstance(observation, WorkingMemoryObservation):
|
||||
working_memory = observation.get_observe_info()
|
||||
if isinstance(observation, ChattingObservation):
|
||||
chat_info = observation.get_observe_info()
|
||||
chat_obs = observation
|
||||
# 检查是否有待压缩内容
|
||||
if chat_obs and chat_obs.compressor_prompt:
|
||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆")
|
||||
await self.compress_chat_memory(working_memory, chat_obs)
|
||||
|
||||
# 检查working_memory是否为None
|
||||
if working_memory is None:
|
||||
logger.debug(f"{self.log_prefix} 没有找到工作记忆观察,跳过处理")
|
||||
return []
|
||||
|
||||
all_memory = working_memory.get_all_memories()
|
||||
if not all_memory:
|
||||
logger.debug(f"{self.log_prefix} 目前没有工作记忆,跳过提取")
|
||||
return []
|
||||
|
||||
memory_prompts = []
|
||||
for memory in all_memory:
|
||||
memory_id = memory.id
|
||||
memory_brief = memory.brief
|
||||
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
|
||||
memory_prompts.append(memory_single_prompt)
|
||||
|
||||
memory_choose_str = "".join(memory_prompts)
|
||||
|
||||
# 使用提示模板进行处理
|
||||
prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
|
||||
bot_name=global_config.bot.nickname,
|
||||
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||
chat_observe_info=chat_info,
|
||||
memory_str=memory_choose_str,
|
||||
)
|
||||
|
||||
# 调用LLM处理记忆
|
||||
content = ""
|
||||
try:
|
||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||
|
||||
# print(f"prompt: {prompt}---------------------------------")
|
||||
# print(f"content: {content}---------------------------------")
|
||||
|
||||
if not content:
|
||||
logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
|
||||
return []
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
# 解析LLM返回的JSON
|
||||
try:
|
||||
result = repair_json(content)
|
||||
if isinstance(result, str):
|
||||
result = json.loads(result)
|
||||
if not isinstance(result, dict):
|
||||
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败,结果不是字典类型: {type(result)}")
|
||||
return []
|
||||
|
||||
selected_memory_ids = result.get("selected_memory_ids", [])
|
||||
merge_memory = result.get("merge_memory", [])
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
logger.debug(
|
||||
f"{self.log_prefix} 解析LLM返回的JSON,selected_memory_ids: {selected_memory_ids}, merge_memory: {merge_memory}"
|
||||
)
|
||||
|
||||
# 根据selected_memory_ids,调取记忆
|
||||
memory_str = ""
|
||||
selected_ids = set(selected_memory_ids) # 转换为集合以便快速查找
|
||||
|
||||
# 遍历所有记忆
|
||||
for memory in all_memory:
|
||||
if memory.id in selected_ids:
|
||||
# 选中的记忆显示详细内容
|
||||
memory = await working_memory.retrieve_memory(memory.id)
|
||||
if memory:
|
||||
memory_str += f"{memory.summary}\n"
|
||||
else:
|
||||
# 未选中的记忆显示梗概
|
||||
memory_str += f"{memory.brief}\n"
|
||||
|
||||
working_memory_info = WorkingMemoryInfo()
|
||||
if memory_str:
|
||||
working_memory_info.add_working_memory(memory_str)
|
||||
logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
|
||||
else:
|
||||
logger.debug(f"{self.log_prefix} 没有找到工作记忆")
|
||||
|
||||
if merge_memory:
|
||||
for merge_pairs in merge_memory:
|
||||
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
|
||||
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
|
||||
if memory1 and memory2:
|
||||
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
|
||||
|
||||
return [working_memory_info]
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return []
|
||||
|
||||
async def compress_chat_memory(self, working_memory: WorkingMemory, obs: ChattingObservation):
|
||||
"""压缩聊天记忆
|
||||
|
||||
Args:
|
||||
working_memory: 工作记忆对象
|
||||
obs: 聊天观察对象
|
||||
"""
|
||||
# 检查working_memory是否为None
|
||||
if working_memory is None:
|
||||
logger.warning(f"{self.log_prefix} 工作记忆对象为None,无法压缩聊天记忆")
|
||||
return
|
||||
|
||||
try:
|
||||
summary_result, _ = await self.llm_model.generate_response_async(obs.compressor_prompt)
|
||||
if not summary_result:
|
||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆失败: 没有生成摘要")
|
||||
return
|
||||
|
||||
print(f"compressor_prompt: {obs.compressor_prompt}")
|
||||
print(f"summary_result: {summary_result}")
|
||||
|
||||
# 修复并解析JSON
|
||||
try:
|
||||
fixed_json = repair_json(summary_result)
|
||||
summary_data = json.loads(fixed_json)
|
||||
|
||||
if not isinstance(summary_data, dict):
|
||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: 不是有效的JSON对象")
|
||||
return
|
||||
|
||||
theme = summary_data.get("theme", "")
|
||||
content = summary_data.get("content", "")
|
||||
|
||||
if not theme or not content:
|
||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: 缺少必要字段")
|
||||
return
|
||||
|
||||
# 创建新记忆
|
||||
await working_memory.add_memory(from_source="chat_compress", summary=content, brief=theme)
|
||||
|
||||
logger.debug(f"{self.log_prefix} 压缩聊天记忆成功: {theme} - {content}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 解析压缩结果失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
return
|
||||
|
||||
# 清理压缩状态
|
||||
obs.compressor_prompt = ""
|
||||
obs.oldest_messages = []
|
||||
obs.oldest_messages_str = ""
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 压缩聊天记忆失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
|
||||
"""异步合并记忆,不阻塞主流程
|
||||
|
||||
Args:
|
||||
working_memory: 工作记忆对象
|
||||
memory_id1: 第一个记忆ID
|
||||
memory_id2: 第二个记忆ID
|
||||
"""
|
||||
# 检查working_memory是否为None
|
||||
if working_memory is None:
|
||||
logger.warning(f"{self.log_prefix} 工作记忆对象为None,无法合并记忆")
|
||||
return
|
||||
|
||||
try:
|
||||
merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
|
||||
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.brief}")
|
||||
logger.debug(f"{self.log_prefix} 合并后的记忆内容: {merged_memory.summary}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
|
||||
init_prompt()
|
||||
Reference in New Issue
Block a user