remove:冗余的sbhf代码和focus代码

This commit is contained in:
SengokuCola
2025-07-06 20:14:09 +08:00
parent dc24a76413
commit 1365099fd4
44 changed files with 132 additions and 2210 deletions

View File

@@ -6,20 +6,16 @@ from src.chat.focus_chat.hfc_utils import CycleDetail
from typing import List
# Import the new utility function
logger = get_logger("observation")
logger = get_logger("loop_info")
# 所有观察的基类
class HFCloopObservation:
class FocusLoopInfo:
def __init__(self, observe_id):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop: List[CycleDetail] = []
def get_observe_info(self):
return self.observe_info
def add_loop_info(self, loop_info: CycleDetail):
self.history_loop.append(loop_info)
@@ -50,11 +46,6 @@ class HFCloopObservation:
action_taken_time_str = (
datetime.fromtimestamp(action_taken_time).strftime("%H:%M:%S") if action_taken_time > 0 else "未知时间"
)
# print(action_type)
# print(action_reasoning)
# print(is_taken)
# print(action_taken_time_str)
# print("--------------------------------")
if action_reasoning != cycle_last_reason:
cycle_last_reason = action_reasoning
action_reasoning_str = f"你选择这个action的原因是:{action_reasoning}"
@@ -71,9 +62,6 @@ class HFCloopObservation:
else:
action_detailed_str += f"{action_taken_time_str}时,你选择回复(action:{action_type},内容是:'{response_text}'),但是动作失败了。{action_reasoning_str}\n"
elif action_type == "no_reply":
# action_detailed_str += (
# f"{action_taken_time_str}时,你选择不回复(action:{action_type}){action_reasoning_str}\n"
# )
pass
else:
if is_taken:
@@ -88,17 +76,6 @@ class HFCloopObservation:
else:
cycle_info_block = "\n"
# 根据连续文本回复的数量构建提示信息
if consecutive_text_replies >= 3: # 如果最近的三个活动都是文本回复
cycle_info_block = f'你已经连续回复了三条消息(最近: "{responses_for_prompt[0]}",第二近: "{responses_for_prompt[1]}",第三近: "{responses_for_prompt[2]}")。你回复的有点多了,请注意'
elif consecutive_text_replies == 2: # 如果最近的两个活动是文本回复
cycle_info_block = f'你已经连续回复了两条消息(最近: "{responses_for_prompt[0]}",第二近: "{responses_for_prompt[1]}"),请注意'
# 包装提示块,增加可读性,即使没有连续回复也给个标记
# if cycle_info_block:
# cycle_info_block = f"\n你最近的回复\n{cycle_info_block}\n"
# else:
# cycle_info_block = "\n"
# 获取history_loop中最新添加的
if self.history_loop:
@@ -112,17 +89,4 @@ class HFCloopObservation:
else:
cycle_info_block += f"距离你上一次阅读消息并思考和规划,已经过去了{time_diff}\n"
else:
cycle_info_block += "你还没看过消息\n"
self.observe_info = cycle_info_block
def to_dict(self) -> dict:
"""将观察对象转换为可序列化的字典"""
# 只序列化基本信息,避免循环引用
return {
"observe_info": self.observe_info,
"observe_id": self.observe_id,
"last_observe_time": self.last_observe_time,
# 不序列化history_loop避免循环引用
"history_loop_count": len(self.history_loop),
}
cycle_info_block += "你还没看过消息\n"

View File

@@ -9,15 +9,7 @@ from rich.traceback import install
from src.chat.utils.prompt_builder import global_prompt_manager
from src.common.logger import get_logger
from src.chat.utils.timer_calculator import Timer
from src.chat.focus_chat.observation.observation import Observation
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.info_processors.chattinginfo_processor import ChattingInfoProcessor
from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor
from src.chat.focus_chat.observation.hfcloop_observation import HFCloopObservation
from src.chat.focus_chat.observation.working_observation import WorkingMemoryObservation
from src.chat.focus_chat.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.observation.actions_observation import ActionObservation
from src.chat.focus_chat.info_processors.base_processor import BaseProcessor
from src.chat.focus_chat.focus_loop_info import FocusLoopInfo
from src.chat.planner_actions.planner_focus import ActionPlanner
from src.chat.planner_actions.action_modifier import ActionModifier
from src.chat.planner_actions.action_manager import ActionManager
@@ -32,23 +24,8 @@ install(extra_lines=3)
# 注释:原来的动作修改超时常量已移除,因为改为顺序执行
# 定义观察器映射:键是观察器名称,值是 (观察器类, 初始化参数)
OBSERVATION_CLASSES = {
"ChattingObservation": (ChattingObservation, "chat_id"),
"WorkingMemoryObservation": (WorkingMemoryObservation, "observe_id"),
"HFCloopObservation": (HFCloopObservation, "observe_id"),
}
# 定义处理器映射:键是处理器名称,值是 (处理器类, 可选的配置键名)
PROCESSOR_CLASSES = {
"ChattingInfoProcessor": (ChattingInfoProcessor, None),
"WorkingMemoryProcessor": (WorkingMemoryProcessor, "working_memory_processor"),
}
logger = get_logger("hfc") # Logger Name Changed
class HeartFChatting:
"""
管理一个连续的Focus Chat循环
@@ -83,25 +60,14 @@ class HeartFChatting:
self._message_threshold = max(10, int(30 * global_config.chat.exit_focus_threshold))
self._fatigue_triggered = False # 是否已触发疲惫退出
# 初始化观察器
self.observations: List[Observation] = []
self._register_observations()
# 根据配置文件和默认规则确定启用的处理器
self.enabled_processor_names = ["ChattingInfoProcessor"]
if global_config.focus_chat.working_memory_processor:
self.enabled_processor_names.append("WorkingMemoryProcessor")
self.processors: List[BaseProcessor] = []
self._register_default_processors()
self.loop_info: FocusLoopInfo = FocusLoopInfo(observe_id=self.stream_id)
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(
log_prefix=self.log_prefix, action_manager=self.action_manager
chat_id = self.stream_id,
action_manager=self.action_manager
)
self.action_modifier = ActionModifier(action_manager=self.action_manager, chat_id=self.stream_id)
self.action_observation = ActionObservation(observe_id=self.stream_id)
self.action_observation.set_action_manager(self.action_manager)
self._processing_lock = asyncio.Lock()
@@ -130,66 +96,8 @@ class HeartFChatting:
f"{self.log_prefix} HeartFChatting 初始化完成,消息疲惫阈值: {self._message_threshold}基于exit_focus_threshold={global_config.chat.exit_focus_threshold}计算仅在auto模式下生效"
)
def _register_observations(self):
"""注册所有观察器"""
self.observations = [] # 清空已有的
for name, (observation_class, param_name) in OBSERVATION_CLASSES.items():
try:
# 检查是否需要跳过WorkingMemoryObservation
if name == "WorkingMemoryObservation":
# 如果工作记忆处理器被禁用则跳过WorkingMemoryObservation
if not global_config.focus_chat.working_memory_processor:
logger.debug(f"{self.log_prefix} 工作记忆处理器已禁用,跳过注册观察器 {name}")
continue
# 根据参数名使用正确的参数
kwargs = {param_name: self.stream_id}
observation = observation_class(**kwargs)
self.observations.append(observation)
logger.debug(f"{self.log_prefix} 注册观察器 {name}")
except Exception as e:
logger.error(f"{self.log_prefix} 观察器 {name} 构造失败: {e}")
if self.observations:
logger.info(f"{self.log_prefix} 已注册观察器: {[o.__class__.__name__ for o in self.observations]}")
else:
logger.warning(f"{self.log_prefix} 没有注册任何观察器")
def _register_default_processors(self):
"""根据 self.enabled_processor_names 注册信息处理器"""
self.processors = [] # 清空已有的
for name in self.enabled_processor_names: # 'name' is "ChattingInfoProcessor", etc.
processor_info = PROCESSOR_CLASSES.get(name) # processor_info is (ProcessorClass, config_key)
if processor_info:
processor_actual_class = processor_info[0] # 获取实际的类定义
# 根据处理器类名判断构造参数
if name == "ChattingInfoProcessor":
self.processors.append(processor_actual_class())
elif name == "WorkingMemoryProcessor":
self.processors.append(processor_actual_class(subheartflow_id=self.stream_id))
else:
try:
self.processors.append(processor_actual_class()) # 尝试无参构造
logger.debug(f"{self.log_prefix} 注册处理器 {name} (尝试无参构造).")
except TypeError:
logger.error(
f"{self.log_prefix} 处理器 {name} 构造失败。它可能需要参数(如 subheartflow_id但未在注册逻辑中明确处理。"
)
else:
logger.warning(
f"{self.log_prefix} 在 PROCESSOR_CLASSES 中未找到名为 '{name}' 的处理器定义,将跳过注册。"
)
if self.processors:
logger.info(f"{self.log_prefix} 已注册处理器: {[p.__class__.__name__ for p in self.processors]}")
else:
logger.warning(f"{self.log_prefix} 没有注册任何处理器。这可能是由于配置错误或所有处理器都被禁用了。")
async def start(self):
"""检查是否需要启动主循环,如果未激活则启动。"""
logger.debug(f"{self.log_prefix} 开始启动 HeartFChatting")
# 如果循环已经激活,直接返回
if self._loop_active:
@@ -210,8 +118,6 @@ class HeartFChatting:
try:
# 等待旧任务确实被取消
await asyncio.wait_for(self._loop_task, timeout=5.0)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass # 忽略取消或超时错误
except Exception as e:
logger.warning(f"{self.log_prefix} 等待旧任务取消时出错: {e}")
self._loop_task = None # 清理旧任务引用
@@ -310,14 +216,11 @@ class HeartFChatting:
logger.error(f"{self.log_prefix} 处理上下文时出错: {e}")
# 为当前循环设置错误状态,防止后续重复报错
error_loop_info = {
"loop_observation_info": {},
"loop_processor_info": {},
"loop_plan_info": {
"action_result": {
"action_type": "error",
"action_data": {},
},
"observed_messages": "",
},
"loop_action_info": {
"action_taken": False,
@@ -335,14 +238,8 @@ class HeartFChatting:
self._current_cycle_detail.set_loop_info(loop_info)
# 从observations列表中获取HFCloopObservation
hfcloop_observation = next(
(obs for obs in self.observations if isinstance(obs, HFCloopObservation)), None
)
if hfcloop_observation:
hfcloop_observation.add_loop_info(self._current_cycle_detail)
else:
logger.warning(f"{self.log_prefix} 未找到HFCloopObservation实例")
self.loop_info.add_loop_info(self._current_cycle_detail)
self._current_cycle_detail.timers = cycle_timers
@@ -391,15 +288,12 @@ class HeartFChatting:
# 如果_current_cycle_detail存在但未完成为其设置错误状态
if self._current_cycle_detail and not hasattr(self._current_cycle_detail, "end_time"):
error_loop_info = {
"loop_observation_info": {},
"loop_processor_info": {},
"loop_plan_info": {
"action_result": {
"action_type": "error",
"action_data": {},
"reasoning": f"循环处理失败: {e}",
},
"observed_messages": "",
},
"loop_action_info": {
"action_taken": False,
@@ -445,65 +339,10 @@ class HeartFChatting:
if acquired and self._processing_lock.locked():
self._processing_lock.release()
async def _process_processors(self, observations: List[Observation]) -> tuple[List[InfoBase], Dict[str, float]]:
# 记录并行任务开始时间
parallel_start_time = time.time()
logger.debug(f"{self.log_prefix} 开始信息处理器并行任务")
processor_tasks = []
task_to_name_map = {}
for processor in self.processors:
processor_name = processor.__class__.log_prefix
async def run_with_timeout(proc=processor):
return await proc.process_info(observations=observations)
task = asyncio.create_task(run_with_timeout())
processor_tasks.append(task)
task_to_name_map[task] = processor_name
logger.debug(f"{self.log_prefix} 启动处理器任务: {processor_name}")
pending_tasks = set(processor_tasks)
all_plan_info: List[InfoBase] = []
while pending_tasks:
done, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_COMPLETED)
for task in done:
processor_name = task_to_name_map[task]
task_completed_time = time.time()
duration_since_parallel_start = task_completed_time - parallel_start_time
try:
result_list = await task
logger.debug(f"{self.log_prefix} 处理器 {processor_name} 已完成!")
if result_list is not None:
all_plan_info.extend(result_list)
else:
logger.warning(f"{self.log_prefix} 处理器 {processor_name} 返回了 None")
except Exception as e:
logger.error(
f"{self.log_prefix} 处理器 {processor_name} 执行失败,耗时 (自并行开始): {duration_since_parallel_start:.2f}秒. 错误: {e}",
exc_info=True,
)
traceback.print_exc()
return all_plan_info
async def _observe_process_plan_action_loop(self, cycle_timers: dict, thinking_id: str) -> dict:
try:
loop_start_time = time.time()
with Timer("观察", cycle_timers):
# 执行所有观察器的观察
for observation in self.observations:
await observation.observe()
loop_observation_info = {
"observations": self.observations,
}
await self.loop_info.observe()
await self.relationship_builder.build_relation()
@@ -513,37 +352,18 @@ class HeartFChatting:
try:
# 调用完整的动作修改流程
await self.action_modifier.modify_actions(
observations=self.observations,
loop_info = self.loop_info,
mode="focus",
)
await self.action_observation.observe()
self.observations.append(self.action_observation)
logger.debug(f"{self.log_prefix} 动作修改完成")
except Exception as e:
logger.error(f"{self.log_prefix} 动作修改失败: {e}")
# 继续执行,不中断流程
try:
all_plan_info = await self._process_processors(self.observations)
except Exception as e:
logger.error(f"{self.log_prefix} 信息处理器失败: {e}")
# 设置默认值以继续执行
all_plan_info = []
loop_processor_info = {
"all_plan_info": all_plan_info,
}
logger.debug(f"{self.log_prefix} 并行阶段完成准备进入规划器plan_info数量: {len(all_plan_info)}")
with Timer("规划器", cycle_timers):
plan_result = await self.action_planner.plan(all_plan_info, loop_start_time)
plan_result = await self.action_planner.plan()
loop_plan_info = {
"action_result": plan_result.get("action_result", {}),
"observed_messages": plan_result.get("observed_messages", ""),
}
action_type, action_data, reasoning = (
@@ -551,6 +371,8 @@ class HeartFChatting:
plan_result.get("action_result", {}).get("action_data", {}),
plan_result.get("action_result", {}).get("reasoning", "未提供理由"),
)
action_data["loop_start_time"] = loop_start_time
if action_type == "reply":
action_str = "回复"
@@ -559,7 +381,7 @@ class HeartFChatting:
else:
action_str = action_type
logger.debug(f"{self.log_prefix} 麦麦想要:'{action_str}'")
logger.debug(f"{self.log_prefix} 麦麦想要:'{action_str}',理由是:{reasoning}")
# 动作执行计时
with Timer("动作执行", cycle_timers):
@@ -575,8 +397,6 @@ class HeartFChatting:
}
loop_info = {
"loop_observation_info": loop_observation_info,
"loop_processor_info": loop_processor_info,
"loop_plan_info": loop_plan_info,
"loop_action_info": loop_action_info,
}
@@ -587,11 +407,8 @@ class HeartFChatting:
logger.error(f"{self.log_prefix} FOCUS聊天处理失败: {e}")
logger.error(traceback.format_exc())
return {
"loop_observation_info": {},
"loop_processor_info": {},
"loop_plan_info": {
"action_result": {"action_type": "error", "action_data": {}, "reasoning": f"处理失败: {e}"},
"observed_messages": "",
},
"loop_action_info": {"action_taken": False, "reply_text": "", "command": "", "taken_time": time.time()},
}
@@ -636,7 +453,7 @@ class HeartFChatting:
return False, "", ""
if not action_handler:
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}, 原因: {reasoning}")
logger.warning(f"{self.log_prefix} 未能创建动作处理器: {action}")
return False, "", ""
# 处理动作并获取结果

View File

@@ -1,125 +0,0 @@
import asyncio
from typing import Dict, Optional # 重新导入类型
from src.chat.message_receive.message import MessageSending, MessageThinking
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.common.logger import get_logger
from src.chat.utils.utils import calculate_typing_time
from rich.traceback import install
import traceback
install(extra_lines=3)
logger = get_logger("sender")
async def send_message(message: MessageSending) -> bool:
"""合并后的消息发送函数包含WS发送和日志记录"""
message_preview = truncate_message(message.processed_plain_text, max_length=40)
try:
# 直接调用API发送消息
await get_global_api().send_message(message)
logger.info(f"已将消息 '{message_preview}' 发往平台'{message.message_info.platform}'")
return True
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
traceback.print_exc()
raise e # 重新抛出其他异常
class HeartFCSender:
"""管理消息的注册、即时处理、发送和存储,并跟踪思考状态。"""
def __init__(self):
self.storage = MessageStorage()
# 用于存储活跃的思考消息
self.thinking_messages: Dict[str, Dict[str, MessageThinking]] = {}
self._thinking_lock = asyncio.Lock() # 保护 thinking_messages 的锁
async def register_thinking(self, thinking_message: MessageThinking):
"""注册一个思考中的消息。"""
if not thinking_message.chat_stream or not thinking_message.message_info.message_id:
logger.error("无法注册缺少 chat_stream 或 message_id 的思考消息")
return
chat_id = thinking_message.chat_stream.stream_id
message_id = thinking_message.message_info.message_id
async with self._thinking_lock:
if chat_id not in self.thinking_messages:
self.thinking_messages[chat_id] = {}
if message_id in self.thinking_messages[chat_id]:
logger.warning(f"[{chat_id}] 尝试注册已存在的思考消息 ID: {message_id}")
self.thinking_messages[chat_id][message_id] = thinking_message
logger.debug(f"[{chat_id}] Registered thinking message: {message_id}")
async def complete_thinking(self, chat_id: str, message_id: str):
"""完成并移除一个思考中的消息记录。"""
async with self._thinking_lock:
if chat_id in self.thinking_messages and message_id in self.thinking_messages[chat_id]:
del self.thinking_messages[chat_id][message_id]
logger.debug(f"[{chat_id}] Completed thinking message: {message_id}")
if not self.thinking_messages[chat_id]:
del self.thinking_messages[chat_id]
logger.debug(f"[{chat_id}] Removed empty thinking message container.")
async def get_thinking_start_time(self, chat_id: str, message_id: str) -> Optional[float]:
"""获取已注册思考消息的开始时间。"""
async with self._thinking_lock:
thinking_message = self.thinking_messages.get(chat_id, {}).get(message_id)
return thinking_message.thinking_start_time if thinking_message else None
async def send_message(self, message: MessageSending, typing=False, set_reply=False, storage_message=True):
"""
处理、发送并存储一条消息。
参数:
message: MessageSending 对象,待发送的消息。
typing: 是否模拟打字等待。
用法:
- typing=True 时,发送前会有打字等待。
"""
if not message.chat_stream:
logger.error("消息缺少 chat_stream无法发送")
raise Exception("消息缺少 chat_stream无法发送")
if not message.message_info or not message.message_info.message_id:
logger.error("消息缺少 message_info 或 message_id无法发送")
raise Exception("消息缺少 message_info 或 message_id无法发送")
chat_id = message.chat_stream.stream_id
message_id = message.message_info.message_id
try:
if set_reply:
message.build_reply()
logger.debug(f"[{chat_id}] 选择回复引用消息: {message.processed_plain_text[:20]}...")
await message.process()
if typing:
typing_time = calculate_typing_time(
input_string=message.processed_plain_text,
thinking_start_time=message.thinking_start_time,
is_emoji=message.is_emoji,
)
await asyncio.sleep(typing_time)
sent_msg = await send_message(message)
if not sent_msg:
return False
if storage_message:
await self.storage.store_message(message, message.chat_stream)
return sent_msg
except Exception as e:
logger.error(f"[{chat_id}] 处理或存储消息 {message_id} 时出错: {e}")
raise e
finally:
await self.complete_thinking(chat_id, message_id)

View File

@@ -41,7 +41,6 @@ class HFCPerformanceLogger:
"action_type": cycle_data.get("action_type", "unknown"),
"total_time": cycle_data.get("total_time", 0),
"step_times": cycle_data.get("step_times", {}),
"processor_time_costs": cycle_data.get("processor_time_costs", {}), # 前处理器时间
"reasoning": cycle_data.get("reasoning", ""),
"success": cycle_data.get("success", False),
}

View File

@@ -5,7 +5,6 @@ from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import UserInfo
from src.common.logger import get_logger
import json
import os
from typing import Dict, Any
logger = get_logger(__name__)
@@ -24,9 +23,6 @@ class CycleDetail:
self.end_time: Optional[float] = None
self.timers: Dict[str, float] = {}
# 新字段
self.loop_observation_info: Dict[str, Any] = {}
self.loop_processor_info: Dict[str, Any] = {} # 前处理器信息
self.loop_plan_info: Dict[str, Any] = {}
self.loop_action_info: Dict[str, Any] = {}
@@ -79,8 +75,6 @@ class CycleDetail:
"end_time": self.end_time,
"timers": self.timers,
"thinking_id": self.thinking_id,
"loop_observation_info": convert_to_serializable(self.loop_observation_info),
"loop_processor_info": convert_to_serializable(self.loop_processor_info),
"loop_plan_info": convert_to_serializable(self.loop_plan_info),
"loop_action_info": convert_to_serializable(self.loop_action_info),
}
@@ -100,41 +94,12 @@ class CycleDetail:
or "group"
)
# current_time_minute = time.strftime("%Y%m%d_%H%M", time.localtime())
# try:
# self.log_cycle_to_file(
# log_dir + self.prefix + f"/{current_time_minute}_cycle_" + str(self.cycle_id) + ".json"
# )
# except Exception as e:
# logger.warning(f"写入文件日志,可能是群名称包含非法字符: {e}")
def log_cycle_to_file(self, file_path: str):
"""将循环信息写入文件"""
# 如果目录不存在,则创建目
dir_name = os.path.dirname(file_path)
# 去除特殊字符,保留字母、数字、下划线、中划线和中文
dir_name = "".join(
char for char in dir_name if char.isalnum() or char in ["_", "-", "/"] or "\u4e00" <= char <= "\u9fff"
)
# print("dir_name:", dir_name)
if dir_name and not os.path.exists(dir_name):
os.makedirs(dir_name, exist_ok=True)
# 写入文件
file_path = os.path.join(dir_name, os.path.basename(file_path))
# print("file_path:", file_path)
with open(file_path, "a", encoding="utf-8") as f:
f.write(json.dumps(self.to_dict(), ensure_ascii=False) + "\n")
def set_thinking_id(self, thinking_id: str):
"""设置思考消息ID"""
self.thinking_id = thinking_id
def set_loop_info(self, loop_info: Dict[str, Any]):
"""设置循环信息"""
self.loop_observation_info = loop_info["loop_observation_info"]
self.loop_processor_info = loop_info["loop_processor_info"]
self.loop_plan_info = loop_info["loop_plan_info"]
self.loop_action_info = loop_info["loop_action_info"]

View File

@@ -20,7 +20,7 @@ class HFCVersionManager:
"""HFC版本号管理器"""
# 默认版本号
DEFAULT_VERSION = "v5.0.0"
DEFAULT_VERSION = "v6.0.0"
# 当前运行时版本号
_current_version: Optional[str] = None

View File

@@ -1,83 +0,0 @@
from typing import Dict, Optional, Any, List
from dataclasses import dataclass
from .info_base import InfoBase
@dataclass
class ActionInfo(InfoBase):
"""动作信息类
用于管理和记录动作的变更信息,包括需要添加或移除的动作。
继承自 InfoBase 类,使用字典存储具体数据。
Attributes:
type (str): 信息类型标识符,固定为 "action"
Data Fields:
add_actions (List[str]): 需要添加的动作列表
remove_actions (List[str]): 需要移除的动作列表
reason (str): 变更原因说明
"""
type: str = "action"
def get_type(self) -> str:
"""获取信息类型"""
return self.type
def get_data(self) -> Dict[str, Any]:
"""获取信息数据"""
return self.data
def set_action_changes(self, action_changes: Dict[str, List[str]]) -> None:
"""设置动作变更信息
Args:
action_changes (Dict[str, List[str]]): 包含要增加和删除的动作列表
{
"add": ["action1", "action2"],
"remove": ["action3"]
}
"""
self.data["add_actions"] = action_changes.get("add", [])
self.data["remove_actions"] = action_changes.get("remove", [])
def set_reason(self, reason: str) -> None:
"""设置变更原因
Args:
reason (str): 动作变更的原因说明
"""
self.data["reason"] = reason
def get_add_actions(self) -> List[str]:
"""获取需要添加的动作列表
Returns:
List[str]: 需要添加的动作列表
"""
return self.data.get("add_actions", [])
def get_remove_actions(self) -> List[str]:
"""获取需要移除的动作列表
Returns:
List[str]: 需要移除的动作列表
"""
return self.data.get("remove_actions", [])
def get_reason(self) -> Optional[str]:
"""获取变更原因
Returns:
Optional[str]: 动作变更的原因说明,如果未设置则返回 None
"""
return self.data.get("reason")
def has_changes(self) -> bool:
"""检查是否有动作变更
Returns:
bool: 如果有任何动作需要添加或移除则返回True
"""
return bool(self.get_add_actions() or self.get_remove_actions())

View File

@@ -1,157 +0,0 @@
from typing import Dict, Optional, Any
from dataclasses import dataclass
from .info_base import InfoBase
@dataclass
class CycleInfo(InfoBase):
"""循环信息类
用于记录和管理心跳循环的相关信息包括循环ID、时间信息、动作信息等。
继承自 InfoBase 类,使用字典存储具体数据。
Attributes:
type (str): 信息类型标识符,固定为 "cycle"
Data Fields:
cycle_id (str): 当前循环的唯一标识符
start_time (str): 循环开始的时间
end_time (str): 循环结束的时间
action (str): 在循环中采取的动作
action_data (Dict[str, Any]): 动作相关的详细数据
reason (str): 触发循环的原因
observe_info (str): 当前的回复信息
"""
type: str = "cycle"
def get_type(self) -> str:
"""获取信息类型"""
return self.type
def get_data(self) -> Dict[str, str]:
"""获取信息数据"""
return self.data
def get_info(self, key: str) -> Optional[str]:
"""获取特定属性的信息
Args:
key: 要获取的属性键名
Returns:
属性值,如果键不存在则返回 None
"""
return self.data.get(key)
def set_cycle_id(self, cycle_id: str) -> None:
"""设置循环ID
Args:
cycle_id (str): 循环的唯一标识符
"""
self.data["cycle_id"] = cycle_id
def set_start_time(self, start_time: str) -> None:
"""设置开始时间
Args:
start_time (str): 循环开始的时间,建议使用标准时间格式
"""
self.data["start_time"] = start_time
def set_end_time(self, end_time: str) -> None:
"""设置结束时间
Args:
end_time (str): 循环结束的时间,建议使用标准时间格式
"""
self.data["end_time"] = end_time
def set_action(self, action: str) -> None:
"""设置采取的动作
Args:
action (str): 在循环中执行的动作名称
"""
self.data["action"] = action
def set_action_data(self, action_data: Dict[str, Any]) -> None:
"""设置动作数据
Args:
action_data (Dict[str, Any]): 动作相关的详细数据,将被转换为字符串存储
"""
self.data["action_data"] = str(action_data)
def set_reason(self, reason: str) -> None:
"""设置原因
Args:
reason (str): 触发循环的原因说明
"""
self.data["reason"] = reason
def set_observe_info(self, observe_info: str) -> None:
"""设置回复信息
Args:
observe_info (str): 当前的回复信息
"""
self.data["observe_info"] = observe_info
def get_cycle_id(self) -> Optional[str]:
"""获取循环ID
Returns:
Optional[str]: 循环的唯一标识符,如果未设置则返回 None
"""
return self.get_info("cycle_id")
def get_start_time(self) -> Optional[str]:
"""获取开始时间
Returns:
Optional[str]: 循环开始的时间,如果未设置则返回 None
"""
return self.get_info("start_time")
def get_end_time(self) -> Optional[str]:
"""获取结束时间
Returns:
Optional[str]: 循环结束的时间,如果未设置则返回 None
"""
return self.get_info("end_time")
def get_action(self) -> Optional[str]:
"""获取采取的动作
Returns:
Optional[str]: 在循环中执行的动作名称,如果未设置则返回 None
"""
return self.get_info("action")
def get_action_data(self) -> Optional[str]:
"""获取动作数据
Returns:
Optional[str]: 动作相关的详细数据(字符串形式),如果未设置则返回 None
"""
return self.get_info("action_data")
def get_reason(self) -> Optional[str]:
"""获取原因
Returns:
Optional[str]: 触发循环的原因说明,如果未设置则返回 None
"""
return self.get_info("reason")
def get_observe_info(self) -> Optional[str]:
"""获取回复信息
Returns:
Optional[str]: 当前的回复信息,如果未设置则返回 None
"""
return self.get_info("observe_info")

View File

@@ -1,69 +0,0 @@
from typing import Dict, Optional, Any, List
from dataclasses import dataclass, field
@dataclass
class InfoBase:
"""信息基类
这是一个基础信息类,用于存储和管理各种类型的信息数据。
所有具体的信息类都应该继承自这个基类。
Attributes:
type (str): 信息类型标识符,默认为 "base"
data (Dict[str, Union[str, Dict, list]]): 存储具体信息数据的字典,
支持存储字符串、字典、列表等嵌套数据结构
"""
type: str = "base"
data: Dict[str, Any] = field(default_factory=dict)
processed_info: str = ""
def get_type(self) -> str:
"""获取信息类型
Returns:
str: 当前信息对象的类型标识符
"""
return self.type
def get_data(self) -> Dict[str, Any]:
"""获取所有信息数据
Returns:
Dict[str, Any]: 包含所有信息数据的字典
"""
return self.data
def get_info(self, key: str) -> Optional[Any]:
"""获取特定属性的信息
Args:
key: 要获取的属性键名
Returns:
Optional[Any]: 属性值,如果键不存在则返回 None
"""
return self.data.get(key)
def get_info_list(self, key: str) -> List[Any]:
"""获取特定属性的信息列表
Args:
key: 要获取的属性键名
Returns:
List[Any]: 属性值列表,如果键不存在则返回空列表
"""
value = self.data.get(key)
if isinstance(value, list):
return value
return []
def get_processed_info(self) -> str:
"""获取处理后的信息
Returns:
str: 处理后的信息字符串
"""
return self.processed_info

View File

@@ -1,165 +0,0 @@
from typing import Dict, Optional
from dataclasses import dataclass
from .info_base import InfoBase
@dataclass
class ObsInfo(InfoBase):
"""OBS信息类
用于记录和管理OBS相关的信息包括说话消息、截断后的说话消息和聊天类型。
继承自 InfoBase 类,使用字典存储具体数据。
Attributes:
type (str): 信息类型标识符,固定为 "obs"
Data Fields:
talking_message (str): 说话消息内容
talking_message_str_truncate (str): 截断后的说话消息内容
talking_message_str_short (str): 简短版本的说话消息内容(使用最新一半消息)
talking_message_str_truncate_short (str): 截断简短版本的说话消息内容(使用最新一半消息)
chat_type (str): 聊天类型,可以是 "private"(私聊)、"group"(群聊)或 "other"(其他)
"""
type: str = "obs"
def set_talking_message(self, message: str) -> None:
"""设置说话消息
Args:
message (str): 说话消息内容
"""
self.data["talking_message"] = message
def set_talking_message_str_truncate(self, message: str) -> None:
"""设置截断后的说话消息
Args:
message (str): 截断后的说话消息内容
"""
self.data["talking_message_str_truncate"] = message
def set_talking_message_str_short(self, message: str) -> None:
"""设置简短版本的说话消息
Args:
message (str): 简短版本的说话消息内容
"""
self.data["talking_message_str_short"] = message
def set_talking_message_str_truncate_short(self, message: str) -> None:
"""设置截断简短版本的说话消息
Args:
message (str): 截断简短版本的说话消息内容
"""
self.data["talking_message_str_truncate_short"] = message
def set_previous_chat_info(self, message: str) -> None:
"""设置之前聊天信息
Args:
message (str): 之前聊天信息内容
"""
self.data["previous_chat_info"] = message
def set_chat_type(self, chat_type: str) -> None:
"""设置聊天类型
Args:
chat_type (str): 聊天类型,可以是 "private"(私聊)、"group"(群聊)或 "other"(其他)
"""
if chat_type not in ["private", "group", "other"]:
chat_type = "other"
self.data["chat_type"] = chat_type
def set_chat_target(self, chat_target: str) -> None:
"""设置聊天目标
Args:
chat_target (str): 聊天目标,可以是 "private"(私聊)、"group"(群聊)或 "other"(其他)
"""
self.data["chat_target"] = chat_target
def set_chat_id(self, chat_id: str) -> None:
"""设置聊天ID
Args:
chat_id (str): 聊天ID
"""
self.data["chat_id"] = chat_id
def get_chat_id(self) -> Optional[str]:
"""获取聊天ID
Returns:
Optional[str]: 聊天ID如果未设置则返回 None
"""
return self.get_info("chat_id")
def get_talking_message(self) -> Optional[str]:
"""获取说话消息
Returns:
Optional[str]: 说话消息内容,如果未设置则返回 None
"""
return self.get_info("talking_message")
def get_talking_message_str_truncate(self) -> Optional[str]:
"""获取截断后的说话消息
Returns:
Optional[str]: 截断后的说话消息内容,如果未设置则返回 None
"""
return self.get_info("talking_message_str_truncate")
def get_talking_message_str_short(self) -> Optional[str]:
"""获取简短版本的说话消息
Returns:
Optional[str]: 简短版本的说话消息内容,如果未设置则返回 None
"""
return self.get_info("talking_message_str_short")
def get_talking_message_str_truncate_short(self) -> Optional[str]:
"""获取截断简短版本的说话消息
Returns:
Optional[str]: 截断简短版本的说话消息内容,如果未设置则返回 None
"""
return self.get_info("talking_message_str_truncate_short")
def get_chat_type(self) -> str:
"""获取聊天类型
Returns:
str: 聊天类型,默认为 "other"
"""
return self.get_info("chat_type") or "other"
def get_type(self) -> str:
"""获取信息类型
Returns:
str: 当前信息对象的类型标识符
"""
return self.type
def get_data(self) -> Dict[str, str]:
"""获取所有信息数据
Returns:
Dict[str, str]: 包含所有信息数据的字典
"""
return self.data
def get_info(self, key: str) -> Optional[str]:
"""获取特定属性的信息
Args:
key: 要获取的属性键名
Returns:
Optional[str]: 属性值,如果键不存在则返回 None
"""
return self.data.get(key)

View File

@@ -1,86 +0,0 @@
from typing import Dict, Optional, List
from dataclasses import dataclass
from .info_base import InfoBase
@dataclass
class WorkingMemoryInfo(InfoBase):
type: str = "workingmemory"
processed_info: str = ""
def set_talking_message(self, message: str) -> None:
"""设置说话消息
Args:
message (str): 说话消息内容
"""
self.data["talking_message"] = message
def set_working_memory(self, working_memory: List[str]) -> None:
"""设置工作记忆列表
Args:
working_memory (List[str]): 工作记忆内容列表
"""
self.data["working_memory"] = working_memory
def add_working_memory(self, working_memory: str) -> None:
"""添加一条工作记忆
Args:
working_memory (str): 工作记忆内容,格式为"记忆要点:xxx"
"""
working_memory_list = self.data.get("working_memory", [])
working_memory_list.append(working_memory)
self.data["working_memory"] = working_memory_list
def get_working_memory(self) -> List[str]:
"""获取所有工作记忆
Returns:
List[str]: 工作记忆内容列表,每条记忆格式为"记忆要点:xxx"
"""
return self.data.get("working_memory", [])
def get_type(self) -> str:
"""获取信息类型
Returns:
str: 当前信息对象的类型标识符
"""
return self.type
def get_data(self) -> Dict[str, List[str]]:
"""获取所有信息数据
Returns:
Dict[str, List[str]]: 包含所有信息数据的字典
"""
return self.data
def get_info(self, key: str) -> Optional[List[str]]:
"""获取特定属性的信息
Args:
key: 要获取的属性键名
Returns:
Optional[List[str]]: 属性值,如果键不存在则返回 None
"""
return self.data.get(key)
def get_processed_info(self) -> str:
"""获取处理后的信息
Returns:
str: 处理后的信息数据,所有记忆要点按行拼接
"""
all_memory = self.get_working_memory()
memory_str = ""
for memory in all_memory:
memory_str += f"{memory}\n"
self.processed_info = memory_str
return self.processed_info

View File

@@ -1,51 +0,0 @@
from abc import ABC, abstractmethod
from typing import List, Any
from src.chat.focus_chat.info.info_base import InfoBase
from src.chat.focus_chat.observation.observation import Observation
from src.common.logger import get_logger
logger = get_logger("base_processor")
class BaseProcessor(ABC):
"""信息处理器基类
所有具体的信息处理器都应该继承这个基类并实现process_info方法。
支持处理InfoBase和Observation类型的输入。
"""
log_prefix = "Base信息处理器"
@abstractmethod
def __init__(self):
"""初始化处理器"""
@abstractmethod
async def process_info(
self,
observations: List[Observation] = None,
**kwargs: Any,
) -> List[InfoBase]:
"""处理信息对象的抽象方法
Args:
infos: InfoBase对象列表
observations: 可选的Observation对象列表
**kwargs: 其他可选参数
Returns:
List[InfoBase]: 处理后的InfoBase实例列表
"""
pass
def _create_processed_item(self, info_type: str, info_data: Any) -> dict:
"""创建处理后的信息项
Args:
info_type: 信息类型
info_data: 信息数据
Returns:
dict: 处理后的信息项
"""
return {"type": info_type, "id": f"info_{info_type}", "content": info_data, "ttl": 3}

View File

@@ -1,142 +0,0 @@
from typing import List, Any
from src.chat.focus_chat.info.obs_info import ObsInfo
from src.chat.focus_chat.observation.observation import Observation
from src.chat.focus_chat.info.info_base import InfoBase
from .base_processor import BaseProcessor
from src.common.logger import get_logger
from src.chat.focus_chat.observation.chatting_observation import ChattingObservation
from datetime import datetime
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
logger = get_logger("processor")
class ChattingInfoProcessor(BaseProcessor):
"""观察处理器
用于处理Observation对象将其转换为ObsInfo对象。
"""
log_prefix = "聊天信息处理"
def __init__(self):
"""初始化观察处理器"""
super().__init__()
# TODO: API-Adapter修改标记
self.model_summary = LLMRequest(
model=global_config.model.utils_small,
temperature=0.7,
request_type="focus.observation.chat",
)
async def process_info(
self,
observations: List[Observation] = None,
**kwargs: Any,
) -> List[InfoBase]:
"""处理Observation对象
Args:
infos: InfoBase对象列表
observations: 可选的Observation对象列表
**kwargs: 其他可选参数
Returns:
List[InfoBase]: 处理后的ObsInfo实例列表
"""
# print(f"observations: {observations}")
processed_infos = []
# 处理Observation对象
if observations:
for obs in observations:
# print(f"obs: {obs}")
if isinstance(obs, ChattingObservation):
obs_info = ObsInfo()
# 设置聊天ID
if hasattr(obs, "chat_id"):
obs_info.set_chat_id(obs.chat_id)
# 设置说话消息
if hasattr(obs, "talking_message_str"):
# print(f"设置说话消息obs.talking_message_str: {obs.talking_message_str}")
obs_info.set_talking_message(obs.talking_message_str)
# 设置截断后的说话消息
if hasattr(obs, "talking_message_str_truncate"):
# print(f"设置截断后的说话消息obs.talking_message_str_truncate: {obs.talking_message_str_truncate}")
obs_info.set_talking_message_str_truncate(obs.talking_message_str_truncate)
# 设置简短版本的说话消息
if hasattr(obs, "talking_message_str_short"):
obs_info.set_talking_message_str_short(obs.talking_message_str_short)
# 设置截断简短版本的说话消息
if hasattr(obs, "talking_message_str_truncate_short"):
obs_info.set_talking_message_str_truncate_short(obs.talking_message_str_truncate_short)
if hasattr(obs, "mid_memory_info"):
# print(f"设置之前聊天信息obs.mid_memory_info: {obs.mid_memory_info}")
obs_info.set_previous_chat_info(obs.mid_memory_info)
# 设置聊天类型
is_group_chat = obs.is_group_chat
if is_group_chat:
chat_type = "group"
else:
chat_type = "private"
if hasattr(obs, "chat_target_info") and obs.chat_target_info:
obs_info.set_chat_target(obs.chat_target_info.get("person_name", "某人"))
obs_info.set_chat_type(chat_type)
# logger.debug(f"聊天信息处理器处理后的信息: {obs_info}")
processed_infos.append(obs_info)
return processed_infos
async def chat_compress(self, obs: ChattingObservation):
log_msg = ""
if obs.compressor_prompt:
summary = ""
try:
summary_result, _ = await self.model_summary.generate_response_async(obs.compressor_prompt)
summary = "没有主题的闲聊"
if summary_result:
summary = summary_result
except Exception as e:
log_msg = f"总结主题失败 for chat {obs.chat_id}: {e}"
logger.error(log_msg)
else:
log_msg = f"chat_compress 完成 for chat {obs.chat_id}, summary: {summary}"
logger.info(log_msg)
mid_memory = {
"id": str(int(datetime.now().timestamp())),
"theme": summary,
"messages": obs.oldest_messages, # 存储原始消息对象
"readable_messages": obs.oldest_messages_str,
# "timestamps": oldest_timestamps,
"chat_id": obs.chat_id,
"created_at": datetime.now().timestamp(),
}
obs.mid_memories.append(mid_memory)
if len(obs.mid_memories) > obs.max_mid_memory_len:
obs.mid_memories.pop(0) # 移除最旧的
mid_memory_str = "之前聊天的内容概述是:\n"
for mid_memory_item in obs.mid_memories: # 重命名循环变量以示区分
time_diff = int((datetime.now().timestamp() - mid_memory_item["created_at"]) / 60)
mid_memory_str += (
f"距离现在{time_diff}分钟前(聊天记录id:{mid_memory_item['id']}){mid_memory_item['theme']}\n"
)
obs.mid_memory_info = mid_memory_str
obs.compressor_prompt = ""
obs.oldest_messages = []
obs.oldest_messages_str = ""
return log_msg

View File

@@ -1,264 +0,0 @@
from src.chat.focus_chat.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.observation.observation import Observation
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
import time
import traceback
from src.common.logger import get_logger
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from .base_processor import BaseProcessor
from typing import List
from src.chat.focus_chat.observation.working_observation import WorkingMemoryObservation
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from src.chat.focus_chat.info.info_base import InfoBase
from json_repair import repair_json
from src.chat.focus_chat.info.workingmemory_info import WorkingMemoryInfo
import asyncio
import json
logger = get_logger("processor")
def init_prompt():
memory_proces_prompt = """
你的名字是{bot_name}
现在是{time_now}你正在上网和qq群里的网友们聊天以下是正在进行的聊天内容
{chat_observe_info}
以下是你已经总结的记忆摘要你可以调取这些记忆查看内容来帮助你聊天不要一次调取太多记忆最多调取3个左右记忆
{memory_str}
观察聊天内容和已经总结的记忆思考如果有相近的记忆请合并记忆输出merge_memory
合并记忆的格式为[["id1", "id2"], ["id3", "id4"],...]你可以进行多组合并但是每组合并只能有两个记忆id不要输出其他内容
请根据聊天内容选择你需要调取的记忆并考虑是否添加新记忆以JSON格式输出格式如下
```json
{{
"selected_memory_ids": ["id1", "id2", ...]
"merge_memory": [["id1", "id2"], ["id3", "id4"],...]
}}
```
"""
Prompt(memory_proces_prompt, "prompt_memory_proces")
class WorkingMemoryProcessor(BaseProcessor):
log_prefix = "工作记忆"
def __init__(self, subheartflow_id: str):
super().__init__()
self.subheartflow_id = subheartflow_id
self.llm_model = LLMRequest(
model=global_config.model.planner,
request_type="focus.processor.working_memory",
)
name = get_chat_manager().get_stream_name(self.subheartflow_id)
self.log_prefix = f"[{name}] "
async def process_info(self, observations: List[Observation] = None, *infos) -> List[InfoBase]:
"""处理信息对象
Args:
*infos: 可变数量的InfoBase类型的信息对象
Returns:
List[InfoBase]: 处理后的结构化信息列表
"""
working_memory = None
chat_info = ""
chat_obs = None
try:
for observation in observations:
if isinstance(observation, WorkingMemoryObservation):
working_memory = observation.get_observe_info()
if isinstance(observation, ChattingObservation):
chat_info = observation.get_observe_info()
chat_obs = observation
# 检查是否有待压缩内容
if chat_obs and chat_obs.compressor_prompt:
logger.debug(f"{self.log_prefix} 压缩聊天记忆")
await self.compress_chat_memory(working_memory, chat_obs)
# 检查working_memory是否为None
if working_memory is None:
logger.debug(f"{self.log_prefix} 没有找到工作记忆观察,跳过处理")
return []
all_memory = working_memory.get_all_memories()
if not all_memory:
logger.debug(f"{self.log_prefix} 目前没有工作记忆,跳过提取")
return []
memory_prompts = []
for memory in all_memory:
memory_id = memory.id
memory_brief = memory.brief
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
memory_prompts.append(memory_single_prompt)
memory_choose_str = "".join(memory_prompts)
# 使用提示模板进行处理
prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
bot_name=global_config.bot.nickname,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_info,
memory_str=memory_choose_str,
)
# 调用LLM处理记忆
content = ""
try:
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
# print(f"prompt: {prompt}---------------------------------")
# print(f"content: {content}---------------------------------")
if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果处理工作记忆失败。")
return []
except Exception as e:
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
logger.error(traceback.format_exc())
return []
# 解析LLM返回的JSON
try:
result = repair_json(content)
if isinstance(result, str):
result = json.loads(result)
if not isinstance(result, dict):
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败结果不是字典类型: {type(result)}")
return []
selected_memory_ids = result.get("selected_memory_ids", [])
merge_memory = result.get("merge_memory", [])
except Exception as e:
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
logger.error(traceback.format_exc())
return []
logger.debug(
f"{self.log_prefix} 解析LLM返回的JSON,selected_memory_ids: {selected_memory_ids}, merge_memory: {merge_memory}"
)
# 根据selected_memory_ids调取记忆
memory_str = ""
selected_ids = set(selected_memory_ids) # 转换为集合以便快速查找
# 遍历所有记忆
for memory in all_memory:
if memory.id in selected_ids:
# 选中的记忆显示详细内容
memory = await working_memory.retrieve_memory(memory.id)
if memory:
memory_str += f"{memory.summary}\n"
else:
# 未选中的记忆显示梗概
memory_str += f"{memory.brief}\n"
working_memory_info = WorkingMemoryInfo()
if memory_str:
working_memory_info.add_working_memory(memory_str)
logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
else:
logger.debug(f"{self.log_prefix} 没有找到工作记忆")
if merge_memory:
for merge_pairs in merge_memory:
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
memory2 = await working_memory.retrieve_memory(merge_pairs[1])
if memory1 and memory2:
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
return [working_memory_info]
except Exception as e:
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
logger.error(traceback.format_exc())
return []
async def compress_chat_memory(self, working_memory: WorkingMemory, obs: ChattingObservation):
"""压缩聊天记忆
Args:
working_memory: 工作记忆对象
obs: 聊天观察对象
"""
# 检查working_memory是否为None
if working_memory is None:
logger.warning(f"{self.log_prefix} 工作记忆对象为None无法压缩聊天记忆")
return
try:
summary_result, _ = await self.llm_model.generate_response_async(obs.compressor_prompt)
if not summary_result:
logger.debug(f"{self.log_prefix} 压缩聊天记忆失败: 没有生成摘要")
return
print(f"compressor_prompt: {obs.compressor_prompt}")
print(f"summary_result: {summary_result}")
# 修复并解析JSON
try:
fixed_json = repair_json(summary_result)
summary_data = json.loads(fixed_json)
if not isinstance(summary_data, dict):
logger.error(f"{self.log_prefix} 解析压缩结果失败: 不是有效的JSON对象")
return
theme = summary_data.get("theme", "")
content = summary_data.get("content", "")
if not theme or not content:
logger.error(f"{self.log_prefix} 解析压缩结果失败: 缺少必要字段")
return
# 创建新记忆
await working_memory.add_memory(from_source="chat_compress", summary=content, brief=theme)
logger.debug(f"{self.log_prefix} 压缩聊天记忆成功: {theme} - {content}")
except Exception as e:
logger.error(f"{self.log_prefix} 解析压缩结果失败: {e}")
logger.error(traceback.format_exc())
return
# 清理压缩状态
obs.compressor_prompt = ""
obs.oldest_messages = []
obs.oldest_messages_str = ""
except Exception as e:
logger.error(f"{self.log_prefix} 压缩聊天记忆失败: {e}")
logger.error(traceback.format_exc())
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
"""异步合并记忆,不阻塞主流程
Args:
working_memory: 工作记忆对象
memory_id1: 第一个记忆ID
memory_id2: 第二个记忆ID
"""
# 检查working_memory是否为None
if working_memory is None:
logger.warning(f"{self.log_prefix} 工作记忆对象为None无法合并记忆")
return
try:
merged_memory = await working_memory.merge_memory(memory_id1, memory_id2)
logger.debug(f"{self.log_prefix} 合并后的记忆梗概: {merged_memory.brief}")
logger.debug(f"{self.log_prefix} 合并后的记忆内容: {merged_memory.summary}")
except Exception as e:
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
logger.error(traceback.format_exc())
init_prompt()

View File

@@ -1,46 +0,0 @@
# 定义了来自外部世界的信息
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime
from src.common.logger import get_logger
from src.chat.planner_actions.action_manager import ActionManager
logger = get_logger("observation")
# 特殊的观察,专门用于观察动作
# 所有观察的基类
class ActionObservation:
def __init__(self, observe_id):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.action_manager: ActionManager = None
self.all_actions = {}
self.all_using_actions = {}
def get_observe_info(self):
return self.observe_info
def set_action_manager(self, action_manager: ActionManager):
self.action_manager = action_manager
self.all_actions = self.action_manager.get_registered_actions()
async def observe(self):
action_info_block = ""
self.all_using_actions = self.action_manager.get_using_actions()
for action_name, action_info in self.all_using_actions.items():
action_info_block += f"\n{action_name}: {action_info.get('description', '')}"
action_info_block += "\n注意,除了上面动作选项之外,你在群聊里不能做其他任何事情,这是你能力的边界\n"
self.observe_info = action_info_block
def to_dict(self) -> dict:
"""将观察对象转换为可序列化的字典"""
return {
"observe_info": self.observe_info,
"observe_id": self.observe_id,
"last_observe_time": self.last_observe_time,
"all_actions": self.all_actions,
"all_using_actions": self.all_using_actions,
}

View File

@@ -1,183 +0,0 @@
from datetime import datetime
from src.config.config import global_config
from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
build_readable_messages,
get_raw_msg_by_timestamp_with_chat,
num_new_messages_since,
get_person_id_list,
)
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
from src.chat.focus_chat.observation.observation import Observation
from src.common.logger import get_logger
from src.chat.utils.utils import get_chat_type_and_target_info
logger = get_logger("observation")
# 定义提示模板
Prompt(
"""这是{chat_type_description},请总结以下聊天记录的主题:
{chat_logs}
请概括这段聊天记录的主题和主要内容
主题简短的概括包括时间人物和事件不要超过20个字
内容具体的信息内容包括人物、事件和信息不要超过200个字不要分点。
请用json格式返回格式如下
{{
"theme": "主题,例如 2025-06-14 10:00:00 群聊 麦麦 和 网友 讨论了 游戏 的话题",
"content": "内容,可以是对聊天记录的概括,也可以是聊天记录的详细内容"
}}
""",
"chat_summary_prompt",
)
class ChattingObservation(Observation):
def __init__(self, chat_id):
super().__init__(chat_id)
self.chat_id = chat_id
self.platform = "qq"
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
self.talking_message = []
self.talking_message_str = ""
self.talking_message_str_truncate = ""
self.talking_message_str_short = ""
self.talking_message_str_truncate_short = ""
self.name = global_config.bot.nickname
self.nick_name = global_config.bot.alias_names
self.max_now_obs_len = global_config.chat.max_context_size
self.overlap_len = global_config.focus_chat.compressed_length
self.person_list = []
self.compressor_prompt = ""
self.oldest_messages = []
self.oldest_messages_str = ""
self.last_observe_time = datetime.now().timestamp()
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
initial_messages_short = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 5)
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
self.talking_message = initial_messages
self.talking_message_short = initial_messages_short
self.talking_message_str = build_readable_messages(self.talking_message, show_actions=True)
self.talking_message_str_truncate = build_readable_messages(
self.talking_message, show_actions=True, truncate=True
)
self.talking_message_str_short = build_readable_messages(self.talking_message_short, show_actions=True)
self.talking_message_str_truncate_short = build_readable_messages(
self.talking_message_short, show_actions=True, truncate=True
)
def to_dict(self) -> dict:
"""将观察对象转换为可序列化的字典"""
return {
"chat_id": self.chat_id,
"platform": self.platform,
"is_group_chat": self.is_group_chat,
"chat_target_info": self.chat_target_info,
"talking_message_str": self.talking_message_str,
"talking_message_str_truncate": self.talking_message_str_truncate,
"talking_message_str_short": self.talking_message_str_short,
"talking_message_str_truncate_short": self.talking_message_str_truncate_short,
"name": self.name,
"nick_name": self.nick_name,
"last_observe_time": self.last_observe_time,
}
def get_observe_info(self, ids=None):
return self.talking_message_str
async def observe(self):
# 自上一次观察的新消息
new_messages_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=self.last_observe_time,
timestamp_end=datetime.now().timestamp(),
limit=self.max_now_obs_len,
limit_mode="latest",
)
# print(f"new_messages_list: {new_messages_list}")
last_obs_time_mark = self.last_observe_time
if new_messages_list:
self.last_observe_time = new_messages_list[-1]["time"]
self.talking_message.extend(new_messages_list)
if len(self.talking_message) > self.max_now_obs_len:
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
oldest_messages = self.talking_message[:messages_to_remove_count]
self.talking_message = self.talking_message[messages_to_remove_count:]
# 构建压缩提示
oldest_messages_str = build_readable_messages(
messages=oldest_messages, timestamp_mode="normal_no_YMD", read_mark=0, show_actions=True
)
# 根据聊天类型选择提示模板
prompt_template_name = "chat_summary_prompt"
if self.is_group_chat:
chat_type_description = "qq群聊的聊天记录"
else:
chat_target_name = "对方"
if self.chat_target_info:
chat_target_name = (
self.chat_target_info.get("person_name")
or self.chat_target_info.get("user_nickname")
or chat_target_name
)
chat_type_description = f"你和{chat_target_name}的私聊记录"
prompt = await global_prompt_manager.format_prompt(
prompt_template_name,
chat_type_description=chat_type_description,
chat_logs=oldest_messages_str,
)
self.compressor_prompt = prompt
# 构建当前消息
self.talking_message_str = build_readable_messages(
messages=self.talking_message,
timestamp_mode="lite",
read_mark=last_obs_time_mark,
show_actions=True,
)
self.talking_message_str_truncate = build_readable_messages(
messages=self.talking_message,
timestamp_mode="normal_no_YMD",
read_mark=last_obs_time_mark,
truncate=True,
show_actions=True,
)
# 构建简短版本 - 使用最新一半的消息
half_count = len(self.talking_message) // 2
recent_messages = self.talking_message[-half_count:] if half_count > 0 else self.talking_message
self.talking_message_str_short = build_readable_messages(
messages=recent_messages,
timestamp_mode="lite",
read_mark=last_obs_time_mark,
show_actions=True,
)
self.talking_message_str_truncate_short = build_readable_messages(
messages=recent_messages,
timestamp_mode="normal_no_YMD",
read_mark=last_obs_time_mark,
truncate=True,
show_actions=True,
)
self.person_list = await get_person_id_list(self.talking_message)
# logger.debug(
# f"Chat {self.chat_id} - 现在聊天内容:{self.talking_message_str}"
# )
async def has_new_messages_since(self, timestamp: float) -> bool:
"""检查指定时间戳之后是否有新消息"""
count = num_new_messages_since(chat_id=self.chat_id, timestamp_start=timestamp)
return count > 0

View File

@@ -1,25 +0,0 @@
# 定义了来自外部世界的信息
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime
from src.common.logger import get_logger
logger = get_logger("observation")
# 所有观察的基类
class Observation:
def __init__(self, observe_id):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
def to_dict(self) -> dict:
"""将观察对象转换为可序列化的字典"""
return {
"observe_info": self.observe_info,
"observe_id": self.observe_id,
"last_observe_time": self.last_observe_time,
}
async def observe(self):
pass

View File

@@ -1,34 +0,0 @@
# 定义了来自外部世界的信息
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime
from src.common.logger import get_logger
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
from typing import List
# Import the new utility function
logger = get_logger("observation")
# 所有观察的基类
class WorkingMemoryObservation:
def __init__(self, observe_id):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp()
self.working_memory = WorkingMemory(chat_id=observe_id)
self.retrieved_working_memory = []
def get_observe_info(self):
return self.working_memory
def add_retrieved_working_memory(self, retrieved_working_memory: List[MemoryItem]):
self.retrieved_working_memory.append(retrieved_working_memory)
def get_retrieved_working_memory(self):
return self.retrieved_working_memory
async def observe(self):
pass

View File

@@ -1,84 +0,0 @@
from typing import Tuple
import time
import random
import string
class MemoryItem:
"""记忆项类,用于存储单个记忆的所有相关信息"""
def __init__(self, summary: str, from_source: str = "", brief: str = ""):
"""
初始化记忆项
Args:
summary: 记忆内容概括
from_source: 数据来源
brief: 记忆内容主题
"""
# 生成可读ID时间戳_随机字符串
timestamp = int(time.time())
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
self.id = f"{timestamp}_{random_str}"
self.from_source = from_source
self.brief = brief
self.timestamp = time.time()
# 记忆内容概括
self.summary = summary
# 记忆精简次数
self.compress_count = 0
# 记忆提取次数
self.retrieval_count = 0
# 记忆强度 (初始为10)
self.memory_strength = 10.0
# 记忆操作历史记录
# 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...]
self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)]
def matches_source(self, source: str) -> bool:
"""检查来源是否匹配"""
return self.from_source == source
def increase_strength(self, amount: float) -> None:
"""增加记忆强度"""
self.memory_strength = min(10.0, self.memory_strength + amount)
# 记录操作历史
self.record_operation("strengthen")
def decrease_strength(self, amount: float) -> None:
"""减少记忆强度"""
self.memory_strength = max(0.1, self.memory_strength - amount)
# 记录操作历史
self.record_operation("weaken")
def increase_compress_count(self) -> None:
"""增加精简次数并减弱记忆强度"""
self.compress_count += 1
# 记录操作历史
self.record_operation("compress")
def record_retrieval(self) -> None:
"""记录记忆被提取的情况"""
self.retrieval_count += 1
# 提取后强度翻倍
self.memory_strength = min(10.0, self.memory_strength * 2)
# 记录操作历史
self.record_operation("retrieval")
def record_operation(self, operation_type: str) -> None:
"""记录操作历史"""
current_time = time.time()
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
def to_tuple(self) -> Tuple[str, str, float, str]:
"""转换为元组格式(为了兼容性)"""
return (self.summary, self.from_source, self.timestamp, self.id)
def is_memory_valid(self) -> bool:
"""检查记忆是否有效强度是否大于等于1"""
return self.memory_strength >= 1.0

View File

@@ -1,413 +0,0 @@
from typing import Dict, TypeVar, List, Optional
import traceback
from json_repair import repair_json
from rich.traceback import install
from src.common.logger import get_logger
from src.llm_models.utils_model import LLMRequest
from src.config.config import global_config
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
import json # 添加json模块导入
install(extra_lines=3)
logger = get_logger("working_memory")
T = TypeVar("T")
class MemoryManager:
def __init__(self, chat_id: str):
"""
初始化工作记忆
Args:
chat_id: 关联的聊天ID用于标识该工作记忆属于哪个聊天
"""
# 关联的聊天ID
self._chat_id = chat_id
# 记忆项列表
self._memories: List[MemoryItem] = []
# ID到记忆项的映射
self._id_map: Dict[str, MemoryItem] = {}
self.llm_summarizer = LLMRequest(
model=global_config.model.focus_working_memory,
temperature=0.3,
request_type="focus.processor.working_memory",
)
@property
def chat_id(self) -> str:
"""获取关联的聊天ID"""
return self._chat_id
@chat_id.setter
def chat_id(self, value: str):
"""设置关联的聊天ID"""
self._chat_id = value
def push_item(self, memory_item: MemoryItem) -> str:
"""
推送一个已创建的记忆项到工作记忆中
Args:
memory_item: 要存储的记忆项
Returns:
记忆项的ID
"""
# 添加到内存和ID映射
self._memories.append(memory_item)
self._id_map[memory_item.id] = memory_item
return memory_item.id
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
"""
通过ID获取记忆项
Args:
memory_id: 记忆项ID
Returns:
找到的记忆项如果不存在则返回None
"""
memory_item = self._id_map.get(memory_id)
if memory_item:
# 检查记忆强度如果小于1则删除
if not memory_item.is_memory_valid():
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
self.delete(memory_id)
return None
return memory_item
def get_all_items(self) -> List[MemoryItem]:
"""获取所有记忆项"""
return list(self._id_map.values())
def find_items(
self,
source: Optional[str] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
memory_id: Optional[str] = None,
limit: Optional[int] = None,
newest_first: bool = False,
min_strength: float = 0.0,
) -> List[MemoryItem]:
"""
按条件查找记忆项
Args:
source: 数据来源
start_time: 开始时间戳
end_time: 结束时间戳
memory_id: 特定记忆项ID
limit: 返回结果的最大数量
newest_first: 是否按最新优先排序
min_strength: 最小记忆强度
Returns:
符合条件的记忆项列表
"""
# 如果提供了特定ID直接查找
if memory_id:
item = self.get_by_id(memory_id)
return [item] if item else []
results = []
# 获取所有项目
items = self._memories
# 如果需要最新优先,则反转遍历顺序
if newest_first:
items_to_check = list(reversed(items))
else:
items_to_check = items
# 遍历项目
for item in items_to_check:
# 检查来源是否匹配
if source is not None and not item.matches_source(source):
continue
# 检查时间范围
if start_time is not None and item.timestamp < start_time:
continue
if end_time is not None and item.timestamp > end_time:
continue
# 检查记忆强度
if min_strength > 0 and item.memory_strength < min_strength:
continue
# 所有条件都满足,添加到结果中
results.append(item)
# 如果达到限制数量,提前返回
if limit is not None and len(results) >= limit:
return results
return results
async def summarize_memory_item(self, content: str) -> Dict[str, str]:
"""
使用LLM总结记忆项
Args:
content: 需要总结的内容
Returns:
包含brief和summary的字典
"""
prompt = f"""请对以下内容进行总结,总结成记忆,输出两部分:
1. 记忆内容主题精简20字以内让用户可以一眼看出记忆内容是什么
2. 记忆内容概括对内容进行概括保留重要信息200字以内
内容:
{content}
请按以下JSON格式输出
{{
"brief": "记忆内容主题",
"summary": "记忆内容概括"
}}
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
default_summary = {
"brief": "主题未知的记忆",
"summary": "无法概括的记忆内容",
}
try:
# 调用LLM生成总结
response, _ = await self.llm_summarizer.generate_response_async(prompt)
# 使用repair_json解析响应
try:
# 使用repair_json修复JSON格式
fixed_json_string = repair_json(response)
# 如果repair_json返回的是字符串需要解析为Python对象
if isinstance(fixed_json_string, str):
try:
json_result = json.loads(fixed_json_string)
except json.JSONDecodeError as decode_error:
logger.error(f"JSON解析错误: {str(decode_error)}")
return default_summary
else:
# 如果repair_json直接返回了字典对象直接使用
json_result = fixed_json_string
# 进行额外的类型检查
if not isinstance(json_result, dict):
logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
return default_summary
# 确保所有必要字段都存在且类型正确
if "brief" not in json_result or not isinstance(json_result["brief"], str):
json_result["brief"] = "主题未知的记忆"
if "summary" not in json_result or not isinstance(json_result["summary"], str):
json_result["summary"] = "无法概括的记忆内容"
return json_result
except Exception as json_error:
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
return default_summary
except Exception as e:
logger.error(f"生成总结时出错: {str(e)}")
return default_summary
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
"""
使单个记忆衰减
Args:
memory_id: 记忆ID
decay_factor: 衰减因子(0-1之间)
Returns:
是否成功衰减
"""
memory_item = self.get_by_id(memory_id)
if not memory_item:
return False
# 计算衰减量(当前强度 * (1-衰减因子)
old_strength = memory_item.memory_strength
decay_amount = old_strength * (1 - decay_factor)
# 更新强度
memory_item.memory_strength = decay_amount
return True
def delete(self, memory_id: str) -> bool:
"""
删除指定ID的记忆项
Args:
memory_id: 要删除的记忆项ID
Returns:
是否成功删除
"""
if memory_id not in self._id_map:
return False
# 获取要删除的项
self._id_map[memory_id]
# 从内存中删除
self._memories = [i for i in self._memories if i.id != memory_id]
# 从ID映射中删除
del self._id_map[memory_id]
return True
def clear(self) -> None:
"""清除所有记忆"""
self._memories.clear()
self._id_map.clear()
async def merge_memories(
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
) -> MemoryItem:
"""
合并两个记忆项
Args:
memory_id1: 第一个记忆项ID
memory_id2: 第二个记忆项ID
reason: 合并原因
delete_originals: 是否删除原始记忆默认为True
Returns:
合并后的记忆项
"""
# 获取两个记忆项
memory_item1 = self.get_by_id(memory_id1)
memory_item2 = self.get_by_id(memory_id2)
if not memory_item1 or not memory_item2:
raise ValueError("无法找到指定的记忆项")
# 构建合并提示
prompt = f"""
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
合并时保留两段记忆的重要信息,避免重复,确保生成的内容连贯、自然。
合并原因:{reason}
记忆1主题{memory_item1.brief}
记忆1内容{memory_item1.summary}
记忆2主题{memory_item2.brief}
记忆2内容{memory_item2.summary}
请按以下JSON格式输出合并结果
{{
"brief": "合并后的主题20字以内",
"summary": "合并后的内容概括200字以内"
}}
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
# 默认合并结果
default_merged = {
"brief": f"合并:{memory_item1.brief} + {memory_item2.brief}",
"summary": f"合并的记忆:{memory_item1.summary}\n{memory_item2.summary}",
}
try:
# 调用LLM合并记忆
response, _ = await self.llm_summarizer.generate_response_async(prompt)
# 处理LLM返回的合并结果
try:
# 修复JSON格式
fixed_json_string = repair_json(response)
# 将修复后的字符串解析为Python对象
if isinstance(fixed_json_string, str):
try:
merged_data = json.loads(fixed_json_string)
except json.JSONDecodeError as decode_error:
logger.error(f"JSON解析错误: {str(decode_error)}")
merged_data = default_merged
else:
# 如果repair_json直接返回了字典对象直接使用
merged_data = fixed_json_string
# 确保是字典类型
if not isinstance(merged_data, dict):
logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
merged_data = default_merged
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
merged_data["brief"] = default_merged["brief"]
if "summary" not in merged_data or not isinstance(merged_data["summary"], str):
merged_data["summary"] = default_merged["summary"]
except Exception as e:
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
traceback.print_exc()
merged_data = default_merged
except Exception as e:
logger.error(f"合并记忆调用LLM出错: {str(e)}")
traceback.print_exc()
merged_data = default_merged
# 创建新的记忆项
# 取两个记忆项中更强的来源
merged_source = (
memory_item1.from_source
if memory_item1.memory_strength >= memory_item2.memory_strength
else memory_item2.from_source
)
# 创建新的记忆项
merged_memory = MemoryItem(
summary=merged_data["summary"], from_source=merged_source, brief=merged_data["brief"]
)
# 记忆强度取两者最大值
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
# 添加到存储中
self.push_item(merged_memory)
# 如果需要,删除原始记忆
if delete_originals:
self.delete(memory_id1)
self.delete(memory_id2)
return merged_memory
def delete_earliest_memory(self) -> bool:
"""
删除最早的记忆项
Returns:
是否成功删除
"""
# 获取所有记忆项
all_memories = self.get_all_items()
if not all_memories:
return False
# 按时间戳排序,找到最早的记忆项
earliest_memory = min(all_memories, key=lambda item: item.timestamp)
# 删除最早的记忆项
return self.delete(earliest_memory.id)

View File

@@ -1,156 +0,0 @@
from typing import List, Any, Optional
import asyncio
from src.common.logger import get_logger
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
from src.config.config import global_config
logger = get_logger(__name__)
# 问题是我不知道这个manager是不是需要和其他manager统一管理因为这个manager是从属于每一个聊天流都有自己的定时任务
class WorkingMemory:
"""
工作记忆,负责协调和运作记忆
从属于特定的流用chat_id来标识
"""
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
"""
初始化工作记忆管理器
Args:
max_memories_per_chat: 每个聊天的最大记忆数量
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
"""
self.memory_manager = MemoryManager(chat_id)
# 记忆容量上限
self.max_memories_per_chat = max_memories_per_chat
# 自动衰减间隔
self.auto_decay_interval = auto_decay_interval
# 衰减任务
self.decay_task = None
# 只有在工作记忆处理器启用时才启动自动衰减任务
if global_config.focus_chat_processor.working_memory_processor:
self._start_auto_decay()
else:
logger.debug(f"工作记忆处理器已禁用,跳过启动自动衰减任务 (chat_id: {chat_id})")
def _start_auto_decay(self):
"""启动自动衰减任务"""
if self.decay_task is None:
self.decay_task = asyncio.create_task(self._auto_decay_loop())
async def _auto_decay_loop(self):
"""自动衰减循环"""
while True:
await asyncio.sleep(self.auto_decay_interval)
try:
await self.decay_all_memories()
except Exception as e:
print(f"自动衰减记忆时出错: {str(e)}")
async def add_memory(self, summary: Any, from_source: str = "", brief: str = ""):
"""
添加一段记忆到指定聊天
Args:
summary: 记忆内容
from_source: 数据来源
Returns:
记忆项
"""
# 如果是字符串类型,生成总结
memory = MemoryItem(summary, from_source, brief)
# 添加到管理器
self.memory_manager.push_item(memory)
# 如果超过最大记忆数量,删除最早的记忆
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
self.remove_earliest_memory()
return memory
def remove_earliest_memory(self):
"""
删除最早的记忆
"""
return self.memory_manager.delete_earliest_memory()
async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]:
"""
检索记忆
Args:
chat_id: 聊天ID
memory_id: 记忆ID
Returns:
检索到的记忆项如果不存在则返回None
"""
memory_item = self.memory_manager.get_by_id(memory_id)
if memory_item:
memory_item.retrieval_count += 1
memory_item.increase_strength(5)
return memory_item
return None
async def decay_all_memories(self, decay_factor: float = 0.5):
"""
对所有聊天的所有记忆进行衰减
衰减对记忆进行refine压缩强度会变为原先的0.5
Args:
decay_factor: 衰减因子(0-1之间)
"""
logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}")
all_memories = self.memory_manager.get_all_items()
for memory_item in all_memories:
# 如果压缩完小于1会被删除
memory_id = memory_item.id
self.memory_manager.decay_memory(memory_id, decay_factor)
if memory_item.memory_strength < 1:
self.memory_manager.delete(memory_id)
continue
# 计算衰减量
# if memory_item.memory_strength < 5:
# await self.memory_manager.refine_memory(
# memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
# )
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
"""合并记忆
Args:
memory_str: 记忆内容
"""
return await self.memory_manager.merge_memories(
memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
)
async def shutdown(self) -> None:
"""关闭管理器,停止所有任务"""
if self.decay_task and not self.decay_task.done():
self.decay_task.cancel()
try:
await self.decay_task
except asyncio.CancelledError:
pass
def get_all_memories(self) -> List[MemoryItem]:
"""
获取所有记忆项目
Returns:
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
"""
return self.memory_manager.get_all_items()