This commit is contained in:
SengokuCola
2025-05-16 17:15:55 +08:00
parent fdd4ac8b4f
commit d19d5fe885
37 changed files with 409 additions and 1388 deletions

View File

@@ -10,7 +10,6 @@ from src.config.config import global_config
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
from src.chat.emoji_system.emoji_manager import emoji_manager
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder,Prompt
from src.chat.focus_chat.heartFC_sender import HeartFCSender
from src.chat.utils.utils import process_llm_response
from src.chat.utils.info_catcher import info_catcher_manager
@@ -18,25 +17,16 @@ from src.manager.mood_manager import mood_manager
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
from src.config.config import global_config
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.chat.person_info.relationship_manager import relationship_manager
from src.chat.utils.utils import get_embedding
import time
from typing import Union, Optional
from src.common.database import db
from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
import random
logger = get_logger("expressor")
def init_prompt():
Prompt(
"""
@@ -59,7 +49,7 @@ def init_prompt():
""",
"default_expressor_prompt",
)
Prompt(
"""
你可以参考以下的语言习惯,如果情景合适就使用,不要盲目使用,不要生硬使用,而是结合到表达中:
@@ -280,7 +270,7 @@ class DefaultExpressor:
logger.error(f"{self.log_prefix}回复生成意外失败: {e}")
traceback.print_exc()
return None
async def build_prompt_focus(
self,
reason,
@@ -357,7 +347,7 @@ class DefaultExpressor:
template_name,
style_habbits=style_habbits_str,
grammar_habbits=grammar_habbits_str,
chat_target=chat_target_1,
chat_target=chat_target_1,
chat_info=chat_talking_prompt,
bot_name=global_config.BOT_NICKNAME,
prompt_personality="",
@@ -377,9 +367,7 @@ class DefaultExpressor:
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
)
return prompt
return prompt
# --- 发送器 (Sender) --- #
@@ -402,7 +390,7 @@ class DefaultExpressor:
if thinking_id:
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id)
else:
thinking_id = "ds"+ str(round(time.time(),2))
thinking_id = "ds" + str(round(time.time(), 2))
thinking_start_time = time.time()
if thinking_start_time is None:
@@ -514,7 +502,6 @@ class DefaultExpressor:
return bot_message
def weighted_sample_no_replacement(items, weights, k) -> list:
"""
加权且不放回地随机抽取k个元素。
@@ -548,4 +535,5 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
break
return selected
init_prompt()
init_prompt()

View File

@@ -17,7 +17,6 @@ from src.chat.focus_chat.info_processors.mind_processor import MindProcessor
from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
from src.chat.focus_chat.memory_activator import MemoryActivator
@@ -26,6 +25,7 @@ from src.chat.focus_chat.info_processors.self_processor import SelfProcessor
from src.chat.focus_chat.planners.planner import ActionPlanner
from src.chat.focus_chat.planners.action_manager import ActionManager
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
install(extra_lines=3)
@@ -85,17 +85,19 @@ class HeartFChatting:
self.log_prefix: str = str(chat_id) # Initial default, will be updated
self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id)
self.chatting_observation = observations[0]
self.memory_activator = MemoryActivator()
self.working_memory = WorkingMemory(chat_id=self.stream_id)
self.working_observation = WorkingMemoryObservation(observe_id=self.stream_id, working_memory=self.working_memory)
self.working_observation = WorkingMemoryObservation(
observe_id=self.stream_id, working_memory=self.working_memory
)
self.expressor = DefaultExpressor(chat_id=self.stream_id)
self.action_manager = ActionManager()
self.action_planner = ActionPlanner(log_prefix=self.log_prefix, action_manager=self.action_manager)
self.hfcloop_observation.set_action_manager(self.action_manager)
self.all_observations = observations
# --- 处理器列表 ---
self.processors: List[BaseProcessor] = []
@@ -369,7 +371,7 @@ class HeartFChatting:
}
self.all_observations = observations
with Timer("回忆", cycle_timers):
running_memorys = await self.memory_activator.activate_memory(observations)

View File

@@ -12,7 +12,6 @@ from src.chat.utils.utils import get_recent_group_speaker
from src.manager.mood_manager import mood_manager
from src.chat.memory_system.Hippocampus import HippocampusManager
from src.chat.knowledge.knowledge_lib import qa_manager
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
import random
@@ -20,7 +19,6 @@ logger = get_logger("prompt")
def init_prompt():
Prompt(
"""
你有以下信息可供参考:
@@ -521,5 +519,6 @@ class PromptBuilder:
# 返回所有找到的内容,用换行分隔
return "\n".join(str(result["content"]) for result in results)
init_prompt()
prompt_builder = PromptBuilder()

View File

@@ -17,7 +17,7 @@ class InfoBase:
type: str = "base"
data: Dict[str, Any] = field(default_factory=dict)
processed_info:str = ""
processed_info: str = ""
def get_type(self) -> str:
"""获取信息类型

View File

@@ -1,5 +1,4 @@
from typing import Dict, Any
from dataclasses import dataclass, field
from dataclasses import dataclass
from .info_base import InfoBase
@@ -31,7 +30,7 @@ class SelfInfo(InfoBase):
self_info: 要设置的思维状态
"""
self.data["self_info"] = self_info
def get_processed_info(self) -> str:
"""获取处理后的信息

View File

@@ -5,10 +5,9 @@ from .info_base import InfoBase
@dataclass
class WorkingMemoryInfo(InfoBase):
type: str = "workingmemory"
processed_info:str = ""
processed_info: str = ""
def set_talking_message(self, message: str) -> None:
"""设置说话消息
@@ -25,7 +24,7 @@ class WorkingMemoryInfo(InfoBase):
working_memory (str): 工作记忆内容
"""
self.data["working_memory"] = working_memory
def add_working_memory(self, working_memory: str) -> None:
"""添加工作记忆
@@ -37,7 +36,7 @@ class WorkingMemoryInfo(InfoBase):
working_memory_list.append(working_memory)
# print(f"working_memory_list: {working_memory_list}")
self.data["working_memory"] = working_memory_list
def get_working_memory(self) -> List[str]:
"""获取工作记忆
@@ -72,7 +71,7 @@ class WorkingMemoryInfo(InfoBase):
Optional[str]: 属性值,如果键不存在则返回 None
"""
return self.data.get(key)
def get_processed_info(self) -> Dict[str, str]:
"""获取处理后的信息
@@ -84,7 +83,7 @@ class WorkingMemoryInfo(InfoBase):
memory_str = ""
for memory in all_memory:
memory_str += f"{memory}\n"
self.processed_info = memory_str
return self.processed_info

View File

@@ -55,7 +55,7 @@ class ChattingInfoProcessor(BaseProcessor):
# print(f"obs: {obs}")
if isinstance(obs, ChattingObservation):
# print("1111111111111111111111读取111111111111111")
obs_info = ObsInfo()
await self.chat_compress(obs)

View File

@@ -6,11 +6,9 @@ import time
import traceback
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
import random
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import safe_json_dumps
from src.chat.message_receive.chat_stream import chat_manager
import difflib
from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor
from src.chat.focus_chat.info.mind_info import MindInfo
@@ -202,7 +200,6 @@ class MindProcessor(BaseProcessor):
for person in person_list:
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before"
logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板")
@@ -218,7 +215,6 @@ class MindProcessor(BaseProcessor):
chat_target_name=chat_target_name,
)
content = "(不知道该想些什么...)"
try:
content, _ = await self.llm_model.generate_response_async(prompt=prompt)

View File

@@ -6,14 +6,10 @@ import time
import traceback
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
import random
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import safe_json_dumps
from src.chat.message_receive.chat_stream import chat_manager
import difflib
from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor
from src.chat.focus_chat.info.mind_info import MindInfo
from typing import List, Optional
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from typing import Dict
@@ -44,7 +40,6 @@ def init_prompt():
Prompt(indentify_prompt, "indentify_prompt")
class SelfProcessor(BaseProcessor):
log_prefix = "自我认同"
@@ -63,7 +58,6 @@ class SelfProcessor(BaseProcessor):
name = chat_manager.get_stream_name(self.subheartflow_id)
self.log_prefix = f"[{name}] "
async def process_info(
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
) -> List[InfoBase]:
@@ -76,7 +70,7 @@ class SelfProcessor(BaseProcessor):
List[InfoBase]: 处理后的结构化信息列表
"""
self_info_str = await self.self_indentify(observations, running_memorys)
if self_info_str:
self_info = SelfInfo()
self_info.set_self_info(self_info_str)
@@ -102,14 +96,12 @@ class SelfProcessor(BaseProcessor):
tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt
"""
memory_str = ""
if running_memorys:
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
for running_memory in running_memorys:
memory_str += f"{running_memory['topic']}: {running_memory['content']}\n"
if observations is None:
observations = []
for observation in observations:
@@ -127,8 +119,8 @@ class SelfProcessor(BaseProcessor):
chat_observe_info = observation.get_observe_info()
person_list = observation.person_list
if isinstance(observation, HFCloopObservation):
hfcloop_observe_info = observation.get_observe_info()
# hfcloop_observe_info = observation.get_observe_info()
pass
individuality = Individuality.get_instance()
personality_block = individuality.get_prompt(x_person=2, level=2)
@@ -137,7 +129,6 @@ class SelfProcessor(BaseProcessor):
for person in person_list:
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
bot_name=individuality.name,
prompt_personality=personality_block,
@@ -147,7 +138,6 @@ class SelfProcessor(BaseProcessor):
chat_observe_info=chat_observe_info,
)
content = ""
try:
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
@@ -159,7 +149,7 @@ class SelfProcessor(BaseProcessor):
logger.error(traceback.format_exc())
content = "自我识别过程中出现错误"
if content == 'None':
if content == "None":
content = ""
# 记录初步思考结果
logger.debug(f"{self.log_prefix} 自我识别prompt: \n{prompt}\n")
@@ -168,5 +158,4 @@ class SelfProcessor(BaseProcessor):
return content
init_prompt()

View File

@@ -4,7 +4,7 @@ from src.config.config import global_config
import time
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.tools.tool_use import ToolUser
from src.chat.utils.json_utils import process_llm_tool_calls
from src.chat.person_info.relationship_manager import relationship_manager
@@ -68,7 +68,7 @@ class ToolProcessor(BaseProcessor):
"""
working_infos = []
if observations:
for observation in observations:
if isinstance(observation, ChattingObservation):
@@ -134,7 +134,7 @@ class ToolProcessor(BaseProcessor):
# 获取个性信息
individuality = Individuality.get_instance()
prompt_personality = individuality.get_prompt(x_person=2, level=2)
# prompt_personality = individuality.get_prompt(x_person=2, level=2)
# 获取时间信息
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())

View File

@@ -5,17 +5,11 @@ from src.config.config import global_config
import time
import traceback
from src.common.logger_manager import get_logger
from src.individuality.individuality import Individuality
import random
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.chat.utils.json_utils import safe_json_dumps
from src.chat.message_receive.chat_stream import chat_manager
import difflib
from src.chat.person_info.relationship_manager import relationship_manager
from .base_processor import BaseProcessor
from src.chat.focus_chat.info.mind_info import MindInfo
from typing import List, Optional
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from typing import Dict
@@ -76,8 +70,6 @@ class WorkingMemoryProcessor(BaseProcessor):
name = chat_manager.get_stream_name(self.subheartflow_id)
self.log_prefix = f"[{name}] "
async def process_info(
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
) -> List[InfoBase]:
@@ -95,11 +87,11 @@ class WorkingMemoryProcessor(BaseProcessor):
for observation in observations:
if isinstance(observation, WorkingMemoryObservation):
working_memory = observation.get_observe_info()
working_memory_obs = observation
# working_memory_obs = observation
if isinstance(observation, ChattingObservation):
chat_info = observation.get_observe_info()
# chat_info_truncate = observation.talking_message_str_truncate
if not working_memory:
logger.warning(f"{self.log_prefix} 没有找到工作记忆对象")
mind_info = MindInfo()
@@ -108,44 +100,42 @@ class WorkingMemoryProcessor(BaseProcessor):
logger.error(f"{self.log_prefix} 处理观察时出错: {e}")
logger.error(traceback.format_exc())
return []
all_memory = working_memory.get_all_memories()
memory_prompts = []
for memory in all_memory:
memory_content = memory.data
# memory_content = memory.data
memory_summary = memory.summary
memory_id = memory.id
memory_brief = memory_summary.get("brief")
memory_detailed = memory_summary.get("detailed")
# memory_detailed = memory_summary.get("detailed")
memory_keypoints = memory_summary.get("keypoints")
memory_events = memory_summary.get("events")
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
memory_prompts.append(memory_single_prompt)
memory_choose_str = "".join(memory_prompts)
# 使用提示模板进行处理
prompt = (await global_prompt_manager.get_prompt_async("prompt_memory_proces")).format(
bot_name=global_config.BOT_NICKNAME,
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
chat_observe_info=chat_info,
memory_str=memory_choose_str
memory_str=memory_choose_str,
)
# 调用LLM处理记忆
content = ""
try:
logger.debug(f"{self.log_prefix} 处理工作记忆的prompt: {prompt}")
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
if not content:
logger.warning(f"{self.log_prefix} LLM返回空结果处理工作记忆失败。")
except Exception as e:
logger.error(f"{self.log_prefix} 执行LLM请求或处理响应时出错: {e}")
logger.error(traceback.format_exc())
# 解析LLM返回的JSON
try:
result = repair_json(content)
@@ -154,7 +144,7 @@ class WorkingMemoryProcessor(BaseProcessor):
if not isinstance(result, dict):
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败结果不是字典类型: {type(result)}")
return []
selected_memory_ids = result.get("selected_memory_ids", [])
new_memory = result.get("new_memory", "")
merge_memory = result.get("merge_memory", [])
@@ -162,20 +152,20 @@ class WorkingMemoryProcessor(BaseProcessor):
logger.error(f"{self.log_prefix} 解析LLM返回的JSON失败: {e}")
logger.error(traceback.format_exc())
return []
logger.debug(f"{self.log_prefix} 解析LLM返回的JSON成功: {result}")
# 根据selected_memory_ids调取记忆
memory_str = ""
if selected_memory_ids:
for memory_id in selected_memory_ids:
memory = await working_memory.retrieve_memory(memory_id)
if memory:
memory_content = memory.data
# memory_content = memory.data
memory_summary = memory.summary
memory_id = memory.id
memory_brief = memory_summary.get("brief")
memory_detailed = memory_summary.get("detailed")
# memory_detailed = memory_summary.get("detailed")
memory_keypoints = memory_summary.get("keypoints")
memory_events = memory_summary.get("events")
for keypoint in memory_keypoints:
@@ -184,21 +174,20 @@ class WorkingMemoryProcessor(BaseProcessor):
memory_str += f"记忆事件:{event}\n"
# memory_str += f"记忆摘要:{memory_detailed}\n"
# memory_str += f"记忆主题:{memory_brief}\n"
working_memory_info = WorkingMemoryInfo()
if memory_str:
working_memory_info.add_working_memory(memory_str)
logger.debug(f"{self.log_prefix} 取得工作记忆: {memory_str}")
else:
logger.warning(f"{self.log_prefix} 没有找到工作记忆")
# 根据聊天内容添加新记忆
if new_memory:
# 使用异步方式添加新记忆,不阻塞主流程
logger.debug(f"{self.log_prefix} {new_memory}新记忆: ")
asyncio.create_task(self.add_memory_async(working_memory, chat_info))
if merge_memory:
for merge_pairs in merge_memory:
memory1 = await working_memory.retrieve_memory(merge_pairs[0])
@@ -207,12 +196,12 @@ class WorkingMemoryProcessor(BaseProcessor):
memory_str = f"记忆id:{memory1.id},记忆摘要:{memory1.summary.get('brief')}\n"
memory_str += f"记忆id:{memory2.id},记忆摘要:{memory2.summary.get('brief')}\n"
asyncio.create_task(self.merge_memory_async(working_memory, merge_pairs[0], merge_pairs[1]))
return [working_memory_info]
async def add_memory_async(self, working_memory: WorkingMemory, content: str):
"""异步添加记忆,不阻塞主流程
Args:
working_memory: 工作记忆对象
content: 记忆内容
@@ -223,10 +212,10 @@ class WorkingMemoryProcessor(BaseProcessor):
except Exception as e:
logger.error(f"{self.log_prefix} 异步添加新记忆失败: {e}")
logger.error(traceback.format_exc())
async def merge_memory_async(self, working_memory: WorkingMemory, memory_id1: str, memory_id2: str):
"""异步合并记忆,不阻塞主流程
Args:
working_memory: 工作记忆对象
memory_str: 记忆内容
@@ -238,7 +227,7 @@ class WorkingMemoryProcessor(BaseProcessor):
logger.debug(f"{self.log_prefix} 合并后的记忆详情: {merged_memory.summary.get('detailed')}")
logger.debug(f"{self.log_prefix} 合并后的记忆要点: {merged_memory.summary.get('keypoints')}")
logger.debug(f"{self.log_prefix} 合并后的记忆事件: {merged_memory.summary.get('events')}")
except Exception as e:
logger.error(f"{self.log_prefix} 异步合并记忆失败: {e}")
logger.error(traceback.format_exc())

View File

@@ -37,7 +37,7 @@ class ActionManager:
# 加载所有已注册动作
self._load_registered_actions()
# 加载插件动作
self._load_plugin_actions()
@@ -52,11 +52,11 @@ class ActionManager:
# 从_ACTION_REGISTRY获取所有已注册动作
for action_name, action_class in _ACTION_REGISTRY.items():
# 获取动作相关信息
# 不读取插件动作和基类
if action_name == "base_action" or action_name == "plugin_action":
continue
action_description: str = getattr(action_class, "action_description", "")
action_parameters: dict[str:str] = getattr(action_class, "action_parameters", {})
action_require: list[str] = getattr(action_class, "action_require", [])
@@ -80,11 +80,11 @@ class ActionManager:
# logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
# logger.info(f"默认动作: {list(self._default_actions.keys())}")
# for action_name, action_info in self._default_actions.items():
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
except Exception as e:
logger.error(f"加载已注册动作失败: {e}")
def _load_plugin_actions(self) -> None:
"""
加载所有插件目录中的动作
@@ -92,23 +92,25 @@ class ActionManager:
try:
# 检查插件目录是否存在
plugin_path = "src.plugins"
plugin_dir = plugin_path.replace('.', os.path.sep)
plugin_dir = plugin_path.replace(".", os.path.sep)
if not os.path.exists(plugin_dir):
logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载")
return
# 导入插件包
try:
plugins_package = importlib.import_module(plugin_path)
except ImportError as e:
logger.error(f"导入插件包失败: {e}")
return
# 遍历插件包中的所有子包
for _, plugin_name, is_pkg in pkgutil.iter_modules(plugins_package.__path__, plugins_package.__name__ + '.'):
for _, plugin_name, is_pkg in pkgutil.iter_modules(
plugins_package.__path__, plugins_package.__name__ + "."
):
if not is_pkg:
continue
# 检查插件是否有actions子包
plugin_actions_path = f"{plugin_name}.actions"
try:
@@ -118,10 +120,10 @@ class ActionManager:
except ImportError as e:
logger.debug(f"插件 {plugin_name} 没有actions子包或导入失败: {e}")
continue
# 再次从_ACTION_REGISTRY获取所有动作包括刚刚从插件加载的
self._load_registered_actions()
except Exception as e:
logger.error(f"加载插件动作失败: {e}")
@@ -316,4 +318,3 @@ class ActionManager:
Optional[Type[BaseAction]]: 动作处理器类如果不存在则返回None
"""
return _ACTION_REGISTRY.get(action_name)

View File

@@ -2,4 +2,4 @@
from . import reply_action # noqa
from . import no_reply_action # noqa
# 在此处添加更多动作模块导入
# 在此处添加更多动作模块导入

View File

@@ -94,8 +94,7 @@ class NoReplyAction(BaseAction):
# 等待新消息、超时或关闭信号,并获取结果
await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix)
# 从计时器获取实际等待时间
current_waiting = self.cycle_timers.get("等待新消息", 0.0)
_current_waiting = self.cycle_timers.get("等待新消息", 0.0)
return True, "" # 不回复动作没有回复文本

View File

@@ -1,6 +1,6 @@
import traceback
from typing import Tuple, Dict, List, Any, Optional
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
from src.chat.focus_chat.planners.actions.base_action import BaseAction
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
from src.common.logger_manager import get_logger
@@ -9,19 +9,20 @@ from abc import abstractmethod
logger = get_logger("plugin_action")
class PluginAction(BaseAction):
"""插件动作基类
封装了主程序内部依赖提供简化的API接口给插件开发者
"""
def __init__(self, action_data: dict, reasoning: str, cycle_timers: dict, thinking_id: str, **kwargs):
"""初始化插件动作基类"""
super().__init__(action_data, reasoning, cycle_timers, thinking_id)
# 存储内部服务和对象引用
self._services = {}
# 从kwargs提取必要的内部服务
if "observations" in kwargs:
self._services["observations"] = kwargs["observations"]
@@ -31,48 +32,43 @@ class PluginAction(BaseAction):
self._services["chat_stream"] = kwargs["chat_stream"]
if "current_cycle" in kwargs:
self._services["current_cycle"] = kwargs["current_cycle"]
self.log_prefix = kwargs.get("log_prefix", "")
async def get_user_id_by_person_name(self, person_name: str) -> Tuple[str, str]:
"""根据用户名获取用户ID"""
person_id = person_info_manager.get_person_id_by_person_name(person_name)
user_id = await person_info_manager.get_value(person_id, "user_id")
platform = await person_info_manager.get_value(person_id, "platform")
return platform, user_id
# 提供简化的API方法
async def send_message(self, text: str, target: Optional[str] = None) -> bool:
"""发送消息的简化方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
try:
expressor = self._services.get("expressor")
chat_stream = self._services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
reply_data = {
"text": text,
"target": target or "",
"emojis": []
}
reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = self._services.get("observations", [])
chatting_observation: ChattingObservation = next(
obs for obs in observations
if isinstance(obs, ChattingObservation)
obs for obs in observations if isinstance(obs, ChattingObservation)
)
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
@@ -84,55 +80,49 @@ class PluginAction(BaseAction):
)
else:
anchor_message.update_chat_stream(chat_stream)
response_set = [
("text", text),
]
# 调用内部方法发送消息
success = await expressor.send_response_messages(
anchor_message=anchor_message,
response_set=response_set,
)
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
traceback.print_exc()
return False
async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
"""发送消息的简化方法
Args:
text: 要发送的消息文本
target: 目标消息(可选)
Returns:
bool: 是否发送成功
"""
try:
expressor = self._services.get("expressor")
chat_stream = self._services.get("chat_stream")
if not expressor or not chat_stream:
logger.error(f"{self.log_prefix} 无法发送消息:缺少必要的内部服务")
return False
# 构造简化的动作数据
reply_data = {
"text": text,
"target": target or "",
"emojis": []
}
reply_data = {"text": text, "target": target or "", "emojis": []}
# 获取锚定消息(如果有)
observations = self._services.get("observations", [])
chatting_observation: ChattingObservation = next(
obs for obs in observations
if isinstance(obs, ChattingObservation)
obs for obs in observations if isinstance(obs, ChattingObservation)
)
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
@@ -144,24 +134,24 @@ class PluginAction(BaseAction):
)
else:
anchor_message.update_chat_stream(chat_stream)
# 调用内部方法发送消息
success, _ = await expressor.deal_reply(
cycle_timers=self.cycle_timers,
action_data=reply_data,
anchor_message=anchor_message,
reasoning=self.reasoning,
thinking_id=self.thinking_id
thinking_id=self.thinking_id,
)
return success
except Exception as e:
logger.error(f"{self.log_prefix} 发送消息时出错: {e}")
return False
def get_chat_type(self) -> str:
"""获取当前聊天类型
Returns:
str: 聊天类型 ("group""private")
"""
@@ -169,19 +159,19 @@ class PluginAction(BaseAction):
if chat_stream and hasattr(chat_stream, "group_info"):
return "group" if chat_stream.group_info else "private"
return "unknown"
def get_recent_messages(self, count: int = 5) -> List[Dict[str, Any]]:
"""获取最近的消息
Args:
count: 要获取的消息数量
Returns:
List[Dict]: 消息列表,每个消息包含发送者、内容等信息
"""
messages = []
observations = self._services.get("observations", [])
if observations and len(observations) > 0:
obs = observations[0]
if hasattr(obs, "get_talking_message"):
@@ -191,24 +181,24 @@ class PluginAction(BaseAction):
simple_msg = {
"sender": msg.get("sender", "未知"),
"content": msg.get("content", ""),
"timestamp": msg.get("timestamp", 0)
"timestamp": msg.get("timestamp", 0),
}
messages.append(simple_msg)
return messages
@abstractmethod
async def process(self) -> Tuple[bool, str]:
"""插件处理逻辑,子类必须实现此方法
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""
pass
async def handle_action(self) -> Tuple[bool, str]:
"""实现BaseAction的抽象方法调用子类的process方法
Returns:
Tuple[bool, str]: (是否执行成功, 回复文本)
"""

View File

@@ -105,8 +105,7 @@ class ReplyAction(BaseAction):
# 从聊天观察获取锚定消息
chatting_observation: ChattingObservation = next(
obs for obs in self.observations
if isinstance(obs, ChattingObservation)
obs for obs in self.observations if isinstance(obs, ChattingObservation)
)
if reply_data.get("target"):
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])

View File

@@ -109,7 +109,7 @@ class ActionPlanner:
cycle_info = info.get_observe_info()
elif isinstance(info, StructuredInfo):
# logger.debug(f"{self.log_prefix} 结构化信息: {info}")
structured_info = info.get_data()
_structured_info = info.get_data()
else:
logger.debug(f"{self.log_prefix} 其他信息: {info}")
extra_info.append(info.get_processed_info())
@@ -157,7 +157,7 @@ class ActionPlanner:
for key, value in parsed_json.items():
if key not in ["action", "reasoning"]:
action_data[key] = value
# 对于reply动作不需要额外处理因为相关字段已经在上面的循环中添加到action_data
if extracted_action not in current_available_actions:

View File

@@ -1,23 +1,16 @@
from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple
from typing import Dict, Any, List, Optional, Set, Tuple
import time
import uuid
import traceback
import random
import string
from json_repair import repair_json
from rich.traceback import install
from src.common.logger_manager import get_logger
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
class MemoryItem:
"""记忆项类,用于存储单个记忆的所有相关信息"""
def __init__(self, data: Any, from_source: str = "", tags: Optional[List[str]] = None):
"""
初始化记忆项
Args:
data: 记忆数据
from_source: 数据来源
@@ -25,7 +18,7 @@ class MemoryItem:
"""
# 生成可读ID时间戳_随机字符串
timestamp = int(time.time())
random_str = ''.join(random.choices(string.ascii_lowercase + string.digits, k=2))
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
self.id = f"{timestamp}_{random_str}"
self.data = data
self.data_type = type(data)
@@ -40,63 +33,63 @@ class MemoryItem:
# "events": ["事件1", "事件2"]
# }
self.summary = None
# 记忆精简次数
self.compress_count = 0
# 记忆提取次数
self.retrieval_count = 0
# 记忆强度 (初始为10)
self.memory_strength = 10.0
# 记忆操作历史记录
# 格式: [(操作类型, 时间戳, 当时精简次数, 当时强度), ...]
self.history = [("create", self.timestamp, self.compress_count, self.memory_strength)]
def add_tag(self, tag: str) -> None:
"""添加标签"""
self.tags.add(tag)
def remove_tag(self, tag: str) -> None:
"""移除标签"""
if tag in self.tags:
self.tags.remove(tag)
def has_tag(self, tag: str) -> bool:
"""检查是否有特定标签"""
return tag in self.tags
def has_all_tags(self, tags: List[str]) -> bool:
"""检查是否有所有指定的标签"""
return all(tag in self.tags for tag in tags)
def matches_source(self, source: str) -> bool:
"""检查来源是否匹配"""
return self.from_source == source
def set_summary(self, summary: Dict[str, Any]) -> None:
"""设置总结信息"""
self.summary = summary
def increase_strength(self, amount: float) -> None:
"""增加记忆强度"""
self.memory_strength = min(10.0, self.memory_strength + amount)
# 记录操作历史
self.record_operation("strengthen")
def decrease_strength(self, amount: float) -> None:
"""减少记忆强度"""
self.memory_strength = max(0.1, self.memory_strength - amount)
# 记录操作历史
self.record_operation("weaken")
def increase_compress_count(self) -> None:
"""增加精简次数并减弱记忆强度"""
self.compress_count += 1
# 记录操作历史
self.record_operation("compress")
def record_retrieval(self) -> None:
"""记录记忆被提取的情况"""
self.retrieval_count += 1
@@ -104,16 +97,16 @@ class MemoryItem:
self.memory_strength = min(10.0, self.memory_strength * 2)
# 记录操作历史
self.record_operation("retrieval")
def record_operation(self, operation_type: str) -> None:
"""记录操作历史"""
current_time = time.time()
self.history.append((operation_type, current_time, self.compress_count, self.memory_strength))
def to_tuple(self) -> Tuple[Any, str, Set[str], float, str]:
"""转换为元组格式(为了兼容性)"""
return (self.data, self.from_source, self.tags, self.timestamp, self.id)
def is_memory_valid(self) -> bool:
"""检查记忆是否有效强度是否大于等于1"""
return self.memory_strength >= 1.0
return self.memory_strength >= 1.0

View File

@@ -1,6 +1,4 @@
from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple
import time
import uuid
from typing import Dict, Any, Type, TypeVar, List, Optional
import traceback
from json_repair import repair_json
from rich.traceback import install
@@ -14,74 +12,71 @@ import json # 添加json模块导入
install(extra_lines=3)
logger = get_logger("working_memory")
T = TypeVar('T')
T = TypeVar("T")
class MemoryManager:
def __init__(self, chat_id: str):
"""
初始化工作记忆
Args:
chat_id: 关联的聊天ID用于标识该工作记忆属于哪个聊天
"""
# 关联的聊天ID
self._chat_id = chat_id
# 主存储: 数据类型 -> 记忆项列表
self._memory: Dict[Type, List[MemoryItem]] = {}
# ID到记忆项的映射
self._id_map: Dict[str, MemoryItem] = {}
self.llm_summarizer = LLMRequest(
model=global_config.llm_summary,
temperature=0.3,
max_tokens=512,
request_type="memory_summarization"
model=global_config.llm_summary, temperature=0.3, max_tokens=512, request_type="memory_summarization"
)
@property
def chat_id(self) -> str:
"""获取关联的聊天ID"""
return self._chat_id
@chat_id.setter
def chat_id(self, value: str):
"""设置关联的聊天ID"""
self._chat_id = value
def push_item(self, memory_item: MemoryItem) -> str:
"""
推送一个已创建的记忆项到工作记忆中
Args:
memory_item: 要存储的记忆项
Returns:
记忆项的ID
"""
data_type = memory_item.data_type
# 确保存在该类型的存储列表
if data_type not in self._memory:
self._memory[data_type] = []
# 添加到内存和ID映射
self._memory[data_type].append(memory_item)
self._id_map[memory_item.id] = memory_item
return memory_item.id
async def push_with_summary(self, data: T, from_source: str = "", tags: Optional[List[str]] = None) -> MemoryItem:
"""
推送一段有类型的信息到工作记忆中,并自动生成总结
Args:
data: 要存储的数据
from_source: 数据来源
tags: 数据标签列表
Returns:
包含原始数据和总结信息的字典
"""
@@ -89,65 +84,66 @@ class MemoryManager:
if isinstance(data, str):
# 先生成总结
summary = await self.summarize_memory_item(data)
# 准备标签
memory_tags = list(tags) if tags else []
# 创建记忆项
memory_item = MemoryItem(data, from_source, memory_tags)
# 将总结信息保存到记忆项中
memory_item.set_summary(summary)
# 推送记忆项
self.push_item(memory_item)
return memory_item
else:
# 非字符串类型,直接创建并推送记忆项
memory_item = MemoryItem(data, from_source, tags)
self.push_item(memory_item)
return memory_item
def get_by_id(self, memory_id: str) -> Optional[MemoryItem]:
"""
通过ID获取记忆项
Args:
memory_id: 记忆项ID
Returns:
找到的记忆项如果不存在则返回None
"""
memory_item = self._id_map.get(memory_id)
if memory_item:
# 检查记忆强度如果小于1则删除
if not memory_item.is_memory_valid():
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
self.delete(memory_id)
return None
return memory_item
def get_all_items(self) -> List[MemoryItem]:
"""获取所有记忆项"""
return list(self._id_map.values())
def find_items(self,
data_type: Optional[Type] = None,
source: Optional[str] = None,
tags: Optional[List[str]] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
memory_id: Optional[str] = None,
limit: Optional[int] = None,
newest_first: bool = False,
min_strength: float = 0.0) -> List[MemoryItem]:
def find_items(
self,
data_type: Optional[Type] = None,
source: Optional[str] = None,
tags: Optional[List[str]] = None,
start_time: Optional[float] = None,
end_time: Optional[float] = None,
memory_id: Optional[str] = None,
limit: Optional[int] = None,
newest_first: bool = False,
min_strength: float = 0.0,
) -> List[MemoryItem]:
"""
按条件查找记忆项
Args:
data_type: 要查找的数据类型
source: 数据来源
@@ -158,7 +154,7 @@ class MemoryManager:
limit: 返回结果的最大数量
newest_first: 是否按最新优先排序
min_strength: 最小记忆强度
Returns:
符合条件的记忆项列表
"""
@@ -166,62 +162,62 @@ class MemoryManager:
if memory_id:
item = self.get_by_id(memory_id)
return [item] if item else []
results = []
# 确定要搜索的类型列表
types_to_search = [data_type] if data_type else list(self._memory.keys())
# 对每个类型进行搜索
for typ in types_to_search:
if typ not in self._memory:
continue
# 获取该类型的所有项目
items = self._memory[typ]
# 如果需要最新优先,则反转遍历顺序
if newest_first:
items_to_check = list(reversed(items))
else:
items_to_check = items
# 遍历项目
for item in items_to_check:
# 检查来源是否匹配
if source is not None and not item.matches_source(source):
continue
# 检查标签是否匹配
if tags is not None and not item.has_all_tags(tags):
continue
# 检查时间范围
if start_time is not None and item.timestamp < start_time:
continue
if end_time is not None and item.timestamp > end_time:
continue
# 检查记忆强度
if min_strength > 0 and item.memory_strength < min_strength:
continue
# 所有条件都满足,添加到结果中
results.append(item)
# 如果达到限制数量,提前返回
if limit is not None and len(results) >= limit:
return results
return results
async def summarize_memory_item(self, content: str) -> Dict[str, Any]:
"""
使用LLM总结记忆项
Args:
content: 需要总结的内容
Returns:
包含总结、概括、关键概念和事件的字典
"""
@@ -257,18 +253,18 @@ class MemoryManager:
"brief": "主题未知的记忆",
"detailed": "大致内容未知的记忆",
"keypoints": ["未知的概念"],
"events": ["未知的事件"]
"events": ["未知的事件"],
}
try:
# 调用LLM生成总结
response, _ = await self.llm_summarizer.generate_response_async(prompt)
# 使用repair_json解析响应
try:
# 使用repair_json修复JSON格式
fixed_json_string = repair_json(response)
# 如果repair_json返回的是字符串需要解析为Python对象
if isinstance(fixed_json_string, str):
try:
@@ -279,68 +275,60 @@ class MemoryManager:
else:
# 如果repair_json直接返回了字典对象直接使用
json_result = fixed_json_string
# 进行额外的类型检查
if not isinstance(json_result, dict):
logger.error(f"修复后的JSON不是字典类型: {type(json_result)}")
return default_summary
# 确保所有必要字段都存在且类型正确
if "brief" not in json_result or not isinstance(json_result["brief"], str):
json_result["brief"] = "主题未知的记忆"
if "detailed" not in json_result or not isinstance(json_result["detailed"], str):
json_result["detailed"] = "大致内容未知的记忆"
# 处理关键概念
if "keypoints" not in json_result or not isinstance(json_result["keypoints"], list):
json_result["keypoints"] = ["未知的概念"]
else:
# 确保keypoints中的每个项目都是字符串
json_result["keypoints"] = [
str(point) for point in json_result["keypoints"]
if point is not None
]
json_result["keypoints"] = [str(point) for point in json_result["keypoints"] if point is not None]
if not json_result["keypoints"]:
json_result["keypoints"] = ["未知的概念"]
# 处理事件
if "events" not in json_result or not isinstance(json_result["events"], list):
json_result["events"] = ["未知的事件"]
else:
# 确保events中的每个项目都是字符串
json_result["events"] = [
str(event) for event in json_result["events"]
if event is not None
]
json_result["events"] = [str(event) for event in json_result["events"] if event is not None]
if not json_result["events"]:
json_result["events"] = ["未知的事件"]
# 兼容旧版将keypoints和events合并到key_points中
json_result["key_points"] = json_result["keypoints"] + json_result["events"]
return json_result
except Exception as json_error:
logger.error(f"JSON处理失败: {str(json_error)},将使用默认摘要")
# 返回默认结构
return default_summary
except Exception as e:
# 出错时返回简单的结构
logger.error(f"生成总结时出错: {str(e)}")
return default_summary
async def refine_memory(self,
memory_id: str,
requirements: str = "") -> Dict[str, Any]:
async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
"""
对记忆进行精简操作,根据要求修改要点、总结和概括
Args:
memory_id: 记忆ID
requirements: 精简要求,描述如何修改记忆,包括可能需要移除的要点
Returns:
修改后的记忆总结字典
"""
@@ -349,12 +337,12 @@ class MemoryManager:
memory_item = self.get_by_id(memory_id)
if not memory_item:
raise ValueError(f"未找到ID为{memory_id}的记忆项")
# 增加精简次数
memory_item.increase_compress_count()
summary = memory_item.summary
# 使用LLM根据要求对总结、概括和要点进行精简修改
prompt = f"""
请根据以下要求,对记忆内容的主题、概括、关键概念和事件进行精简,模拟记忆的遗忘过程:
@@ -396,15 +384,15 @@ class MemoryManager:
halfway = len(key_points) // 2
summary["keypoints"] = key_points[:halfway] or ["未知的概念"]
summary["events"] = key_points[halfway:] or ["未知的事件"]
# 定义默认的精简结果
default_refined = {
"brief": summary["brief"],
"detailed": summary["detailed"],
"keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念
"events": summary.get("events", ["未知的事件"])[:1] # 默认只保留第一个事件
"events": summary.get("events", ["未知的事件"])[:1], # 默认只保留第一个事件
}
try:
# 调用LLM修改总结、概括和要点
response, _ = await self.llm_summarizer.generate_response_async(prompt)
@@ -413,7 +401,7 @@ class MemoryManager:
try:
# 修复JSON格式
fixed_json_string = repair_json(response)
# 将修复后的字符串解析为Python对象
if isinstance(fixed_json_string, str):
try:
@@ -424,16 +412,16 @@ class MemoryManager:
else:
# 如果repair_json直接返回了字典对象直接使用
refined_data = fixed_json_string
# 确保是字典类型
if not isinstance(refined_data, dict):
logger.error(f"修复后的JSON不是字典类型: {type(refined_data)}")
refined_data = default_refined
# 更新总结、概括
summary["brief"] = refined_data.get("brief", "主题未知的记忆")
summary["detailed"] = refined_data.get("detailed", "大致内容未知的记忆")
# 更新关键概念
keypoints = refined_data.get("keypoints", [])
if isinstance(keypoints, list) and keypoints:
@@ -442,7 +430,7 @@ class MemoryManager:
else:
# 如果keypoints不是列表或为空使用默认值
summary["keypoints"] = ["主要概念已遗忘"]
# 更新事件
events = refined_data.get("events", [])
if isinstance(events, list) and events:
@@ -451,84 +439,83 @@ class MemoryManager:
else:
# 如果events不是列表或为空使用默认值
summary["events"] = ["事件细节已遗忘"]
# 兼容旧版维护key_points
summary["key_points"] = summary["keypoints"] + summary["events"]
except Exception as e:
logger.error(f"精简记忆出错: {str(e)}")
traceback.print_exc()
# 出错时使用简化的默认精简
summary["brief"] = summary["brief"] + " (已简化)"
summary["keypoints"] = summary.get("keypoints", ["未知的概念"])[:1]
summary["events"] = summary.get("events", ["未知的事件"])[:1]
summary["key_points"] = summary["keypoints"] + summary["events"]
except Exception as e:
logger.error(f"精简记忆调用LLM出错: {str(e)}")
traceback.print_exc()
# 更新原记忆项的总结
memory_item.set_summary(summary)
return memory_item
def decay_memory(self, memory_id: str, decay_factor: float = 0.8) -> bool:
"""
使单个记忆衰减
Args:
memory_id: 记忆ID
decay_factor: 衰减因子(0-1之间)
Returns:
是否成功衰减
"""
memory_item = self.get_by_id(memory_id)
if not memory_item:
return False
# 计算衰减量(当前强度 * (1-衰减因子)
old_strength = memory_item.memory_strength
decay_amount = old_strength * (1 - decay_factor)
# 更新强度
memory_item.memory_strength = decay_amount
return True
def delete(self, memory_id: str) -> bool:
"""
删除指定ID的记忆项
Args:
memory_id: 要删除的记忆项ID
Returns:
是否成功删除
"""
if memory_id not in self._id_map:
return False
# 获取要删除的项
item = self._id_map[memory_id]
# 从内存中删除
data_type = item.data_type
if data_type in self._memory:
self._memory[data_type] = [i for i in self._memory[data_type] if i.id != memory_id]
# 从ID映射中删除
del self._id_map[memory_id]
return True
def clear(self, data_type: Optional[Type] = None) -> None:
"""
清除记忆中的数据
Args:
data_type: 要清除的数据类型如果为None则清除所有数据
"""
@@ -542,34 +529,36 @@ class MemoryManager:
if item.id in self._id_map:
del self._id_map[item.id]
del self._memory[data_type]
async def merge_memories(self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True) -> MemoryItem:
async def merge_memories(
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
) -> MemoryItem:
"""
合并两个记忆项
Args:
memory_id1: 第一个记忆项ID
memory_id2: 第二个记忆项ID
reason: 合并原因
delete_originals: 是否删除原始记忆默认为True
Returns:
包含合并后的记忆信息的字典
"""
# 获取两个记忆项
memory_item1 = self.get_by_id(memory_id1)
memory_item2 = self.get_by_id(memory_id2)
if not memory_item1 or not memory_item2:
raise ValueError("无法找到指定的记忆项")
content1 = memory_item1.data
content2 = memory_item2.data
# 获取记忆的摘要信息(如果有)
summary1 = memory_item1.summary
summary2 = memory_item2.summary
# 构建合并提示
prompt = f"""
请根据以下原因,将两段记忆内容有机合并成一段新的记忆内容。
@@ -577,32 +566,32 @@ class MemoryManager:
合并原因:{reason}
"""
# 如果有摘要信息,添加到提示中
if summary1:
prompt += f"记忆1主题{summary1['brief']}\n"
prompt += f"记忆1概括{summary1['detailed']}\n"
if "keypoints" in summary1:
prompt += f"记忆1关键概念\n" + "\n".join([f"- {point}" for point in summary1['keypoints']]) + "\n\n"
prompt += "记忆1关键概念\n" + "\n".join([f"- {point}" for point in summary1["keypoints"]]) + "\n\n"
if "events" in summary1:
prompt += f"记忆1事件\n" + "\n".join([f"- {point}" for point in summary1['events']]) + "\n\n"
prompt += "记忆1事件\n" + "\n".join([f"- {point}" for point in summary1["events"]]) + "\n\n"
elif "key_points" in summary1:
prompt += f"记忆1要点\n" + "\n".join([f"- {point}" for point in summary1['key_points']]) + "\n\n"
prompt += "记忆1要点\n" + "\n".join([f"- {point}" for point in summary1["key_points"]]) + "\n\n"
if summary2:
prompt += f"记忆2主题{summary2['brief']}\n"
prompt += f"记忆2概括{summary2['detailed']}\n"
if "keypoints" in summary2:
prompt += f"记忆2关键概念\n" + "\n".join([f"- {point}" for point in summary2['keypoints']]) + "\n\n"
prompt += "记忆2关键概念\n" + "\n".join([f"- {point}" for point in summary2["keypoints"]]) + "\n\n"
if "events" in summary2:
prompt += f"记忆2事件\n" + "\n".join([f"- {point}" for point in summary2['events']]) + "\n\n"
prompt += "记忆2事件\n" + "\n".join([f"- {point}" for point in summary2["events"]]) + "\n\n"
elif "key_points" in summary2:
prompt += f"记忆2要点\n" + "\n".join([f"- {point}" for point in summary2['key_points']]) + "\n\n"
prompt += "记忆2要点\n" + "\n".join([f"- {point}" for point in summary2["key_points"]]) + "\n\n"
# 添加记忆原始内容
prompt += f"""
记忆1原始内容
@@ -630,16 +619,16 @@ class MemoryManager:
```
请确保输出是有效的JSON格式不要添加任何额外的说明或解释。
"""
# 默认合并结果
default_merged = {
"content": f"{content1}\n\n{content2}",
"brief": f"合并:{summary1['brief']} + {summary2['brief']}",
"detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}",
"keypoints": [],
"events": []
"events": [],
}
# 合并旧版key_points
if "key_points" in summary1:
default_merged["keypoints"].extend(summary1.get("keypoints", []))
@@ -650,7 +639,7 @@ class MemoryManager:
halfway = len(key_points) // 2
default_merged["keypoints"].extend(key_points[:halfway])
default_merged["events"].extend(key_points[halfway:])
if "key_points" in summary2:
default_merged["keypoints"].extend(summary2.get("keypoints", []))
default_merged["events"].extend(summary2.get("events", []))
@@ -660,25 +649,25 @@ class MemoryManager:
halfway = len(key_points) // 2
default_merged["keypoints"].extend(key_points[:halfway])
default_merged["events"].extend(key_points[halfway:])
# 确保列表不为空
if not default_merged["keypoints"]:
default_merged["keypoints"] = ["合并的关键概念"]
if not default_merged["events"]:
default_merged["events"] = ["合并的事件"]
# 添加key_points兼容
default_merged["key_points"] = default_merged["keypoints"] + default_merged["events"]
try:
# 调用LLM合并记忆
response, _ = await self.llm_summarizer.generate_response_async(prompt)
# 处理LLM返回的合并结果
try:
# 修复JSON格式
fixed_json_string = repair_json(response)
# 将修复后的字符串解析为Python对象
if isinstance(fixed_json_string, str):
try:
@@ -689,49 +678,43 @@ class MemoryManager:
else:
# 如果repair_json直接返回了字典对象直接使用
merged_data = fixed_json_string
# 确保是字典类型
if not isinstance(merged_data, dict):
logger.error(f"修复后的JSON不是字典类型: {type(merged_data)}")
merged_data = default_merged
# 确保所有必要字段都存在且类型正确
if "content" not in merged_data or not isinstance(merged_data["content"], str):
merged_data["content"] = default_merged["content"]
if "brief" not in merged_data or not isinstance(merged_data["brief"], str):
merged_data["brief"] = default_merged["brief"]
if "detailed" not in merged_data or not isinstance(merged_data["detailed"], str):
merged_data["detailed"] = default_merged["detailed"]
# 处理关键概念
if "keypoints" not in merged_data or not isinstance(merged_data["keypoints"], list):
merged_data["keypoints"] = default_merged["keypoints"]
else:
# 确保keypoints中的每个项目都是字符串
merged_data["keypoints"] = [
str(point) for point in merged_data["keypoints"]
if point is not None
]
merged_data["keypoints"] = [str(point) for point in merged_data["keypoints"] if point is not None]
if not merged_data["keypoints"]:
merged_data["keypoints"] = ["合并的关键概念"]
# 处理事件
if "events" not in merged_data or not isinstance(merged_data["events"], list):
merged_data["events"] = default_merged["events"]
else:
# 确保events中的每个项目都是字符串
merged_data["events"] = [
str(event) for event in merged_data["events"]
if event is not None
]
merged_data["events"] = [str(event) for event in merged_data["events"] if event is not None]
if not merged_data["events"]:
merged_data["events"] = ["合并的事件"]
# 添加key_points兼容
merged_data["key_points"] = merged_data["keypoints"] + merged_data["events"]
except Exception as e:
logger.error(f"合并记忆时处理JSON出错: {str(e)}")
traceback.print_exc()
@@ -740,59 +723,59 @@ class MemoryManager:
logger.error(f"合并记忆调用LLM出错: {str(e)}")
traceback.print_exc()
merged_data = default_merged
# 创建新的记忆项
# 合并记忆项的标签
merged_tags = memory_item1.tags.union(memory_item2.tags)
# 取两个记忆项中更强的来源
merged_source = memory_item1.from_source if memory_item1.memory_strength >= memory_item2.memory_strength else memory_item2.from_source
# 创建新的记忆项
merged_memory = MemoryItem(
data=merged_data["content"],
from_source=merged_source,
tags=list(merged_tags)
merged_source = (
memory_item1.from_source
if memory_item1.memory_strength >= memory_item2.memory_strength
else memory_item2.from_source
)
# 创建新的记忆项
merged_memory = MemoryItem(data=merged_data["content"], from_source=merged_source, tags=list(merged_tags))
# 设置合并后的摘要
summary = {
"brief": merged_data["brief"],
"detailed": merged_data["detailed"],
"keypoints": merged_data["keypoints"],
"events": merged_data["events"],
"key_points": merged_data["key_points"]
"key_points": merged_data["key_points"],
}
merged_memory.set_summary(summary)
# 记忆强度取两者最大值
merged_memory.memory_strength = max(memory_item1.memory_strength, memory_item2.memory_strength)
# 添加到存储中
self.push_item(merged_memory)
# 如果需要,删除原始记忆
if delete_originals:
self.delete(memory_id1)
self.delete(memory_id2)
return merged_memory
def delete_earliest_memory(self) -> bool:
"""
删除最早的记忆项
Returns:
是否成功删除
"""
# 获取所有记忆项
all_memories = self.get_all_items()
if not all_memories:
return False
# 按时间戳排序,找到最早的记忆项
earliest_memory = min(all_memories, key=lambda item: item.timestamp)
# 删除最早的记忆项
return self.delete(earliest_memory.id)
return self.delete(earliest_memory.id)

View File

@@ -1,169 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import asyncio
from typing import List, Dict, Any, Optional
from pathlib import Path
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
from src.common.logger_manager import get_logger
logger = get_logger("memory_loader")
class MemoryFileLoader:
"""从文件加载记忆内容的工具类"""
def __init__(self, working_memory: WorkingMemory):
"""
初始化记忆文件加载器
Args:
working_memory: 工作记忆实例
"""
self.working_memory = working_memory
async def load_from_directory(self,
directory_path: str,
file_pattern: str = "*.txt",
common_tags: List[str] = None,
source_prefix: str = "文件") -> List[MemoryItem]:
"""
从指定目录加载符合模式的文件作为记忆
Args:
directory_path: 目录路径
file_pattern: 文件模式(默认为*.txt
common_tags: 所有记忆共有的标签
source_prefix: 来源前缀
Returns:
加载的记忆项列表
"""
directory = Path(directory_path)
if not directory.exists() or not directory.is_dir():
logger.error(f"目录不存在或不是有效目录: {directory_path}")
return []
# 获取文件列表
files = list(directory.glob(file_pattern))
if not files:
logger.warning(f"在目录 {directory_path} 中没有找到符合 {file_pattern} 的文件")
return []
logger.info(f"在目录 {directory_path} 中找到 {len(files)} 个符合条件的文件")
# 加载文件内容为记忆
loaded_memories = []
for file_path in files:
try:
memory_item = await self._load_single_file(
file_path=str(file_path),
common_tags=common_tags,
source_prefix=source_prefix
)
if memory_item:
loaded_memories.append(memory_item)
logger.info(f"成功加载记忆: {file_path.name}")
except Exception as e:
logger.error(f"加载文件 {file_path} 失败: {str(e)}")
logger.info(f"完成加载,共加载了 {len(loaded_memories)} 个记忆")
return loaded_memories
async def _load_single_file(self,
file_path: str,
common_tags: Optional[List[str]] = None,
source_prefix: str = "文件") -> Optional[MemoryItem]:
"""
加载单个文件作为记忆
Args:
file_path: 文件路径
common_tags: 记忆共有的标签
source_prefix: 来源前缀
Returns:
记忆项加载失败则返回None
"""
try:
# 读取文件内容
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
if not content.strip():
logger.warning(f"文件 {file_path} 内容为空")
return None
# 准备标签和来源
file_name = os.path.basename(file_path)
tags = list(common_tags) if common_tags else []
tags.append(file_name) # 添加文件名作为标签
source = f"{source_prefix}_{file_name}"
# 添加到工作记忆
memory = await self.working_memory.add_memory(
content=content,
from_source=source,
tags=tags
)
return memory
except Exception as e:
logger.error(f"加载文件 {file_path} 失败: {str(e)}")
return None
async def main():
"""示例使用"""
# 初始化工作记忆
chat_id = "demo_chat"
working_memory = WorkingMemory(chat_id=chat_id)
try:
# 初始化加载器
loader = MemoryFileLoader(working_memory)
# 加载当前目录中的txt文件
current_dir = Path(__file__).parent
memories = await loader.load_from_directory(
directory_path=str(current_dir),
file_pattern="*.txt",
common_tags=["测试数据", "自动加载"],
source_prefix="测试文件"
)
# 显示加载结果
print(f"共加载了 {len(memories)} 个记忆")
# 获取并显示所有记忆的概要
all_memories = working_memory.memory_manager.get_all_items()
for memory in all_memories:
print("\n" + "=" * 40)
print(f"记忆ID: {memory.id}")
print(f"来源: {memory.from_source}")
print(f"标签: {', '.join(memory.tags)}")
if memory.summary:
print(f"\n主题: {memory.summary.get('brief', '无主题')}")
print(f"概述: {memory.summary.get('detailed', '无概述')}")
print("\n要点:")
for point in memory.summary.get('key_points', []):
print(f"- {point}")
else:
print("\n无摘要信息")
print("=" * 40)
finally:
# 关闭工作记忆
await working_memory.shutdown()
if __name__ == "__main__":
# 运行示例
asyncio.run(main())

View File

@@ -1,92 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import asyncio
import os
import sys
from pathlib import Path
# 添加项目根目录到系统路径
current_dir = Path(__file__).parent
project_root = current_dir.parent.parent.parent.parent.parent
sys.path.insert(0, str(project_root))
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
async def test_load_memories_from_files():
"""测试从文件加载记忆的功能"""
print("开始测试从文件加载记忆...")
# 初始化工作记忆
chat_id = "test_memory_load"
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60)
try:
# 获取测试文件列表
test_dir = Path(__file__).parent
test_files = [
os.path.join(test_dir, f)
for f in os.listdir(test_dir)
if f.endswith(".txt")
]
print(f"找到 {len(test_files)} 个测试文件")
# 从每个文件加载记忆
for file_path in test_files:
file_name = os.path.basename(file_path)
print(f"从文件 {file_name} 加载记忆...")
# 读取文件内容
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# 添加记忆
memory = await working_memory.add_memory(
content=content,
from_source=f"文件_{file_name}",
tags=["测试文件", file_name]
)
print(f"已添加记忆: ID={memory.id}")
if memory.summary:
print(f"记忆概要: {memory.summary.get('brief', '无概要')}")
print(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点']))}")
print("-" * 50)
# 获取所有记忆
all_memories = working_memory.memory_manager.get_all_items()
print(f"\n成功加载 {len(all_memories)} 个记忆")
# 测试检索记忆
if all_memories:
print("\n测试检索第一个记忆...")
first_memory = all_memories[0]
retrieved = await working_memory.retrieve_memory(first_memory.id)
if retrieved:
print(f"成功检索记忆: ID={retrieved.id}")
print(f"检索后强度: {retrieved.memory_strength} (初始为10.0)")
print(f"检索次数: {retrieved.retrieval_count}")
else:
print("检索失败")
# 测试记忆衰减
print("\n测试记忆衰减...")
for memory in all_memories:
print(f"记忆 {memory.id} 衰减前强度: {memory.memory_strength}")
await working_memory.decay_all_memories(decay_factor=0.5)
all_memories_after = working_memory.memory_manager.get_all_items()
for memory in all_memories_after:
print(f"记忆 {memory.id} 衰减后强度: {memory.memory_strength}")
finally:
# 关闭工作记忆
await working_memory.shutdown()
print("\n测试完成,已关闭工作记忆")
if __name__ == "__main__":
# 运行测试
asyncio.run(test_load_memories_from_files())

View File

@@ -1,197 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import asyncio
import os
import sys
import time
import random
from pathlib import Path
from datetime import datetime
# 添加项目根目录到系统路径
current_dir = Path(__file__).parent
project_root = current_dir.parent.parent.parent.parent.parent
sys.path.insert(0, str(project_root))
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
from src.common.logger_manager import get_logger
logger = get_logger("real_usage_simulation")
class WorkingMemorySimulator:
"""模拟工作记忆的真实使用场景"""
def __init__(self, chat_id="real_usage_test", cycle_interval=20):
"""
初始化模拟器
Args:
chat_id: 聊天ID
cycle_interval: 循环间隔时间(秒)
"""
self.chat_id = chat_id
self.cycle_interval = cycle_interval
self.working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=20, auto_decay_interval=60)
self.cycle_count = 0
self.running = False
# 获取测试文件路径
self.test_files = self._get_test_files()
if not self.test_files:
raise FileNotFoundError("找不到测试文件请确保test目录中有.txt文件")
# 存储所有添加的记忆ID
self.memory_ids = []
async def start(self, total_cycles=5):
"""
开始模拟循环
Args:
total_cycles: 总循环次数设为None表示无限循环
"""
self.running = True
logger.info(f"开始模拟真实使用场景,循环间隔: {self.cycle_interval}")
try:
while self.running and (total_cycles is None or self.cycle_count < total_cycles):
self.cycle_count += 1
logger.info(f"\n===== 开始第 {self.cycle_count} 次循环 =====")
# 执行一次循环
await self._run_one_cycle()
# 如果还有更多循环,则等待
if self.running and (total_cycles is None or self.cycle_count < total_cycles):
wait_time = self.cycle_interval
logger.info(f"等待 {wait_time} 秒后开始下一循环...")
await asyncio.sleep(wait_time)
logger.info(f"模拟完成,共执行了 {self.cycle_count} 次循环")
except KeyboardInterrupt:
logger.info("接收到中断信号,停止模拟")
except Exception as e:
logger.error(f"模拟过程中出错: {str(e)}", exc_info=True)
finally:
# 关闭工作记忆
await self.working_memory.shutdown()
def stop(self):
"""停止模拟循环"""
self.running = False
logger.info("正在停止模拟...")
async def _run_one_cycle(self):
"""运行一次完整循环:先检索记忆,再添加新记忆"""
start_time = time.time()
# 1. 先检索已有记忆(如果有)
await self._retrieve_memories()
# 2. 添加新记忆
await self._add_new_memory()
# 3. 显示工作记忆状态
await self._show_memory_status()
# 计算循环耗时
cycle_duration = time.time() - start_time
logger.info(f"{self.cycle_count} 次循环完成,耗时: {cycle_duration:.2f}")
async def _retrieve_memories(self):
"""检索现有记忆"""
# 如果有已保存的记忆ID随机选择1-3个进行检索
if self.memory_ids:
num_to_retrieve = min(len(self.memory_ids), random.randint(1, 3))
retrieval_ids = random.sample(self.memory_ids, num_to_retrieve)
logger.info(f"正在检索 {num_to_retrieve} 条记忆...")
for memory_id in retrieval_ids:
memory = await self.working_memory.retrieve_memory(memory_id)
if memory:
logger.info(f"成功检索记忆 ID: {memory_id}")
logger.info(f" - 强度: {memory.memory_strength:.2f},检索次数: {memory.retrieval_count}")
if memory.summary:
logger.info(f" - 主题: {memory.summary.get('brief', '无主题')}")
else:
logger.warning(f"记忆 ID: {memory_id} 不存在或已被移除")
# 从ID列表中移除
if memory_id in self.memory_ids:
self.memory_ids.remove(memory_id)
else:
logger.info("当前没有可检索的记忆")
async def _add_new_memory(self):
"""添加新记忆"""
# 随机选择一个测试文件作为记忆内容
file_path = random.choice(self.test_files)
file_name = os.path.basename(file_path)
try:
# 读取文件内容
with open(file_path, "r", encoding="utf-8") as f:
content = f.read()
# 添加时间戳,模拟不同内容
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
content_with_timestamp = f"[{timestamp}] {content}"
# 添加记忆
logger.info(f"正在添加新记忆,来源: {file_name}")
memory = await self.working_memory.add_memory(
content=content_with_timestamp,
from_source=f"模拟_{file_name}",
tags=["模拟测试", f"循环{self.cycle_count}", file_name]
)
# 保存记忆ID
self.memory_ids.append(memory.id)
# 显示记忆信息
logger.info(f"已添加新记忆 ID: {memory.id}")
if memory.summary:
logger.info(f"记忆主题: {memory.summary.get('brief', '无主题')}")
logger.info(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点'])[:2])}...")
except Exception as e:
logger.error(f"添加记忆失败: {str(e)}")
async def _show_memory_status(self):
"""显示当前工作记忆状态"""
all_memories = self.working_memory.memory_manager.get_all_items()
logger.info(f"\n当前工作记忆状态:")
logger.info(f"记忆总数: {len(all_memories)}")
# 按强度排序
sorted_memories = sorted(all_memories, key=lambda x: x.memory_strength, reverse=True)
logger.info("记忆强度排名 (前5项):")
for i, memory in enumerate(sorted_memories[:5], 1):
logger.info(f"{i}. ID: {memory.id}, 强度: {memory.memory_strength:.2f}, "
f"检索次数: {memory.retrieval_count}, "
f"主题: {memory.summary.get('brief', '无主题') if memory.summary else '无摘要'}")
def _get_test_files(self):
"""获取测试文件列表"""
test_dir = Path(__file__).parent
return [
os.path.join(test_dir, f)
for f in os.listdir(test_dir)
if f.endswith(".txt")
]
async def main():
"""主函数"""
# 创建模拟器
simulator = WorkingMemorySimulator(cycle_interval=20) # 设置20秒的循环间隔
# 设置运行5个循环
await simulator.start(total_cycles=5)
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,323 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import asyncio
import os
import sys
import time
from pathlib import Path
# 添加项目根目录到系统路径
current_dir = Path(__file__).parent
project_root = current_dir.parent.parent.parent.parent.parent
sys.path.insert(0, str(project_root))
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from src.chat.focus_chat.working_memory.test.memory_file_loader import MemoryFileLoader
from src.common.logger_manager import get_logger
logger = get_logger("memory_decay_test")
async def test_manual_decay_until_removal():
"""测试手动衰减直到记忆被自动移除"""
print("\n===== 测试手动衰减直到记忆被自动移除 =====")
# 初始化工作记忆,设置较大的衰减间隔,避免自动衰减影响测试
chat_id = "decay_test_manual"
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=3600)
try:
# 创建加载器并加载测试文件
loader = MemoryFileLoader(working_memory)
test_dir = current_dir
# 加载第一个测试文件作为记忆
memories = await loader.load_from_directory(
directory_path=str(test_dir),
file_pattern="test1.txt", # 只加载test1.txt
common_tags=["测试", "衰减", "自动移除"],
source_prefix="衰减测试"
)
if not memories:
print("未能加载记忆文件,测试结束")
return
# 获取加载的记忆
memory = memories[0]
memory_id = memory.id
print(f"已加载测试记忆ID: {memory_id}")
print(f"初始强度: {memory.memory_strength}")
if memory.summary:
print(f"记忆主题: {memory.summary.get('brief', '无主题')}")
# 执行多次衰减,直到记忆被移除
decay_count = 0
decay_factor = 0.5 # 每次衰减为原来的一半
while True:
# 获取当前记忆
current_memory = working_memory.memory_manager.get_by_id(memory_id)
# 如果记忆已被移除,退出循环
if current_memory is None:
print(f"记忆已在第 {decay_count} 次衰减后被自动移除!")
break
# 输出当前强度
print(f"衰减 {decay_count} 次后强度: {current_memory.memory_strength}")
# 执行衰减
await working_memory.decay_all_memories(decay_factor=decay_factor)
decay_count += 1
# 输出衰减后的详细信息
after_memory = working_memory.memory_manager.get_by_id(memory_id)
if after_memory:
print(f"{decay_count} 次衰减结果: 强度={after_memory.memory_strength},压缩次数={after_memory.compress_count}")
if after_memory.summary:
print(f"记忆概要: {after_memory.summary.get('brief', '无概要')}")
print(f"记忆要点数量: {len(after_memory.summary.get('key_points', []))}")
else:
print(f"{decay_count} 次衰减结果: 记忆已被移除")
# 防止无限循环
if decay_count > 20:
print("达到最大衰减次数(20),退出测试。")
break
# 短暂等待
await asyncio.sleep(0.5)
# 验证记忆是否真的被移除
all_memories = working_memory.memory_manager.get_all_items()
print(f"剩余记忆数量: {len(all_memories)}")
if len(all_memories) == 0:
print("测试通过: 记忆在强度低于阈值后被成功移除。")
else:
print("测试失败: 记忆应该被移除但仍然存在。")
finally:
await working_memory.shutdown()
async def test_auto_decay():
"""测试自动衰减功能"""
print("\n===== 测试自动衰减功能 =====")
# 初始化工作记忆,设置短的衰减间隔,便于测试
chat_id = "decay_test_auto"
decay_interval = 3 # 3秒
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=decay_interval)
try:
# 创建加载器并加载测试文件
loader = MemoryFileLoader(working_memory)
test_dir = current_dir
# 加载第二个测试文件作为记忆
memories = await loader.load_from_directory(
directory_path=str(test_dir),
file_pattern="test1.txt", # 只加载test2.txt
common_tags=["测试", "自动衰减"],
source_prefix="自动衰减测试"
)
if not memories:
print("未能加载记忆文件,测试结束")
return
# 获取加载的记忆
memory = memories[0]
memory_id = memory.id
print(f"已加载测试记忆ID: {memory_id}")
print(f"初始强度: {memory.memory_strength}")
if memory.summary:
print(f"记忆主题: {memory.summary.get('brief', '无主题')}")
print(f"记忆概要: {memory.summary.get('detailed', '无概要')}")
print(f"记忆要点: {memory.summary.get('keypoints', '无要点')}")
print(f"记忆事件: {memory.summary.get('events', '无事件')}")
# 观察自动衰减
print(f"等待自动衰减任务执行 (间隔 {decay_interval} 秒)...")
for i in range(3): # 观察3次自动衰减
# 等待自动衰减发生
await asyncio.sleep(decay_interval + 1) # 多等1秒确保任务执行
# 获取当前记忆
current_memory = working_memory.memory_manager.get_by_id(memory_id)
# 如果记忆已被移除,退出循环
if current_memory is None:
print(f"记忆已在第 {i+1} 次自动衰减后被移除!")
break
# 输出当前强度和详细信息
print(f"{i+1} 次自动衰减后强度: {current_memory.memory_strength}")
print(f"自动衰减详细结果: 压缩次数={current_memory.compress_count}, 提取次数={current_memory.retrieval_count}")
if current_memory.summary:
print(f"记忆概要: {current_memory.summary.get('brief', '无概要')}")
print(f"\n自动衰减测试结束。")
# 验证自动衰减是否发生
final_memory = working_memory.memory_manager.get_by_id(memory_id)
if final_memory is None:
print("记忆已被自动衰减移除。")
elif final_memory.memory_strength < memory.memory_strength:
print(f"自动衰减有效:初始强度 {memory.memory_strength} -> 最终强度 {final_memory.memory_strength}")
print(f"衰减历史记录: {final_memory.history}")
else:
print("测试失败:记忆强度未减少,自动衰减可能未生效。")
finally:
await working_memory.shutdown()
async def test_decay_and_retrieval_balance():
"""测试记忆衰减和检索的平衡"""
print("\n===== 测试记忆衰减和检索的平衡 =====")
# 初始化工作记忆
chat_id = "decay_retrieval_balance"
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60)
try:
# 创建加载器并加载测试文件
loader = MemoryFileLoader(working_memory)
test_dir = current_dir
# 加载第三个测试文件作为记忆
memories = await loader.load_from_directory(
directory_path=str(test_dir),
file_pattern="test3.txt", # 只加载test3.txt
common_tags=["测试", "衰减", "检索"],
source_prefix="平衡测试"
)
if not memories:
print("未能加载记忆文件,测试结束")
return
# 获取加载的记忆
memory = memories[0]
memory_id = memory.id
print(f"已加载测试记忆ID: {memory_id}")
print(f"初始强度: {memory.memory_strength}")
if memory.summary:
print(f"记忆主题: {memory.summary.get('brief', '无主题')}")
# 先衰减几次
print("\n开始衰减:")
for i in range(3):
await working_memory.decay_all_memories(decay_factor=0.5)
current = working_memory.memory_manager.get_by_id(memory_id)
if current:
print(f"衰减 {i+1} 次后强度: {current.memory_strength}")
print(f"衰减详细信息: 压缩次数={current.compress_count}, 历史操作数={len(current.history)}")
if current.summary:
print(f"记忆概要: {current.summary.get('brief', '无概要')}")
else:
print(f"记忆已在第 {i+1} 次衰减后被移除。")
break
# 如果记忆还存在,则检索几次增强它
current = working_memory.memory_manager.get_by_id(memory_id)
if current:
print("\n开始检索增强:")
for i in range(2):
retrieved = await working_memory.retrieve_memory(memory_id)
print(f"检索 {i+1} 次后强度: {retrieved.memory_strength}")
print(f"检索后详细信息: 提取次数={retrieved.retrieval_count}, 历史记录长度={len(retrieved.history)}")
# 再次衰减几次,测试是否会被移除
print("\n再次衰减:")
for i in range(5):
await working_memory.decay_all_memories(decay_factor=0.5)
current = working_memory.memory_manager.get_by_id(memory_id)
if current:
print(f"最终衰减 {i+1} 次后强度: {current.memory_strength}")
print(f"衰减详细结果: 压缩次数={current.compress_count}")
else:
print(f"记忆已在最终衰减第 {i+1} 次后被移除。")
break
print("\n测试结束。")
finally:
await working_memory.shutdown()
async def test_multi_memories_decay():
"""测试多条记忆同时衰减"""
print("\n===== 测试多条记忆同时衰减 =====")
# 初始化工作记忆
chat_id = "multi_decay_test"
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60)
try:
# 创建加载器并加载所有测试文件
loader = MemoryFileLoader(working_memory)
test_dir = current_dir
# 加载所有测试文件作为记忆
memories = await loader.load_from_directory(
directory_path=str(test_dir),
file_pattern="*.txt",
common_tags=["测试", "多记忆衰减"],
source_prefix="多记忆测试"
)
if not memories or len(memories) < 2:
print("未能加载足够的记忆文件,测试结束")
return
# 显示已加载的记忆
print(f"已加载 {len(memories)} 条记忆:")
for idx, mem in enumerate(memories):
print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 来源: {mem.from_source}")
if mem.summary:
print(f" 主题: {mem.summary.get('brief', '无主题')}")
# 进行多次衰减测试
print("\n开始多记忆衰减测试:")
for decay_round in range(5):
# 执行衰减
await working_memory.decay_all_memories(decay_factor=0.5)
# 获取并显示所有记忆
all_memories = working_memory.memory_manager.get_all_items()
print(f"\n{decay_round+1} 次衰减后,剩余记忆数量: {len(all_memories)}")
for idx, mem in enumerate(all_memories):
print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 压缩次数: {mem.compress_count}")
if mem.summary:
print(f" 概要: {mem.summary.get('brief', '无概要')[:30]}...")
# 如果所有记忆都被移除,退出循环
if not all_memories:
print("所有记忆已被移除,测试结束。")
break
# 等待一下
await asyncio.sleep(0.5)
print("\n多记忆衰减测试结束。")
finally:
await working_memory.shutdown()
async def main():
"""运行所有测试"""
# 测试手动衰减直到移除
await test_manual_decay_until_removal()
# 测试自动衰减
await test_auto_decay()
# 测试衰减和检索的平衡
await test_decay_and_retrieval_balance()
# 测试多条记忆同时衰减
await test_multi_memories_decay()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,121 +0,0 @@
import asyncio
import os
import unittest
from typing import List, Dict, Any
from pathlib import Path
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
class TestWorkingMemory(unittest.TestCase):
"""工作记忆测试类"""
def setUp(self):
"""测试前准备"""
self.chat_id = "test_chat_123"
self.working_memory = WorkingMemory(chat_id=self.chat_id, max_memories_per_chat=10, auto_decay_interval=60)
self.test_dir = Path(__file__).parent
def tearDown(self):
"""测试后清理"""
loop = asyncio.get_event_loop()
loop.run_until_complete(self.working_memory.shutdown())
def test_init(self):
"""测试初始化"""
self.assertEqual(self.working_memory.max_memories_per_chat, 10)
self.assertEqual(self.working_memory.auto_decay_interval, 60)
def test_add_memory_from_files(self):
"""从文件添加记忆"""
loop = asyncio.get_event_loop()
test_files = self._get_test_files()
# 添加记忆
memories = []
for file_path in test_files:
content = self._read_file_content(file_path)
file_name = os.path.basename(file_path)
source = f"test_file_{file_name}"
tags = ["测试", f"文件_{file_name}"]
memory = loop.run_until_complete(
self.working_memory.add_memory(
content=content,
from_source=source,
tags=tags
)
)
memories.append(memory)
# 验证记忆数量
all_items = self.working_memory.memory_manager.get_all_items()
self.assertEqual(len(all_items), len(test_files))
# 验证每个记忆的内容和标签
for i, memory in enumerate(memories):
file_name = os.path.basename(test_files[i])
retrieved_memory = loop.run_until_complete(
self.working_memory.retrieve_memory(memory.id)
)
self.assertIsNotNone(retrieved_memory)
self.assertTrue(retrieved_memory.has_tag("测试"))
self.assertTrue(retrieved_memory.has_tag(f"文件_{file_name}"))
self.assertEqual(retrieved_memory.from_source, f"test_file_{file_name}")
# 验证检索后强度增加
self.assertGreater(retrieved_memory.memory_strength, 10.0) # 原始强度为10.0检索后增加1.5倍
self.assertEqual(retrieved_memory.retrieval_count, 1)
def test_decay_memories(self):
"""测试记忆衰减"""
loop = asyncio.get_event_loop()
test_files = self._get_test_files()[:1] # 只使用一个文件测试衰减
# 添加记忆
for file_path in test_files:
content = self._read_file_content(file_path)
loop.run_until_complete(
self.working_memory.add_memory(
content=content,
from_source="decay_test",
tags=["衰减测试"]
)
)
# 获取添加后的记忆项
all_items_before = self.working_memory.memory_manager.get_all_items()
self.assertEqual(len(all_items_before), 1)
# 记录原始强度
original_strength = all_items_before[0].memory_strength
# 执行衰减
loop.run_until_complete(
self.working_memory.decay_all_memories(decay_factor=0.5)
)
# 获取衰减后的记忆项
all_items_after = self.working_memory.memory_manager.get_all_items()
# 验证强度衰减
self.assertEqual(len(all_items_after), 1)
self.assertLess(all_items_after[0].memory_strength, original_strength)
def _get_test_files(self) -> List[str]:
"""获取测试文件列表"""
test_dir = self.test_dir
return [
os.path.join(test_dir, f)
for f in os.listdir(test_dir)
if f.endswith(".txt")
]
def _read_file_content(self, file_path: str) -> str:
"""读取文件内容"""
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
if __name__ == "__main__":
unittest.main()

View File

@@ -1,7 +1,6 @@
from typing import Dict, List, Any, Optional
from typing import List, Any, Optional
import asyncio
import random
from datetime import datetime
from src.common.logger_manager import get_logger
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
@@ -9,39 +8,40 @@ logger = get_logger(__name__)
# 问题是我不知道这个manager是不是需要和其他manager统一管理因为这个manager是从属于每一个聊天流都有自己的定时任务
class WorkingMemory:
"""
工作记忆,负责协调和运作记忆
从属于特定的流用chat_id来标识
"""
def __init__(self,chat_id:str , max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
"""
初始化工作记忆管理器
Args:
max_memories_per_chat: 每个聊天的最大记忆数量
auto_decay_interval: 自动衰减记忆的时间间隔(秒)
"""
self.memory_manager = MemoryManager(chat_id)
# 记忆容量上限
self.max_memories_per_chat = max_memories_per_chat
# 自动衰减间隔
self.auto_decay_interval = auto_decay_interval
# 衰减任务
self.decay_task = None
# 启动自动衰减任务
self._start_auto_decay()
def _start_auto_decay(self):
"""启动自动衰减任务"""
if self.decay_task is None:
self.decay_task = asyncio.create_task(self._auto_decay_loop())
async def _auto_decay_loop(self):
"""自动衰减循环"""
while True:
@@ -50,43 +50,39 @@ class WorkingMemory:
await self.decay_all_memories()
except Exception as e:
print(f"自动衰减记忆时出错: {str(e)}")
async def add_memory(self,
content: Any,
from_source: str = "",
tags: Optional[List[str]] = None):
async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None):
"""
添加一段记忆到指定聊天
Args:
content: 记忆内容
from_source: 数据来源
tags: 数据标签列表
Returns:
包含记忆信息的字典
"""
memory = await self.memory_manager.push_with_summary(content, from_source, tags)
if len(self.memory_manager.get_all_items()) > self.max_memories_per_chat:
self.remove_earliest_memory()
return memory
def remove_earliest_memory(self):
"""
删除最早的记忆
"""
return self.memory_manager.delete_earliest_memory()
async def retrieve_memory(self, memory_id: str) -> Optional[MemoryItem]:
"""
检索记忆
Args:
chat_id: 聊天ID
memory_id: 记忆ID
Returns:
检索到的记忆项如果不存在则返回None
"""
@@ -97,19 +93,18 @@ class WorkingMemory:
return memory_item
return None
async def decay_all_memories(self, decay_factor: float = 0.5):
"""
对所有聊天的所有记忆进行衰减
衰减对记忆进行refine压缩强度会变为原先的0.5
Args:
decay_factor: 衰减因子(0-1之间)
"""
logger.debug(f"开始对所有记忆进行衰减,衰减因子: {decay_factor}")
all_memories = self.memory_manager.get_all_items()
for memory_item in all_memories:
# 如果压缩完小于1会被删除
memory_id = memory_item.id
@@ -119,45 +114,47 @@ class WorkingMemory:
continue
# 计算衰减量
if memory_item.memory_strength < 5:
await self.memory_manager.refine_memory(memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩")
await self.memory_manager.refine_memory(
memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
)
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
"""合并记忆
Args:
memory_str: 记忆内容
"""
return await self.memory_manager.merge_memories(memory_id1 = memory_id1, memory_id2 = memory_id2,reason = "两端记忆有重复的内容")
return await self.memory_manager.merge_memories(
memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
)
# 暂时没用,先留着
async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2):
"""
模拟记忆模糊过程,随机选择一部分记忆进行精简
Args:
chat_id: 聊天ID
blur_rate: 模糊比率(0-1之间),表示有多少比例的记忆会被精简
"""
memory = self.get_memory(chat_id)
# 获取所有字符串类型且有总结的记忆
all_summarized_memories = []
for type_items in memory._memory.values():
for item in type_items:
if isinstance(item.data, str) and hasattr(item, 'summary') and item.summary:
if isinstance(item.data, str) and hasattr(item, "summary") and item.summary:
all_summarized_memories.append(item)
if not all_summarized_memories:
return
# 计算要模糊的记忆数量
blur_count = max(1, int(len(all_summarized_memories) * blur_rate))
# 随机选择要模糊的记忆
memories_to_blur = random.sample(all_summarized_memories, min(blur_count, len(all_summarized_memories)))
# 对选中的记忆进行精简
for memory_item in memories_to_blur:
try:
@@ -168,16 +165,14 @@ class WorkingMemory:
requirement = "保留核心要点,适度精简细节"
else:
requirement = "只保留最关键的1-2个要点大幅精简内容"
# 进行精简
await memory.refine_memory(memory_item.id, requirement)
print(f"已模糊记忆 {memory_item.id},强度: {memory_item.memory_strength}, 要求: {requirement}")
except Exception as e:
print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}")
async def shutdown(self) -> None:
"""关闭管理器,停止所有任务"""
if self.decay_task and not self.decay_task.done():
@@ -185,13 +180,13 @@ class WorkingMemory:
try:
await self.decay_task
except asyncio.CancelledError:
pass
pass
def get_all_memories(self) -> List[MemoryItem]:
"""
获取所有记忆项目
Returns:
List[MemoryItem]: 当前工作记忆中的所有记忆项目列表
"""
return self.memory_manager.get_all_items()
return self.memory_manager.get_all_items()

View File

@@ -17,14 +17,14 @@ class HFCloopObservation:
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop: List[CycleDetail] = []
self.action_manager = ActionManager()
self.action_manager = ActionManager()
def get_observe_info(self):
return self.observe_info
def add_loop_info(self, loop_info: CycleDetail):
self.history_loop.append(loop_info)
def set_action_manager(self, action_manager: ActionManager):
self.action_manager = action_manager
@@ -75,16 +75,15 @@ class HFCloopObservation:
if start_time is not None and end_time is not None:
time_diff = int(end_time - start_time)
if time_diff > 60:
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff/60}分钟\n"
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff / 60}分钟\n"
else:
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}\n"
else:
cycle_info_block += "\n你还没看过消息\n"
using_actions = self.action_manager.get_using_actions()
for action_name, action_info in using_actions.items():
action_description = action_info["description"]
cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
self.observe_info = cycle_info_block

View File

@@ -5,6 +5,7 @@ from src.common.logger_manager import get_logger
logger = get_logger("observation")
# 所有观察的基类
class Observation:
def __init__(self, observe_id):

View File

@@ -29,4 +29,4 @@ class StructureObservation:
observed_structured_infos.append(structured_info)
logger.debug(f"观察到结构化信息仍旧在: {structured_info}")
self.structured_info = observed_structured_infos
self.structured_info = observed_structured_infos

View File

@@ -16,9 +16,9 @@ class WorkingMemoryObservation:
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp()
self.working_memory = working_memory
self.retrieved_working_memory = []
def get_observe_info(self):
@@ -26,7 +26,7 @@ class WorkingMemoryObservation:
def add_retrieved_working_memory(self, retrieved_working_memory: List[MemoryItem]):
self.retrieved_working_memory.append(retrieved_working_memory)
def get_retrieved_working_memory(self):
return self.retrieved_working_memory

View File

@@ -94,7 +94,7 @@ class PersonInfoManager:
return True
else:
return False
def get_person_id_by_person_name(self, person_name: str):
"""根据用户名获取用户ID"""
document = db.person_info.find_one({"person_name": person_name})
@@ -102,7 +102,6 @@ class PersonInfoManager:
return document["person_id"]
else:
return ""
@staticmethod
async def create_person_info(person_id: str, data: dict = None):

View File

@@ -451,7 +451,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 处理 回复<aaa:bbb>
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
def reply_replacer(match):
def reply_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
anon_reply = get_anon_name(platform, bbb)
@@ -462,7 +462,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
# 处理 @<aaa:bbb>
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
def at_replacer(match):
def at_replacer(match, platform=platform):
# aaa = match.group(1)
bbb = match.group(2)
anon_at = get_anon_name(platform, bbb)

View File

@@ -1 +1 @@
"""插件系统包"""
"""插件系统包"""

View File

@@ -1,4 +1,5 @@
"""测试插件包"""
"""
这是一个测试插件
"""
"""

View File

@@ -1,6 +1,7 @@
"""测试插件动作模块"""
# 导入所有动作模块以确保装饰器被执行
from . import test_action # noqa
# from . import online_action # noqa
from . import mute_action # noqa
from . import test_action # noqa
# from . import online_action # noqa
from . import mute_action # noqa

View File

@@ -4,12 +4,15 @@ from typing import Tuple
logger = get_logger("mute_action")
@register_action
class MuteAction(PluginAction):
"""测试动作处理类"""
action_name = "mute_action"
action_description = "如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人"
action_description = (
"如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人"
)
action_parameters = {
"target": "禁言对象,输入你要禁言的对象的名字,必填,",
"duration": "禁言时长,输入你要禁言的时长,单位为秒,必填",
@@ -27,22 +30,22 @@ class MuteAction(PluginAction):
async def process(self) -> Tuple[bool, str]:
"""处理测试动作"""
logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}")
# 发送测试消息
target = self.action_data.get("target")
duration = self.action_data.get("duration")
reason = self.action_data.get("reason")
platform, user_id = await self.get_user_id_by_person_name(target)
await self.send_message_by_expressor(f"我要禁言{target}{platform},时长{duration}秒,理由{reason},表达情绪")
try:
await self.send_message(f"[command]mute,{user_id},{duration}")
except Exception as e:
logger.error(f"{self.log_prefix} 执行mute动作时出错: {e}")
await self.send_message_by_expressor(f"执行mute动作时出错: {e}")
return False, "执行mute动作时出错"
return True, "测试动作执行成功"
return True, "测试动作执行成功"

View File

@@ -4,15 +4,14 @@ from typing import Tuple
logger = get_logger("check_online_action")
@register_action
class CheckOnlineAction(PluginAction):
"""测试动作处理类"""
action_name = "check_online_action"
action_description = "这是一个检查在线状态的动作当有人要求你检查Maibot麦麦 机器人)在线状态时使用"
action_parameters = {
"mode": "查看模式"
}
action_parameters = {"mode": "查看模式"}
action_require = [
"当有人要求你检查Maibot麦麦 机器人)在线状态时使用",
"mode参数为version时查看在线版本状态默认用这种",
@@ -23,22 +22,22 @@ class CheckOnlineAction(PluginAction):
async def process(self) -> Tuple[bool, str]:
"""处理测试动作"""
logger.info(f"{self.log_prefix} 执行online动作: {self.reasoning}")
# 发送测试消息
mode = self.action_data.get("mode", "type")
await self.send_message_by_expressor("我看看")
try:
if mode == "type":
await self.send_message(f"#online detail")
await self.send_message("#online detail")
elif mode == "version":
await self.send_message(f"#online")
await self.send_message("#online")
except Exception as e:
logger.error(f"{self.log_prefix} 执行online动作时出错: {e}")
await self.send_message_by_expressor("执行online动作时出错: {e}")
return False, "执行online动作时出错"
return True, "测试动作执行成功"
return True, "测试动作执行成功"

View File

@@ -4,15 +4,14 @@ from typing import Tuple
logger = get_logger("test_action")
@register_action
class TestAction(PluginAction):
"""测试动作处理类"""
action_name = "test_action"
action_description = "这是一个测试动作,当有人要求你测试插件系统时使用"
action_parameters = {
"test_param": "测试参数(可选)"
}
action_parameters = {"test_param": "测试参数(可选)"}
action_require = [
"测试情况下使用",
"想测试插件动作加载时使用",
@@ -22,17 +21,17 @@ class TestAction(PluginAction):
async def process(self) -> Tuple[bool, str]:
"""处理测试动作"""
logger.info(f"{self.log_prefix} 执行测试动作: {self.reasoning}")
# 获取聊天类型
chat_type = self.get_chat_type()
logger.info(f"{self.log_prefix} 当前聊天类型: {chat_type}")
# 获取最近消息
recent_messages = self.get_recent_messages(3)
logger.info(f"{self.log_prefix} 最近3条消息: {recent_messages}")
# 发送测试消息
test_param = self.action_data.get("test_param", "默认参数")
await self.send_message_by_expressor(f"测试动作执行成功,参数: {test_param}")
return True, "测试动作执行成功"
return True, "测试动作执行成功"