fix:ruff
This commit is contained in:
@@ -10,7 +10,6 @@ from src.config.config import global_config
|
|||||||
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
|
from src.chat.utils.utils_image import image_path_to_base64 # Local import needed after move
|
||||||
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
from src.chat.utils.timer_calculator import Timer # <--- Import Timer
|
||||||
from src.chat.emoji_system.emoji_manager import emoji_manager
|
from src.chat.emoji_system.emoji_manager import emoji_manager
|
||||||
from src.chat.focus_chat.heartflow_prompt_builder import prompt_builder,Prompt
|
|
||||||
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
from src.chat.focus_chat.heartFC_sender import HeartFCSender
|
||||||
from src.chat.utils.utils import process_llm_response
|
from src.chat.utils.utils import process_llm_response
|
||||||
from src.chat.utils.info_catcher import info_catcher_manager
|
from src.chat.utils.info_catcher import info_catcher_manager
|
||||||
@@ -18,25 +17,16 @@ from src.manager.mood_manager import mood_manager
|
|||||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
from src.chat.focus_chat.hfc_utils import parse_thinking_id_to_timestamp
|
||||||
from src.config.config import global_config
|
|
||||||
from src.common.logger_manager import get_logger
|
|
||||||
from src.individuality.individuality import Individuality
|
from src.individuality.individuality import Individuality
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
|
||||||
from src.chat.utils.utils import get_embedding
|
|
||||||
import time
|
import time
|
||||||
from typing import Union, Optional
|
|
||||||
from src.common.database import db
|
|
||||||
from src.chat.utils.utils import get_recent_group_speaker
|
|
||||||
from src.manager.mood_manager import mood_manager
|
|
||||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
|
||||||
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
|
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
|
||||||
import random
|
import random
|
||||||
|
|
||||||
logger = get_logger("expressor")
|
logger = get_logger("expressor")
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
@@ -377,10 +367,8 @@ class DefaultExpressor:
|
|||||||
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
|
moderation_prompt=await global_prompt_manager.get_prompt_async("moderation_prompt"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
# --- 发送器 (Sender) --- #
|
# --- 发送器 (Sender) --- #
|
||||||
|
|
||||||
async def send_response_messages(
|
async def send_response_messages(
|
||||||
@@ -402,7 +390,7 @@ class DefaultExpressor:
|
|||||||
if thinking_id:
|
if thinking_id:
|
||||||
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id)
|
thinking_start_time = await self.heart_fc_sender.get_thinking_start_time(chat_id, thinking_id)
|
||||||
else:
|
else:
|
||||||
thinking_id = "ds"+ str(round(time.time(),2))
|
thinking_id = "ds" + str(round(time.time(), 2))
|
||||||
thinking_start_time = time.time()
|
thinking_start_time = time.time()
|
||||||
|
|
||||||
if thinking_start_time is None:
|
if thinking_start_time is None:
|
||||||
@@ -514,7 +502,6 @@ class DefaultExpressor:
|
|||||||
return bot_message
|
return bot_message
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def weighted_sample_no_replacement(items, weights, k) -> list:
|
def weighted_sample_no_replacement(items, weights, k) -> list:
|
||||||
"""
|
"""
|
||||||
加权且不放回地随机抽取k个元素。
|
加权且不放回地随机抽取k个元素。
|
||||||
@@ -548,4 +535,5 @@ def weighted_sample_no_replacement(items, weights, k) -> list:
|
|||||||
break
|
break
|
||||||
return selected
|
return selected
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
@@ -17,7 +17,6 @@ from src.chat.focus_chat.info_processors.mind_processor import MindProcessor
|
|||||||
from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor
|
from src.chat.focus_chat.info_processors.working_memory_processor import WorkingMemoryProcessor
|
||||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||||
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
||||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
|
||||||
from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor
|
from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor
|
||||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||||
from src.chat.focus_chat.memory_activator import MemoryActivator
|
from src.chat.focus_chat.memory_activator import MemoryActivator
|
||||||
@@ -26,6 +25,7 @@ from src.chat.focus_chat.info_processors.self_processor import SelfProcessor
|
|||||||
from src.chat.focus_chat.planners.planner import ActionPlanner
|
from src.chat.focus_chat.planners.planner import ActionPlanner
|
||||||
from src.chat.focus_chat.planners.action_manager import ActionManager
|
from src.chat.focus_chat.planners.action_manager import ActionManager
|
||||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|
||||||
@@ -88,7 +88,9 @@ class HeartFChatting:
|
|||||||
|
|
||||||
self.memory_activator = MemoryActivator()
|
self.memory_activator = MemoryActivator()
|
||||||
self.working_memory = WorkingMemory(chat_id=self.stream_id)
|
self.working_memory = WorkingMemory(chat_id=self.stream_id)
|
||||||
self.working_observation = WorkingMemoryObservation(observe_id=self.stream_id, working_memory=self.working_memory)
|
self.working_observation = WorkingMemoryObservation(
|
||||||
|
observe_id=self.stream_id, working_memory=self.working_memory
|
||||||
|
)
|
||||||
|
|
||||||
self.expressor = DefaultExpressor(chat_id=self.stream_id)
|
self.expressor = DefaultExpressor(chat_id=self.stream_id)
|
||||||
self.action_manager = ActionManager()
|
self.action_manager = ActionManager()
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ from src.chat.utils.utils import get_recent_group_speaker
|
|||||||
from src.manager.mood_manager import mood_manager
|
from src.manager.mood_manager import mood_manager
|
||||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||||
from src.chat.knowledge.knowledge_lib import qa_manager
|
from src.chat.knowledge.knowledge_lib import qa_manager
|
||||||
from src.chat.focus_chat.expressors.exprssion_learner import expression_learner
|
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
||||||
@@ -20,7 +19,6 @@ logger = get_logger("prompt")
|
|||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
|
|
||||||
Prompt(
|
Prompt(
|
||||||
"""
|
"""
|
||||||
你有以下信息可供参考:
|
你有以下信息可供参考:
|
||||||
@@ -521,5 +519,6 @@ class PromptBuilder:
|
|||||||
# 返回所有找到的内容,用换行分隔
|
# 返回所有找到的内容,用换行分隔
|
||||||
return "\n".join(str(result["content"]) for result in results)
|
return "\n".join(str(result["content"]) for result in results)
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
prompt_builder = PromptBuilder()
|
prompt_builder = PromptBuilder()
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ class InfoBase:
|
|||||||
|
|
||||||
type: str = "base"
|
type: str = "base"
|
||||||
data: Dict[str, Any] = field(default_factory=dict)
|
data: Dict[str, Any] = field(default_factory=dict)
|
||||||
processed_info:str = ""
|
processed_info: str = ""
|
||||||
|
|
||||||
def get_type(self) -> str:
|
def get_type(self) -> str:
|
||||||
"""获取信息类型
|
"""获取信息类型
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
from typing import Dict, Any
|
from dataclasses import dataclass
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from .info_base import InfoBase
|
from .info_base import InfoBase
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,10 +5,9 @@ from .info_base import InfoBase
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class WorkingMemoryInfo(InfoBase):
|
class WorkingMemoryInfo(InfoBase):
|
||||||
|
|
||||||
type: str = "workingmemory"
|
type: str = "workingmemory"
|
||||||
|
|
||||||
processed_info:str = ""
|
processed_info: str = ""
|
||||||
|
|
||||||
def set_talking_message(self, message: str) -> None:
|
def set_talking_message(self, message: str) -> None:
|
||||||
"""设置说话消息
|
"""设置说话消息
|
||||||
|
|||||||
@@ -6,11 +6,9 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.individuality.individuality import Individuality
|
from src.individuality.individuality import Individuality
|
||||||
import random
|
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.json_utils import safe_json_dumps
|
from src.chat.utils.json_utils import safe_json_dumps
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
import difflib
|
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.chat.person_info.relationship_manager import relationship_manager
|
||||||
from .base_processor import BaseProcessor
|
from .base_processor import BaseProcessor
|
||||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
from src.chat.focus_chat.info.mind_info import MindInfo
|
||||||
@@ -202,7 +200,6 @@ class MindProcessor(BaseProcessor):
|
|||||||
for person in person_list:
|
for person in person_list:
|
||||||
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
|
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
|
||||||
|
|
||||||
|
|
||||||
template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before"
|
template_name = "sub_heartflow_prompt_before" if is_group_chat else "sub_heartflow_prompt_private_before"
|
||||||
logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板")
|
logger.debug(f"{self.log_prefix} 使用{'群聊' if is_group_chat else '私聊'}思考模板")
|
||||||
|
|
||||||
@@ -218,7 +215,6 @@ class MindProcessor(BaseProcessor):
|
|||||||
chat_target_name=chat_target_name,
|
chat_target_name=chat_target_name,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
content = "(不知道该想些什么...)"
|
content = "(不知道该想些什么...)"
|
||||||
try:
|
try:
|
||||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
|
|||||||
@@ -6,14 +6,10 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.individuality.individuality import Individuality
|
from src.individuality.individuality import Individuality
|
||||||
import random
|
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.json_utils import safe_json_dumps
|
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
import difflib
|
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
from src.chat.person_info.relationship_manager import relationship_manager
|
||||||
from .base_processor import BaseProcessor
|
from .base_processor import BaseProcessor
|
||||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
@@ -44,7 +40,6 @@ def init_prompt():
|
|||||||
Prompt(indentify_prompt, "indentify_prompt")
|
Prompt(indentify_prompt, "indentify_prompt")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class SelfProcessor(BaseProcessor):
|
class SelfProcessor(BaseProcessor):
|
||||||
log_prefix = "自我认同"
|
log_prefix = "自我认同"
|
||||||
|
|
||||||
@@ -63,7 +58,6 @@ class SelfProcessor(BaseProcessor):
|
|||||||
name = chat_manager.get_stream_name(self.subheartflow_id)
|
name = chat_manager.get_stream_name(self.subheartflow_id)
|
||||||
self.log_prefix = f"[{name}] "
|
self.log_prefix = f"[{name}] "
|
||||||
|
|
||||||
|
|
||||||
async def process_info(
|
async def process_info(
|
||||||
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
|
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
|
||||||
) -> List[InfoBase]:
|
) -> List[InfoBase]:
|
||||||
@@ -102,14 +96,12 @@ class SelfProcessor(BaseProcessor):
|
|||||||
tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt
|
tuple: (current_mind, past_mind, prompt) 当前想法、过去的想法列表和使用的prompt
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
memory_str = ""
|
memory_str = ""
|
||||||
if running_memorys:
|
if running_memorys:
|
||||||
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
memory_str = "以下是当前在聊天中,你回忆起的记忆:\n"
|
||||||
for running_memory in running_memorys:
|
for running_memory in running_memorys:
|
||||||
memory_str += f"{running_memory['topic']}: {running_memory['content']}\n"
|
memory_str += f"{running_memory['topic']}: {running_memory['content']}\n"
|
||||||
|
|
||||||
|
|
||||||
if observations is None:
|
if observations is None:
|
||||||
observations = []
|
observations = []
|
||||||
for observation in observations:
|
for observation in observations:
|
||||||
@@ -127,8 +119,8 @@ class SelfProcessor(BaseProcessor):
|
|||||||
chat_observe_info = observation.get_observe_info()
|
chat_observe_info = observation.get_observe_info()
|
||||||
person_list = observation.person_list
|
person_list = observation.person_list
|
||||||
if isinstance(observation, HFCloopObservation):
|
if isinstance(observation, HFCloopObservation):
|
||||||
hfcloop_observe_info = observation.get_observe_info()
|
# hfcloop_observe_info = observation.get_observe_info()
|
||||||
|
pass
|
||||||
|
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
personality_block = individuality.get_prompt(x_person=2, level=2)
|
personality_block = individuality.get_prompt(x_person=2, level=2)
|
||||||
@@ -137,7 +129,6 @@ class SelfProcessor(BaseProcessor):
|
|||||||
for person in person_list:
|
for person in person_list:
|
||||||
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
|
relation_prompt += await relationship_manager.build_relationship_info(person, is_id=True)
|
||||||
|
|
||||||
|
|
||||||
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
|
prompt = (await global_prompt_manager.get_prompt_async("indentify_prompt")).format(
|
||||||
bot_name=individuality.name,
|
bot_name=individuality.name,
|
||||||
prompt_personality=personality_block,
|
prompt_personality=personality_block,
|
||||||
@@ -147,7 +138,6 @@ class SelfProcessor(BaseProcessor):
|
|||||||
chat_observe_info=chat_observe_info,
|
chat_observe_info=chat_observe_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
content = ""
|
content = ""
|
||||||
try:
|
try:
|
||||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
@@ -159,7 +149,7 @@ class SelfProcessor(BaseProcessor):
|
|||||||
logger.error(traceback.format_exc())
|
logger.error(traceback.format_exc())
|
||||||
content = "自我识别过程中出现错误"
|
content = "自我识别过程中出现错误"
|
||||||
|
|
||||||
if content == 'None':
|
if content == "None":
|
||||||
content = ""
|
content = ""
|
||||||
# 记录初步思考结果
|
# 记录初步思考结果
|
||||||
logger.debug(f"{self.log_prefix} 自我识别prompt: \n{prompt}\n")
|
logger.debug(f"{self.log_prefix} 自我识别prompt: \n{prompt}\n")
|
||||||
@@ -168,5 +158,4 @@ class SelfProcessor(BaseProcessor):
|
|||||||
return content
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
init_prompt()
|
init_prompt()
|
||||||
|
|||||||
@@ -134,7 +134,7 @@ class ToolProcessor(BaseProcessor):
|
|||||||
|
|
||||||
# 获取个性信息
|
# 获取个性信息
|
||||||
individuality = Individuality.get_instance()
|
individuality = Individuality.get_instance()
|
||||||
prompt_personality = individuality.get_prompt(x_person=2, level=2)
|
# prompt_personality = individuality.get_prompt(x_person=2, level=2)
|
||||||
|
|
||||||
# 获取时间信息
|
# 获取时间信息
|
||||||
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
time_now = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
|
||||||
|
|||||||
@@ -5,17 +5,11 @@ from src.config.config import global_config
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.individuality.individuality import Individuality
|
|
||||||
import random
|
|
||||||
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from src.chat.utils.json_utils import safe_json_dumps
|
|
||||||
from src.chat.message_receive.chat_stream import chat_manager
|
from src.chat.message_receive.chat_stream import chat_manager
|
||||||
import difflib
|
|
||||||
from src.chat.person_info.relationship_manager import relationship_manager
|
|
||||||
from .base_processor import BaseProcessor
|
from .base_processor import BaseProcessor
|
||||||
from src.chat.focus_chat.info.mind_info import MindInfo
|
from src.chat.focus_chat.info.mind_info import MindInfo
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
|
||||||
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
||||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
@@ -76,8 +70,6 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
name = chat_manager.get_stream_name(self.subheartflow_id)
|
name = chat_manager.get_stream_name(self.subheartflow_id)
|
||||||
self.log_prefix = f"[{name}] "
|
self.log_prefix = f"[{name}] "
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def process_info(
|
async def process_info(
|
||||||
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
|
self, observations: Optional[List[Observation]] = None, running_memorys: Optional[List[Dict]] = None, *infos
|
||||||
) -> List[InfoBase]:
|
) -> List[InfoBase]:
|
||||||
@@ -95,7 +87,7 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
for observation in observations:
|
for observation in observations:
|
||||||
if isinstance(observation, WorkingMemoryObservation):
|
if isinstance(observation, WorkingMemoryObservation):
|
||||||
working_memory = observation.get_observe_info()
|
working_memory = observation.get_observe_info()
|
||||||
working_memory_obs = observation
|
# working_memory_obs = observation
|
||||||
if isinstance(observation, ChattingObservation):
|
if isinstance(observation, ChattingObservation):
|
||||||
chat_info = observation.get_observe_info()
|
chat_info = observation.get_observe_info()
|
||||||
# chat_info_truncate = observation.talking_message_str_truncate
|
# chat_info_truncate = observation.talking_message_str_truncate
|
||||||
@@ -112,11 +104,11 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
all_memory = working_memory.get_all_memories()
|
all_memory = working_memory.get_all_memories()
|
||||||
memory_prompts = []
|
memory_prompts = []
|
||||||
for memory in all_memory:
|
for memory in all_memory:
|
||||||
memory_content = memory.data
|
# memory_content = memory.data
|
||||||
memory_summary = memory.summary
|
memory_summary = memory.summary
|
||||||
memory_id = memory.id
|
memory_id = memory.id
|
||||||
memory_brief = memory_summary.get("brief")
|
memory_brief = memory_summary.get("brief")
|
||||||
memory_detailed = memory_summary.get("detailed")
|
# memory_detailed = memory_summary.get("detailed")
|
||||||
memory_keypoints = memory_summary.get("keypoints")
|
memory_keypoints = memory_summary.get("keypoints")
|
||||||
memory_events = memory_summary.get("events")
|
memory_events = memory_summary.get("events")
|
||||||
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
|
memory_single_prompt = f"记忆id:{memory_id},记忆摘要:{memory_brief}\n"
|
||||||
@@ -129,16 +121,14 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
bot_name=global_config.BOT_NICKNAME,
|
bot_name=global_config.BOT_NICKNAME,
|
||||||
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
time_now=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
|
||||||
chat_observe_info=chat_info,
|
chat_observe_info=chat_info,
|
||||||
memory_str=memory_choose_str
|
memory_str=memory_choose_str,
|
||||||
)
|
)
|
||||||
|
|
||||||
# 调用LLM处理记忆
|
# 调用LLM处理记忆
|
||||||
content = ""
|
content = ""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
logger.debug(f"{self.log_prefix} 处理工作记忆的prompt: {prompt}")
|
logger.debug(f"{self.log_prefix} 处理工作记忆的prompt: {prompt}")
|
||||||
|
|
||||||
|
|
||||||
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
content, _ = await self.llm_model.generate_response_async(prompt=prompt)
|
||||||
if not content:
|
if not content:
|
||||||
logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
|
logger.warning(f"{self.log_prefix} LLM返回空结果,处理工作记忆失败。")
|
||||||
@@ -171,11 +161,11 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
for memory_id in selected_memory_ids:
|
for memory_id in selected_memory_ids:
|
||||||
memory = await working_memory.retrieve_memory(memory_id)
|
memory = await working_memory.retrieve_memory(memory_id)
|
||||||
if memory:
|
if memory:
|
||||||
memory_content = memory.data
|
# memory_content = memory.data
|
||||||
memory_summary = memory.summary
|
memory_summary = memory.summary
|
||||||
memory_id = memory.id
|
memory_id = memory.id
|
||||||
memory_brief = memory_summary.get("brief")
|
memory_brief = memory_summary.get("brief")
|
||||||
memory_detailed = memory_summary.get("detailed")
|
# memory_detailed = memory_summary.get("detailed")
|
||||||
memory_keypoints = memory_summary.get("keypoints")
|
memory_keypoints = memory_summary.get("keypoints")
|
||||||
memory_events = memory_summary.get("events")
|
memory_events = memory_summary.get("events")
|
||||||
for keypoint in memory_keypoints:
|
for keypoint in memory_keypoints:
|
||||||
@@ -185,7 +175,6 @@ class WorkingMemoryProcessor(BaseProcessor):
|
|||||||
# memory_str += f"记忆摘要:{memory_detailed}\n"
|
# memory_str += f"记忆摘要:{memory_detailed}\n"
|
||||||
# memory_str += f"记忆主题:{memory_brief}\n"
|
# memory_str += f"记忆主题:{memory_brief}\n"
|
||||||
|
|
||||||
|
|
||||||
working_memory_info = WorkingMemoryInfo()
|
working_memory_info = WorkingMemoryInfo()
|
||||||
if memory_str:
|
if memory_str:
|
||||||
working_memory_info.add_working_memory(memory_str)
|
working_memory_info.add_working_memory(memory_str)
|
||||||
|
|||||||
@@ -80,7 +80,7 @@ class ActionManager:
|
|||||||
# logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
|
# logger.info(f"所有注册动作: {list(self._registered_actions.keys())}")
|
||||||
# logger.info(f"默认动作: {list(self._default_actions.keys())}")
|
# logger.info(f"默认动作: {list(self._default_actions.keys())}")
|
||||||
# for action_name, action_info in self._default_actions.items():
|
# for action_name, action_info in self._default_actions.items():
|
||||||
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
|
# logger.info(f"动作名称: {action_name}, 动作信息: {action_info}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"加载已注册动作失败: {e}")
|
logger.error(f"加载已注册动作失败: {e}")
|
||||||
@@ -92,7 +92,7 @@ class ActionManager:
|
|||||||
try:
|
try:
|
||||||
# 检查插件目录是否存在
|
# 检查插件目录是否存在
|
||||||
plugin_path = "src.plugins"
|
plugin_path = "src.plugins"
|
||||||
plugin_dir = plugin_path.replace('.', os.path.sep)
|
plugin_dir = plugin_path.replace(".", os.path.sep)
|
||||||
if not os.path.exists(plugin_dir):
|
if not os.path.exists(plugin_dir):
|
||||||
logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载")
|
logger.info(f"插件目录 {plugin_dir} 不存在,跳过插件动作加载")
|
||||||
return
|
return
|
||||||
@@ -105,7 +105,9 @@ class ActionManager:
|
|||||||
return
|
return
|
||||||
|
|
||||||
# 遍历插件包中的所有子包
|
# 遍历插件包中的所有子包
|
||||||
for _, plugin_name, is_pkg in pkgutil.iter_modules(plugins_package.__path__, plugins_package.__name__ + '.'):
|
for _, plugin_name, is_pkg in pkgutil.iter_modules(
|
||||||
|
plugins_package.__path__, plugins_package.__name__ + "."
|
||||||
|
):
|
||||||
if not is_pkg:
|
if not is_pkg:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -316,4 +318,3 @@ class ActionManager:
|
|||||||
Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None
|
Optional[Type[BaseAction]]: 动作处理器类,如果不存在则返回None
|
||||||
"""
|
"""
|
||||||
return _ACTION_REGISTRY.get(action_name)
|
return _ACTION_REGISTRY.get(action_name)
|
||||||
|
|
||||||
|
|||||||
@@ -94,8 +94,7 @@ class NoReplyAction(BaseAction):
|
|||||||
# 等待新消息、超时或关闭信号,并获取结果
|
# 等待新消息、超时或关闭信号,并获取结果
|
||||||
await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix)
|
await self._wait_for_new_message(observation, self.thinking_id, self.log_prefix)
|
||||||
# 从计时器获取实际等待时间
|
# 从计时器获取实际等待时间
|
||||||
current_waiting = self.cycle_timers.get("等待新消息", 0.0)
|
_current_waiting = self.cycle_timers.get("等待新消息", 0.0)
|
||||||
|
|
||||||
|
|
||||||
return True, "" # 不回复动作没有回复文本
|
return True, "" # 不回复动作没有回复文本
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import traceback
|
import traceback
|
||||||
from typing import Tuple, Dict, List, Any, Optional
|
from typing import Tuple, Dict, List, Any, Optional
|
||||||
from src.chat.focus_chat.planners.actions.base_action import BaseAction, register_action
|
from src.chat.focus_chat.planners.actions.base_action import BaseAction
|
||||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||||
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
from src.chat.focus_chat.hfc_utils import create_empty_anchor_message
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
@@ -9,6 +9,7 @@ from abc import abstractmethod
|
|||||||
|
|
||||||
logger = get_logger("plugin_action")
|
logger = get_logger("plugin_action")
|
||||||
|
|
||||||
|
|
||||||
class PluginAction(BaseAction):
|
class PluginAction(BaseAction):
|
||||||
"""插件动作基类
|
"""插件动作基类
|
||||||
|
|
||||||
@@ -61,18 +62,13 @@ class PluginAction(BaseAction):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# 构造简化的动作数据
|
# 构造简化的动作数据
|
||||||
reply_data = {
|
reply_data = {"text": text, "target": target or "", "emojis": []}
|
||||||
"text": text,
|
|
||||||
"target": target or "",
|
|
||||||
"emojis": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取锚定消息(如果有)
|
# 获取锚定消息(如果有)
|
||||||
observations = self._services.get("observations", [])
|
observations = self._services.get("observations", [])
|
||||||
|
|
||||||
chatting_observation: ChattingObservation = next(
|
chatting_observation: ChattingObservation = next(
|
||||||
obs for obs in observations
|
obs for obs in observations if isinstance(obs, ChattingObservation)
|
||||||
if isinstance(obs, ChattingObservation)
|
|
||||||
)
|
)
|
||||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||||
|
|
||||||
@@ -101,7 +97,6 @@ class PluginAction(BaseAction):
|
|||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
|
async def send_message_by_expressor(self, text: str, target: Optional[str] = None) -> bool:
|
||||||
"""发送消息的简化方法
|
"""发送消息的简化方法
|
||||||
|
|
||||||
@@ -121,18 +116,13 @@ class PluginAction(BaseAction):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
# 构造简化的动作数据
|
# 构造简化的动作数据
|
||||||
reply_data = {
|
reply_data = {"text": text, "target": target or "", "emojis": []}
|
||||||
"text": text,
|
|
||||||
"target": target or "",
|
|
||||||
"emojis": []
|
|
||||||
}
|
|
||||||
|
|
||||||
# 获取锚定消息(如果有)
|
# 获取锚定消息(如果有)
|
||||||
observations = self._services.get("observations", [])
|
observations = self._services.get("observations", [])
|
||||||
|
|
||||||
chatting_observation: ChattingObservation = next(
|
chatting_observation: ChattingObservation = next(
|
||||||
obs for obs in observations
|
obs for obs in observations if isinstance(obs, ChattingObservation)
|
||||||
if isinstance(obs, ChattingObservation)
|
|
||||||
)
|
)
|
||||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||||
|
|
||||||
@@ -151,7 +141,7 @@ class PluginAction(BaseAction):
|
|||||||
action_data=reply_data,
|
action_data=reply_data,
|
||||||
anchor_message=anchor_message,
|
anchor_message=anchor_message,
|
||||||
reasoning=self.reasoning,
|
reasoning=self.reasoning,
|
||||||
thinking_id=self.thinking_id
|
thinking_id=self.thinking_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
return success
|
return success
|
||||||
@@ -191,7 +181,7 @@ class PluginAction(BaseAction):
|
|||||||
simple_msg = {
|
simple_msg = {
|
||||||
"sender": msg.get("sender", "未知"),
|
"sender": msg.get("sender", "未知"),
|
||||||
"content": msg.get("content", ""),
|
"content": msg.get("content", ""),
|
||||||
"timestamp": msg.get("timestamp", 0)
|
"timestamp": msg.get("timestamp", 0),
|
||||||
}
|
}
|
||||||
messages.append(simple_msg)
|
messages.append(simple_msg)
|
||||||
|
|
||||||
|
|||||||
@@ -105,8 +105,7 @@ class ReplyAction(BaseAction):
|
|||||||
|
|
||||||
# 从聊天观察获取锚定消息
|
# 从聊天观察获取锚定消息
|
||||||
chatting_observation: ChattingObservation = next(
|
chatting_observation: ChattingObservation = next(
|
||||||
obs for obs in self.observations
|
obs for obs in self.observations if isinstance(obs, ChattingObservation)
|
||||||
if isinstance(obs, ChattingObservation)
|
|
||||||
)
|
)
|
||||||
if reply_data.get("target"):
|
if reply_data.get("target"):
|
||||||
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
anchor_message = chatting_observation.search_message_by_text(reply_data["target"])
|
||||||
|
|||||||
@@ -109,7 +109,7 @@ class ActionPlanner:
|
|||||||
cycle_info = info.get_observe_info()
|
cycle_info = info.get_observe_info()
|
||||||
elif isinstance(info, StructuredInfo):
|
elif isinstance(info, StructuredInfo):
|
||||||
# logger.debug(f"{self.log_prefix} 结构化信息: {info}")
|
# logger.debug(f"{self.log_prefix} 结构化信息: {info}")
|
||||||
structured_info = info.get_data()
|
_structured_info = info.get_data()
|
||||||
else:
|
else:
|
||||||
logger.debug(f"{self.log_prefix} 其他信息: {info}")
|
logger.debug(f"{self.log_prefix} 其他信息: {info}")
|
||||||
extra_info.append(info.get_processed_info())
|
extra_info.append(info.get_processed_info())
|
||||||
|
|||||||
@@ -1,14 +1,7 @@
|
|||||||
from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple
|
from typing import Dict, Any, List, Optional, Set, Tuple
|
||||||
import time
|
import time
|
||||||
import uuid
|
|
||||||
import traceback
|
|
||||||
import random
|
import random
|
||||||
import string
|
import string
|
||||||
from json_repair import repair_json
|
|
||||||
from rich.traceback import install
|
|
||||||
from src.common.logger_manager import get_logger
|
|
||||||
from src.chat.models.utils_model import LLMRequest
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryItem:
|
class MemoryItem:
|
||||||
@@ -25,7 +18,7 @@ class MemoryItem:
|
|||||||
"""
|
"""
|
||||||
# 生成可读ID:时间戳_随机字符串
|
# 生成可读ID:时间戳_随机字符串
|
||||||
timestamp = int(time.time())
|
timestamp = int(time.time())
|
||||||
random_str = ''.join(random.choices(string.ascii_lowercase + string.digits, k=2))
|
random_str = "".join(random.choices(string.ascii_lowercase + string.digits, k=2))
|
||||||
self.id = f"{timestamp}_{random_str}"
|
self.id = f"{timestamp}_{random_str}"
|
||||||
self.data = data
|
self.data = data
|
||||||
self.data_type = type(data)
|
self.data_type = type(data)
|
||||||
|
|||||||
@@ -1,6 +1,4 @@
|
|||||||
from typing import Dict, Any, Type, TypeVar, Generic, List, Optional, Callable, Set, Tuple
|
from typing import Dict, Any, Type, TypeVar, List, Optional
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
import traceback
|
import traceback
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -14,7 +12,7 @@ import json # 添加json模块导入
|
|||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
logger = get_logger("working_memory")
|
logger = get_logger("working_memory")
|
||||||
|
|
||||||
T = TypeVar('T')
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class MemoryManager:
|
class MemoryManager:
|
||||||
@@ -35,10 +33,7 @@ class MemoryManager:
|
|||||||
self._id_map: Dict[str, MemoryItem] = {}
|
self._id_map: Dict[str, MemoryItem] = {}
|
||||||
|
|
||||||
self.llm_summarizer = LLMRequest(
|
self.llm_summarizer = LLMRequest(
|
||||||
model=global_config.llm_summary,
|
model=global_config.llm_summary, temperature=0.3, max_tokens=512, request_type="memory_summarization"
|
||||||
temperature=0.3,
|
|
||||||
max_tokens=512,
|
|
||||||
request_type="memory_summarization"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -122,7 +117,6 @@ class MemoryManager:
|
|||||||
"""
|
"""
|
||||||
memory_item = self._id_map.get(memory_id)
|
memory_item = self._id_map.get(memory_id)
|
||||||
if memory_item:
|
if memory_item:
|
||||||
|
|
||||||
# 检查记忆强度,如果小于1则删除
|
# 检查记忆强度,如果小于1则删除
|
||||||
if not memory_item.is_memory_valid():
|
if not memory_item.is_memory_valid():
|
||||||
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
|
print(f"记忆 {memory_id} 强度过低 ({memory_item.memory_strength}),已自动移除")
|
||||||
@@ -135,16 +129,18 @@ class MemoryManager:
|
|||||||
"""获取所有记忆项"""
|
"""获取所有记忆项"""
|
||||||
return list(self._id_map.values())
|
return list(self._id_map.values())
|
||||||
|
|
||||||
def find_items(self,
|
def find_items(
|
||||||
data_type: Optional[Type] = None,
|
self,
|
||||||
source: Optional[str] = None,
|
data_type: Optional[Type] = None,
|
||||||
tags: Optional[List[str]] = None,
|
source: Optional[str] = None,
|
||||||
start_time: Optional[float] = None,
|
tags: Optional[List[str]] = None,
|
||||||
end_time: Optional[float] = None,
|
start_time: Optional[float] = None,
|
||||||
memory_id: Optional[str] = None,
|
end_time: Optional[float] = None,
|
||||||
limit: Optional[int] = None,
|
memory_id: Optional[str] = None,
|
||||||
newest_first: bool = False,
|
limit: Optional[int] = None,
|
||||||
min_strength: float = 0.0) -> List[MemoryItem]:
|
newest_first: bool = False,
|
||||||
|
min_strength: float = 0.0,
|
||||||
|
) -> List[MemoryItem]:
|
||||||
"""
|
"""
|
||||||
按条件查找记忆项
|
按条件查找记忆项
|
||||||
|
|
||||||
@@ -257,7 +253,7 @@ class MemoryManager:
|
|||||||
"brief": "主题未知的记忆",
|
"brief": "主题未知的记忆",
|
||||||
"detailed": "大致内容未知的记忆",
|
"detailed": "大致内容未知的记忆",
|
||||||
"keypoints": ["未知的概念"],
|
"keypoints": ["未知的概念"],
|
||||||
"events": ["未知的事件"]
|
"events": ["未知的事件"],
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -297,10 +293,7 @@ class MemoryManager:
|
|||||||
json_result["keypoints"] = ["未知的概念"]
|
json_result["keypoints"] = ["未知的概念"]
|
||||||
else:
|
else:
|
||||||
# 确保keypoints中的每个项目都是字符串
|
# 确保keypoints中的每个项目都是字符串
|
||||||
json_result["keypoints"] = [
|
json_result["keypoints"] = [str(point) for point in json_result["keypoints"] if point is not None]
|
||||||
str(point) for point in json_result["keypoints"]
|
|
||||||
if point is not None
|
|
||||||
]
|
|
||||||
if not json_result["keypoints"]:
|
if not json_result["keypoints"]:
|
||||||
json_result["keypoints"] = ["未知的概念"]
|
json_result["keypoints"] = ["未知的概念"]
|
||||||
|
|
||||||
@@ -309,10 +302,7 @@ class MemoryManager:
|
|||||||
json_result["events"] = ["未知的事件"]
|
json_result["events"] = ["未知的事件"]
|
||||||
else:
|
else:
|
||||||
# 确保events中的每个项目都是字符串
|
# 确保events中的每个项目都是字符串
|
||||||
json_result["events"] = [
|
json_result["events"] = [str(event) for event in json_result["events"] if event is not None]
|
||||||
str(event) for event in json_result["events"]
|
|
||||||
if event is not None
|
|
||||||
]
|
|
||||||
if not json_result["events"]:
|
if not json_result["events"]:
|
||||||
json_result["events"] = ["未知的事件"]
|
json_result["events"] = ["未知的事件"]
|
||||||
|
|
||||||
@@ -331,9 +321,7 @@ class MemoryManager:
|
|||||||
logger.error(f"生成总结时出错: {str(e)}")
|
logger.error(f"生成总结时出错: {str(e)}")
|
||||||
return default_summary
|
return default_summary
|
||||||
|
|
||||||
async def refine_memory(self,
|
async def refine_memory(self, memory_id: str, requirements: str = "") -> Dict[str, Any]:
|
||||||
memory_id: str,
|
|
||||||
requirements: str = "") -> Dict[str, Any]:
|
|
||||||
"""
|
"""
|
||||||
对记忆进行精简操作,根据要求修改要点、总结和概括
|
对记忆进行精简操作,根据要求修改要点、总结和概括
|
||||||
|
|
||||||
@@ -402,7 +390,7 @@ class MemoryManager:
|
|||||||
"brief": summary["brief"],
|
"brief": summary["brief"],
|
||||||
"detailed": summary["detailed"],
|
"detailed": summary["detailed"],
|
||||||
"keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念
|
"keypoints": summary.get("keypoints", ["未知的概念"])[:1], # 默认只保留第一个关键概念
|
||||||
"events": summary.get("events", ["未知的事件"])[:1] # 默认只保留第一个事件
|
"events": summary.get("events", ["未知的事件"])[:1], # 默认只保留第一个事件
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -498,7 +486,6 @@ class MemoryManager:
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def delete(self, memory_id: str) -> bool:
|
def delete(self, memory_id: str) -> bool:
|
||||||
"""
|
"""
|
||||||
删除指定ID的记忆项
|
删除指定ID的记忆项
|
||||||
@@ -543,7 +530,9 @@ class MemoryManager:
|
|||||||
del self._id_map[item.id]
|
del self._id_map[item.id]
|
||||||
del self._memory[data_type]
|
del self._memory[data_type]
|
||||||
|
|
||||||
async def merge_memories(self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True) -> MemoryItem:
|
async def merge_memories(
|
||||||
|
self, memory_id1: str, memory_id2: str, reason: str, delete_originals: bool = True
|
||||||
|
) -> MemoryItem:
|
||||||
"""
|
"""
|
||||||
合并两个记忆项
|
合并两个记忆项
|
||||||
|
|
||||||
@@ -584,24 +573,24 @@ class MemoryManager:
|
|||||||
prompt += f"记忆1概括:{summary1['detailed']}\n"
|
prompt += f"记忆1概括:{summary1['detailed']}\n"
|
||||||
|
|
||||||
if "keypoints" in summary1:
|
if "keypoints" in summary1:
|
||||||
prompt += f"记忆1关键概念:\n" + "\n".join([f"- {point}" for point in summary1['keypoints']]) + "\n\n"
|
prompt += "记忆1关键概念:\n" + "\n".join([f"- {point}" for point in summary1["keypoints"]]) + "\n\n"
|
||||||
|
|
||||||
if "events" in summary1:
|
if "events" in summary1:
|
||||||
prompt += f"记忆1事件:\n" + "\n".join([f"- {point}" for point in summary1['events']]) + "\n\n"
|
prompt += "记忆1事件:\n" + "\n".join([f"- {point}" for point in summary1["events"]]) + "\n\n"
|
||||||
elif "key_points" in summary1:
|
elif "key_points" in summary1:
|
||||||
prompt += f"记忆1要点:\n" + "\n".join([f"- {point}" for point in summary1['key_points']]) + "\n\n"
|
prompt += "记忆1要点:\n" + "\n".join([f"- {point}" for point in summary1["key_points"]]) + "\n\n"
|
||||||
|
|
||||||
if summary2:
|
if summary2:
|
||||||
prompt += f"记忆2主题:{summary2['brief']}\n"
|
prompt += f"记忆2主题:{summary2['brief']}\n"
|
||||||
prompt += f"记忆2概括:{summary2['detailed']}\n"
|
prompt += f"记忆2概括:{summary2['detailed']}\n"
|
||||||
|
|
||||||
if "keypoints" in summary2:
|
if "keypoints" in summary2:
|
||||||
prompt += f"记忆2关键概念:\n" + "\n".join([f"- {point}" for point in summary2['keypoints']]) + "\n\n"
|
prompt += "记忆2关键概念:\n" + "\n".join([f"- {point}" for point in summary2["keypoints"]]) + "\n\n"
|
||||||
|
|
||||||
if "events" in summary2:
|
if "events" in summary2:
|
||||||
prompt += f"记忆2事件:\n" + "\n".join([f"- {point}" for point in summary2['events']]) + "\n\n"
|
prompt += "记忆2事件:\n" + "\n".join([f"- {point}" for point in summary2["events"]]) + "\n\n"
|
||||||
elif "key_points" in summary2:
|
elif "key_points" in summary2:
|
||||||
prompt += f"记忆2要点:\n" + "\n".join([f"- {point}" for point in summary2['key_points']]) + "\n\n"
|
prompt += "记忆2要点:\n" + "\n".join([f"- {point}" for point in summary2["key_points"]]) + "\n\n"
|
||||||
|
|
||||||
# 添加记忆原始内容
|
# 添加记忆原始内容
|
||||||
prompt += f"""
|
prompt += f"""
|
||||||
@@ -637,7 +626,7 @@ class MemoryManager:
|
|||||||
"brief": f"合并:{summary1['brief']} + {summary2['brief']}",
|
"brief": f"合并:{summary1['brief']} + {summary2['brief']}",
|
||||||
"detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}",
|
"detailed": f"合并了两个记忆:{summary1['detailed']} 以及 {summary2['detailed']}",
|
||||||
"keypoints": [],
|
"keypoints": [],
|
||||||
"events": []
|
"events": [],
|
||||||
}
|
}
|
||||||
|
|
||||||
# 合并旧版key_points
|
# 合并旧版key_points
|
||||||
@@ -710,10 +699,7 @@ class MemoryManager:
|
|||||||
merged_data["keypoints"] = default_merged["keypoints"]
|
merged_data["keypoints"] = default_merged["keypoints"]
|
||||||
else:
|
else:
|
||||||
# 确保keypoints中的每个项目都是字符串
|
# 确保keypoints中的每个项目都是字符串
|
||||||
merged_data["keypoints"] = [
|
merged_data["keypoints"] = [str(point) for point in merged_data["keypoints"] if point is not None]
|
||||||
str(point) for point in merged_data["keypoints"]
|
|
||||||
if point is not None
|
|
||||||
]
|
|
||||||
if not merged_data["keypoints"]:
|
if not merged_data["keypoints"]:
|
||||||
merged_data["keypoints"] = ["合并的关键概念"]
|
merged_data["keypoints"] = ["合并的关键概念"]
|
||||||
|
|
||||||
@@ -722,10 +708,7 @@ class MemoryManager:
|
|||||||
merged_data["events"] = default_merged["events"]
|
merged_data["events"] = default_merged["events"]
|
||||||
else:
|
else:
|
||||||
# 确保events中的每个项目都是字符串
|
# 确保events中的每个项目都是字符串
|
||||||
merged_data["events"] = [
|
merged_data["events"] = [str(event) for event in merged_data["events"] if event is not None]
|
||||||
str(event) for event in merged_data["events"]
|
|
||||||
if event is not None
|
|
||||||
]
|
|
||||||
if not merged_data["events"]:
|
if not merged_data["events"]:
|
||||||
merged_data["events"] = ["合并的事件"]
|
merged_data["events"] = ["合并的事件"]
|
||||||
|
|
||||||
@@ -746,14 +729,14 @@ class MemoryManager:
|
|||||||
merged_tags = memory_item1.tags.union(memory_item2.tags)
|
merged_tags = memory_item1.tags.union(memory_item2.tags)
|
||||||
|
|
||||||
# 取两个记忆项中更强的来源
|
# 取两个记忆项中更强的来源
|
||||||
merged_source = memory_item1.from_source if memory_item1.memory_strength >= memory_item2.memory_strength else memory_item2.from_source
|
merged_source = (
|
||||||
|
memory_item1.from_source
|
||||||
|
if memory_item1.memory_strength >= memory_item2.memory_strength
|
||||||
|
else memory_item2.from_source
|
||||||
|
)
|
||||||
|
|
||||||
# 创建新的记忆项
|
# 创建新的记忆项
|
||||||
merged_memory = MemoryItem(
|
merged_memory = MemoryItem(data=merged_data["content"], from_source=merged_source, tags=list(merged_tags))
|
||||||
data=merged_data["content"],
|
|
||||||
from_source=merged_source,
|
|
||||||
tags=list(merged_tags)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 设置合并后的摘要
|
# 设置合并后的摘要
|
||||||
summary = {
|
summary = {
|
||||||
@@ -761,7 +744,7 @@ class MemoryManager:
|
|||||||
"detailed": merged_data["detailed"],
|
"detailed": merged_data["detailed"],
|
||||||
"keypoints": merged_data["keypoints"],
|
"keypoints": merged_data["keypoints"],
|
||||||
"events": merged_data["events"],
|
"events": merged_data["events"],
|
||||||
"key_points": merged_data["key_points"]
|
"key_points": merged_data["key_points"],
|
||||||
}
|
}
|
||||||
merged_memory.set_summary(summary)
|
merged_memory.set_summary(summary)
|
||||||
|
|
||||||
|
|||||||
@@ -1,169 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import os
|
|
||||||
import asyncio
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
|
||||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
|
||||||
from src.common.logger_manager import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("memory_loader")
|
|
||||||
|
|
||||||
class MemoryFileLoader:
|
|
||||||
"""从文件加载记忆内容的工具类"""
|
|
||||||
|
|
||||||
def __init__(self, working_memory: WorkingMemory):
|
|
||||||
"""
|
|
||||||
初始化记忆文件加载器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
working_memory: 工作记忆实例
|
|
||||||
"""
|
|
||||||
self.working_memory = working_memory
|
|
||||||
|
|
||||||
async def load_from_directory(self,
|
|
||||||
directory_path: str,
|
|
||||||
file_pattern: str = "*.txt",
|
|
||||||
common_tags: List[str] = None,
|
|
||||||
source_prefix: str = "文件") -> List[MemoryItem]:
|
|
||||||
"""
|
|
||||||
从指定目录加载符合模式的文件作为记忆
|
|
||||||
|
|
||||||
Args:
|
|
||||||
directory_path: 目录路径
|
|
||||||
file_pattern: 文件模式(默认为*.txt)
|
|
||||||
common_tags: 所有记忆共有的标签
|
|
||||||
source_prefix: 来源前缀
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
加载的记忆项列表
|
|
||||||
"""
|
|
||||||
directory = Path(directory_path)
|
|
||||||
if not directory.exists() or not directory.is_dir():
|
|
||||||
logger.error(f"目录不存在或不是有效目录: {directory_path}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 获取文件列表
|
|
||||||
files = list(directory.glob(file_pattern))
|
|
||||||
if not files:
|
|
||||||
logger.warning(f"在目录 {directory_path} 中没有找到符合 {file_pattern} 的文件")
|
|
||||||
return []
|
|
||||||
|
|
||||||
logger.info(f"在目录 {directory_path} 中找到 {len(files)} 个符合条件的文件")
|
|
||||||
|
|
||||||
# 加载文件内容为记忆
|
|
||||||
loaded_memories = []
|
|
||||||
for file_path in files:
|
|
||||||
try:
|
|
||||||
memory_item = await self._load_single_file(
|
|
||||||
file_path=str(file_path),
|
|
||||||
common_tags=common_tags,
|
|
||||||
source_prefix=source_prefix
|
|
||||||
)
|
|
||||||
if memory_item:
|
|
||||||
loaded_memories.append(memory_item)
|
|
||||||
logger.info(f"成功加载记忆: {file_path.name}")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载文件 {file_path} 失败: {str(e)}")
|
|
||||||
|
|
||||||
logger.info(f"完成加载,共加载了 {len(loaded_memories)} 个记忆")
|
|
||||||
return loaded_memories
|
|
||||||
|
|
||||||
async def _load_single_file(self,
|
|
||||||
file_path: str,
|
|
||||||
common_tags: Optional[List[str]] = None,
|
|
||||||
source_prefix: str = "文件") -> Optional[MemoryItem]:
|
|
||||||
"""
|
|
||||||
加载单个文件作为记忆
|
|
||||||
|
|
||||||
Args:
|
|
||||||
file_path: 文件路径
|
|
||||||
common_tags: 记忆共有的标签
|
|
||||||
source_prefix: 来源前缀
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
记忆项,加载失败则返回None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 读取文件内容
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
if not content.strip():
|
|
||||||
logger.warning(f"文件 {file_path} 内容为空")
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 准备标签和来源
|
|
||||||
file_name = os.path.basename(file_path)
|
|
||||||
tags = list(common_tags) if common_tags else []
|
|
||||||
tags.append(file_name) # 添加文件名作为标签
|
|
||||||
|
|
||||||
source = f"{source_prefix}_{file_name}"
|
|
||||||
|
|
||||||
# 添加到工作记忆
|
|
||||||
memory = await self.working_memory.add_memory(
|
|
||||||
content=content,
|
|
||||||
from_source=source,
|
|
||||||
tags=tags
|
|
||||||
)
|
|
||||||
|
|
||||||
return memory
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"加载文件 {file_path} 失败: {str(e)}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""示例使用"""
|
|
||||||
# 初始化工作记忆
|
|
||||||
chat_id = "demo_chat"
|
|
||||||
working_memory = WorkingMemory(chat_id=chat_id)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 初始化加载器
|
|
||||||
loader = MemoryFileLoader(working_memory)
|
|
||||||
|
|
||||||
# 加载当前目录中的txt文件
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
memories = await loader.load_from_directory(
|
|
||||||
directory_path=str(current_dir),
|
|
||||||
file_pattern="*.txt",
|
|
||||||
common_tags=["测试数据", "自动加载"],
|
|
||||||
source_prefix="测试文件"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 显示加载结果
|
|
||||||
print(f"共加载了 {len(memories)} 个记忆")
|
|
||||||
|
|
||||||
# 获取并显示所有记忆的概要
|
|
||||||
all_memories = working_memory.memory_manager.get_all_items()
|
|
||||||
for memory in all_memories:
|
|
||||||
print("\n" + "=" * 40)
|
|
||||||
print(f"记忆ID: {memory.id}")
|
|
||||||
print(f"来源: {memory.from_source}")
|
|
||||||
print(f"标签: {', '.join(memory.tags)}")
|
|
||||||
|
|
||||||
if memory.summary:
|
|
||||||
print(f"\n主题: {memory.summary.get('brief', '无主题')}")
|
|
||||||
print(f"概述: {memory.summary.get('detailed', '无概述')}")
|
|
||||||
print("\n要点:")
|
|
||||||
for point in memory.summary.get('key_points', []):
|
|
||||||
print(f"- {point}")
|
|
||||||
else:
|
|
||||||
print("\n无摘要信息")
|
|
||||||
|
|
||||||
print("=" * 40)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 关闭工作记忆
|
|
||||||
await working_memory.shutdown()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 运行示例
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# 添加项目根目录到系统路径
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
project_root = current_dir.parent.parent.parent.parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
|
||||||
|
|
||||||
async def test_load_memories_from_files():
|
|
||||||
"""测试从文件加载记忆的功能"""
|
|
||||||
print("开始测试从文件加载记忆...")
|
|
||||||
|
|
||||||
# 初始化工作记忆
|
|
||||||
chat_id = "test_memory_load"
|
|
||||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取测试文件列表
|
|
||||||
test_dir = Path(__file__).parent
|
|
||||||
test_files = [
|
|
||||||
os.path.join(test_dir, f)
|
|
||||||
for f in os.listdir(test_dir)
|
|
||||||
if f.endswith(".txt")
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"找到 {len(test_files)} 个测试文件")
|
|
||||||
|
|
||||||
# 从每个文件加载记忆
|
|
||||||
for file_path in test_files:
|
|
||||||
file_name = os.path.basename(file_path)
|
|
||||||
print(f"从文件 {file_name} 加载记忆...")
|
|
||||||
|
|
||||||
# 读取文件内容
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# 添加记忆
|
|
||||||
memory = await working_memory.add_memory(
|
|
||||||
content=content,
|
|
||||||
from_source=f"文件_{file_name}",
|
|
||||||
tags=["测试文件", file_name]
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"已添加记忆: ID={memory.id}")
|
|
||||||
if memory.summary:
|
|
||||||
print(f"记忆概要: {memory.summary.get('brief', '无概要')}")
|
|
||||||
print(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点']))}")
|
|
||||||
print("-" * 50)
|
|
||||||
|
|
||||||
# 获取所有记忆
|
|
||||||
all_memories = working_memory.memory_manager.get_all_items()
|
|
||||||
print(f"\n成功加载 {len(all_memories)} 个记忆")
|
|
||||||
|
|
||||||
# 测试检索记忆
|
|
||||||
if all_memories:
|
|
||||||
print("\n测试检索第一个记忆...")
|
|
||||||
first_memory = all_memories[0]
|
|
||||||
retrieved = await working_memory.retrieve_memory(first_memory.id)
|
|
||||||
|
|
||||||
if retrieved:
|
|
||||||
print(f"成功检索记忆: ID={retrieved.id}")
|
|
||||||
print(f"检索后强度: {retrieved.memory_strength} (初始为10.0)")
|
|
||||||
print(f"检索次数: {retrieved.retrieval_count}")
|
|
||||||
else:
|
|
||||||
print("检索失败")
|
|
||||||
|
|
||||||
# 测试记忆衰减
|
|
||||||
print("\n测试记忆衰减...")
|
|
||||||
for memory in all_memories:
|
|
||||||
print(f"记忆 {memory.id} 衰减前强度: {memory.memory_strength}")
|
|
||||||
|
|
||||||
await working_memory.decay_all_memories(decay_factor=0.5)
|
|
||||||
|
|
||||||
all_memories_after = working_memory.memory_manager.get_all_items()
|
|
||||||
for memory in all_memories_after:
|
|
||||||
print(f"记忆 {memory.id} 衰减后强度: {memory.memory_strength}")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
# 关闭工作记忆
|
|
||||||
await working_memory.shutdown()
|
|
||||||
print("\n测试完成,已关闭工作记忆")
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 运行测试
|
|
||||||
asyncio.run(test_load_memories_from_files())
|
|
||||||
@@ -1,197 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import random
|
|
||||||
from pathlib import Path
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# 添加项目根目录到系统路径
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
project_root = current_dir.parent.parent.parent.parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
|
||||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
|
||||||
from src.common.logger_manager import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("real_usage_simulation")
|
|
||||||
|
|
||||||
class WorkingMemorySimulator:
|
|
||||||
"""模拟工作记忆的真实使用场景"""
|
|
||||||
|
|
||||||
def __init__(self, chat_id="real_usage_test", cycle_interval=20):
|
|
||||||
"""
|
|
||||||
初始化模拟器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_id: 聊天ID
|
|
||||||
cycle_interval: 循环间隔时间(秒)
|
|
||||||
"""
|
|
||||||
self.chat_id = chat_id
|
|
||||||
self.cycle_interval = cycle_interval
|
|
||||||
self.working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=20, auto_decay_interval=60)
|
|
||||||
self.cycle_count = 0
|
|
||||||
self.running = False
|
|
||||||
|
|
||||||
# 获取测试文件路径
|
|
||||||
self.test_files = self._get_test_files()
|
|
||||||
if not self.test_files:
|
|
||||||
raise FileNotFoundError("找不到测试文件,请确保test目录中有.txt文件")
|
|
||||||
|
|
||||||
# 存储所有添加的记忆ID
|
|
||||||
self.memory_ids = []
|
|
||||||
|
|
||||||
async def start(self, total_cycles=5):
|
|
||||||
"""
|
|
||||||
开始模拟循环
|
|
||||||
|
|
||||||
Args:
|
|
||||||
total_cycles: 总循环次数,设为None表示无限循环
|
|
||||||
"""
|
|
||||||
self.running = True
|
|
||||||
logger.info(f"开始模拟真实使用场景,循环间隔: {self.cycle_interval}秒")
|
|
||||||
|
|
||||||
try:
|
|
||||||
while self.running and (total_cycles is None or self.cycle_count < total_cycles):
|
|
||||||
self.cycle_count += 1
|
|
||||||
logger.info(f"\n===== 开始第 {self.cycle_count} 次循环 =====")
|
|
||||||
|
|
||||||
# 执行一次循环
|
|
||||||
await self._run_one_cycle()
|
|
||||||
|
|
||||||
# 如果还有更多循环,则等待
|
|
||||||
if self.running and (total_cycles is None or self.cycle_count < total_cycles):
|
|
||||||
wait_time = self.cycle_interval
|
|
||||||
logger.info(f"等待 {wait_time} 秒后开始下一循环...")
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
|
|
||||||
logger.info(f"模拟完成,共执行了 {self.cycle_count} 次循环")
|
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
logger.info("接收到中断信号,停止模拟")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"模拟过程中出错: {str(e)}", exc_info=True)
|
|
||||||
finally:
|
|
||||||
# 关闭工作记忆
|
|
||||||
await self.working_memory.shutdown()
|
|
||||||
|
|
||||||
def stop(self):
|
|
||||||
"""停止模拟循环"""
|
|
||||||
self.running = False
|
|
||||||
logger.info("正在停止模拟...")
|
|
||||||
|
|
||||||
async def _run_one_cycle(self):
|
|
||||||
"""运行一次完整循环:先检索记忆,再添加新记忆"""
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
# 1. 先检索已有记忆(如果有)
|
|
||||||
await self._retrieve_memories()
|
|
||||||
|
|
||||||
# 2. 添加新记忆
|
|
||||||
await self._add_new_memory()
|
|
||||||
|
|
||||||
# 3. 显示工作记忆状态
|
|
||||||
await self._show_memory_status()
|
|
||||||
|
|
||||||
# 计算循环耗时
|
|
||||||
cycle_duration = time.time() - start_time
|
|
||||||
logger.info(f"第 {self.cycle_count} 次循环完成,耗时: {cycle_duration:.2f}秒")
|
|
||||||
|
|
||||||
async def _retrieve_memories(self):
|
|
||||||
"""检索现有记忆"""
|
|
||||||
# 如果有已保存的记忆ID,随机选择1-3个进行检索
|
|
||||||
if self.memory_ids:
|
|
||||||
num_to_retrieve = min(len(self.memory_ids), random.randint(1, 3))
|
|
||||||
retrieval_ids = random.sample(self.memory_ids, num_to_retrieve)
|
|
||||||
|
|
||||||
logger.info(f"正在检索 {num_to_retrieve} 条记忆...")
|
|
||||||
|
|
||||||
for memory_id in retrieval_ids:
|
|
||||||
memory = await self.working_memory.retrieve_memory(memory_id)
|
|
||||||
if memory:
|
|
||||||
logger.info(f"成功检索记忆 ID: {memory_id}")
|
|
||||||
logger.info(f" - 强度: {memory.memory_strength:.2f},检索次数: {memory.retrieval_count}")
|
|
||||||
if memory.summary:
|
|
||||||
logger.info(f" - 主题: {memory.summary.get('brief', '无主题')}")
|
|
||||||
else:
|
|
||||||
logger.warning(f"记忆 ID: {memory_id} 不存在或已被移除")
|
|
||||||
# 从ID列表中移除
|
|
||||||
if memory_id in self.memory_ids:
|
|
||||||
self.memory_ids.remove(memory_id)
|
|
||||||
else:
|
|
||||||
logger.info("当前没有可检索的记忆")
|
|
||||||
|
|
||||||
async def _add_new_memory(self):
|
|
||||||
"""添加新记忆"""
|
|
||||||
# 随机选择一个测试文件作为记忆内容
|
|
||||||
file_path = random.choice(self.test_files)
|
|
||||||
file_name = os.path.basename(file_path)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 读取文件内容
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
content = f.read()
|
|
||||||
|
|
||||||
# 添加时间戳,模拟不同内容
|
|
||||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
content_with_timestamp = f"[{timestamp}] {content}"
|
|
||||||
|
|
||||||
# 添加记忆
|
|
||||||
logger.info(f"正在添加新记忆,来源: {file_name}")
|
|
||||||
memory = await self.working_memory.add_memory(
|
|
||||||
content=content_with_timestamp,
|
|
||||||
from_source=f"模拟_{file_name}",
|
|
||||||
tags=["模拟测试", f"循环{self.cycle_count}", file_name]
|
|
||||||
)
|
|
||||||
|
|
||||||
# 保存记忆ID
|
|
||||||
self.memory_ids.append(memory.id)
|
|
||||||
|
|
||||||
# 显示记忆信息
|
|
||||||
logger.info(f"已添加新记忆 ID: {memory.id}")
|
|
||||||
if memory.summary:
|
|
||||||
logger.info(f"记忆主题: {memory.summary.get('brief', '无主题')}")
|
|
||||||
logger.info(f"记忆要点: {', '.join(memory.summary.get('key_points', ['无要点'])[:2])}...")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"添加记忆失败: {str(e)}")
|
|
||||||
|
|
||||||
async def _show_memory_status(self):
|
|
||||||
"""显示当前工作记忆状态"""
|
|
||||||
all_memories = self.working_memory.memory_manager.get_all_items()
|
|
||||||
|
|
||||||
logger.info(f"\n当前工作记忆状态:")
|
|
||||||
logger.info(f"记忆总数: {len(all_memories)}")
|
|
||||||
|
|
||||||
# 按强度排序
|
|
||||||
sorted_memories = sorted(all_memories, key=lambda x: x.memory_strength, reverse=True)
|
|
||||||
|
|
||||||
logger.info("记忆强度排名 (前5项):")
|
|
||||||
for i, memory in enumerate(sorted_memories[:5], 1):
|
|
||||||
logger.info(f"{i}. ID: {memory.id}, 强度: {memory.memory_strength:.2f}, "
|
|
||||||
f"检索次数: {memory.retrieval_count}, "
|
|
||||||
f"主题: {memory.summary.get('brief', '无主题') if memory.summary else '无摘要'}")
|
|
||||||
|
|
||||||
def _get_test_files(self):
|
|
||||||
"""获取测试文件列表"""
|
|
||||||
test_dir = Path(__file__).parent
|
|
||||||
return [
|
|
||||||
os.path.join(test_dir, f)
|
|
||||||
for f in os.listdir(test_dir)
|
|
||||||
if f.endswith(".txt")
|
|
||||||
]
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""主函数"""
|
|
||||||
# 创建模拟器
|
|
||||||
simulator = WorkingMemorySimulator(cycle_interval=20) # 设置20秒的循环间隔
|
|
||||||
|
|
||||||
# 设置运行5个循环
|
|
||||||
await simulator.start(total_cycles=5)
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,323 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# 添加项目根目录到系统路径
|
|
||||||
current_dir = Path(__file__).parent
|
|
||||||
project_root = current_dir.parent.parent.parent.parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
|
||||||
from src.chat.focus_chat.working_memory.test.memory_file_loader import MemoryFileLoader
|
|
||||||
from src.common.logger_manager import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("memory_decay_test")
|
|
||||||
|
|
||||||
async def test_manual_decay_until_removal():
|
|
||||||
"""测试手动衰减直到记忆被自动移除"""
|
|
||||||
print("\n===== 测试手动衰减直到记忆被自动移除 =====")
|
|
||||||
|
|
||||||
# 初始化工作记忆,设置较大的衰减间隔,避免自动衰减影响测试
|
|
||||||
chat_id = "decay_test_manual"
|
|
||||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=3600)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建加载器并加载测试文件
|
|
||||||
loader = MemoryFileLoader(working_memory)
|
|
||||||
test_dir = current_dir
|
|
||||||
|
|
||||||
# 加载第一个测试文件作为记忆
|
|
||||||
memories = await loader.load_from_directory(
|
|
||||||
directory_path=str(test_dir),
|
|
||||||
file_pattern="test1.txt", # 只加载test1.txt
|
|
||||||
common_tags=["测试", "衰减", "自动移除"],
|
|
||||||
source_prefix="衰减测试"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not memories:
|
|
||||||
print("未能加载记忆文件,测试结束")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 获取加载的记忆
|
|
||||||
memory = memories[0]
|
|
||||||
memory_id = memory.id
|
|
||||||
print(f"已加载测试记忆,ID: {memory_id}")
|
|
||||||
print(f"初始强度: {memory.memory_strength}")
|
|
||||||
if memory.summary:
|
|
||||||
print(f"记忆主题: {memory.summary.get('brief', '无主题')}")
|
|
||||||
|
|
||||||
# 执行多次衰减,直到记忆被移除
|
|
||||||
decay_count = 0
|
|
||||||
decay_factor = 0.5 # 每次衰减为原来的一半
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# 获取当前记忆
|
|
||||||
current_memory = working_memory.memory_manager.get_by_id(memory_id)
|
|
||||||
|
|
||||||
# 如果记忆已被移除,退出循环
|
|
||||||
if current_memory is None:
|
|
||||||
print(f"记忆已在第 {decay_count} 次衰减后被自动移除!")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 输出当前强度
|
|
||||||
print(f"衰减 {decay_count} 次后强度: {current_memory.memory_strength}")
|
|
||||||
|
|
||||||
# 执行衰减
|
|
||||||
await working_memory.decay_all_memories(decay_factor=decay_factor)
|
|
||||||
decay_count += 1
|
|
||||||
|
|
||||||
# 输出衰减后的详细信息
|
|
||||||
after_memory = working_memory.memory_manager.get_by_id(memory_id)
|
|
||||||
if after_memory:
|
|
||||||
print(f"第 {decay_count} 次衰减结果: 强度={after_memory.memory_strength},压缩次数={after_memory.compress_count}")
|
|
||||||
if after_memory.summary:
|
|
||||||
print(f"记忆概要: {after_memory.summary.get('brief', '无概要')}")
|
|
||||||
print(f"记忆要点数量: {len(after_memory.summary.get('key_points', []))}")
|
|
||||||
else:
|
|
||||||
print(f"第 {decay_count} 次衰减结果: 记忆已被移除")
|
|
||||||
|
|
||||||
# 防止无限循环
|
|
||||||
if decay_count > 20:
|
|
||||||
print("达到最大衰减次数(20),退出测试。")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 短暂等待
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
# 验证记忆是否真的被移除
|
|
||||||
all_memories = working_memory.memory_manager.get_all_items()
|
|
||||||
print(f"剩余记忆数量: {len(all_memories)}")
|
|
||||||
if len(all_memories) == 0:
|
|
||||||
print("测试通过: 记忆在强度低于阈值后被成功移除。")
|
|
||||||
else:
|
|
||||||
print("测试失败: 记忆应该被移除但仍然存在。")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await working_memory.shutdown()
|
|
||||||
|
|
||||||
async def test_auto_decay():
|
|
||||||
"""测试自动衰减功能"""
|
|
||||||
print("\n===== 测试自动衰减功能 =====")
|
|
||||||
|
|
||||||
# 初始化工作记忆,设置短的衰减间隔,便于测试
|
|
||||||
chat_id = "decay_test_auto"
|
|
||||||
decay_interval = 3 # 3秒
|
|
||||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=decay_interval)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建加载器并加载测试文件
|
|
||||||
loader = MemoryFileLoader(working_memory)
|
|
||||||
test_dir = current_dir
|
|
||||||
|
|
||||||
# 加载第二个测试文件作为记忆
|
|
||||||
memories = await loader.load_from_directory(
|
|
||||||
directory_path=str(test_dir),
|
|
||||||
file_pattern="test1.txt", # 只加载test2.txt
|
|
||||||
common_tags=["测试", "自动衰减"],
|
|
||||||
source_prefix="自动衰减测试"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not memories:
|
|
||||||
print("未能加载记忆文件,测试结束")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 获取加载的记忆
|
|
||||||
memory = memories[0]
|
|
||||||
memory_id = memory.id
|
|
||||||
print(f"已加载测试记忆,ID: {memory_id}")
|
|
||||||
print(f"初始强度: {memory.memory_strength}")
|
|
||||||
if memory.summary:
|
|
||||||
print(f"记忆主题: {memory.summary.get('brief', '无主题')}")
|
|
||||||
print(f"记忆概要: {memory.summary.get('detailed', '无概要')}")
|
|
||||||
print(f"记忆要点: {memory.summary.get('keypoints', '无要点')}")
|
|
||||||
print(f"记忆事件: {memory.summary.get('events', '无事件')}")
|
|
||||||
# 观察自动衰减
|
|
||||||
print(f"等待自动衰减任务执行 (间隔 {decay_interval} 秒)...")
|
|
||||||
|
|
||||||
for i in range(3): # 观察3次自动衰减
|
|
||||||
# 等待自动衰减发生
|
|
||||||
await asyncio.sleep(decay_interval + 1) # 多等1秒确保任务执行
|
|
||||||
|
|
||||||
# 获取当前记忆
|
|
||||||
current_memory = working_memory.memory_manager.get_by_id(memory_id)
|
|
||||||
|
|
||||||
# 如果记忆已被移除,退出循环
|
|
||||||
if current_memory is None:
|
|
||||||
print(f"记忆已在第 {i+1} 次自动衰减后被移除!")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 输出当前强度和详细信息
|
|
||||||
print(f"第 {i+1} 次自动衰减后强度: {current_memory.memory_strength}")
|
|
||||||
print(f"自动衰减详细结果: 压缩次数={current_memory.compress_count}, 提取次数={current_memory.retrieval_count}")
|
|
||||||
if current_memory.summary:
|
|
||||||
print(f"记忆概要: {current_memory.summary.get('brief', '无概要')}")
|
|
||||||
|
|
||||||
print(f"\n自动衰减测试结束。")
|
|
||||||
|
|
||||||
# 验证自动衰减是否发生
|
|
||||||
final_memory = working_memory.memory_manager.get_by_id(memory_id)
|
|
||||||
if final_memory is None:
|
|
||||||
print("记忆已被自动衰减移除。")
|
|
||||||
elif final_memory.memory_strength < memory.memory_strength:
|
|
||||||
print(f"自动衰减有效:初始强度 {memory.memory_strength} -> 最终强度 {final_memory.memory_strength}")
|
|
||||||
print(f"衰减历史记录: {final_memory.history}")
|
|
||||||
else:
|
|
||||||
print("测试失败:记忆强度未减少,自动衰减可能未生效。")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await working_memory.shutdown()
|
|
||||||
|
|
||||||
async def test_decay_and_retrieval_balance():
|
|
||||||
"""测试记忆衰减和检索的平衡"""
|
|
||||||
print("\n===== 测试记忆衰减和检索的平衡 =====")
|
|
||||||
|
|
||||||
# 初始化工作记忆
|
|
||||||
chat_id = "decay_retrieval_balance"
|
|
||||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建加载器并加载测试文件
|
|
||||||
loader = MemoryFileLoader(working_memory)
|
|
||||||
test_dir = current_dir
|
|
||||||
|
|
||||||
# 加载第三个测试文件作为记忆
|
|
||||||
memories = await loader.load_from_directory(
|
|
||||||
directory_path=str(test_dir),
|
|
||||||
file_pattern="test3.txt", # 只加载test3.txt
|
|
||||||
common_tags=["测试", "衰减", "检索"],
|
|
||||||
source_prefix="平衡测试"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not memories:
|
|
||||||
print("未能加载记忆文件,测试结束")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 获取加载的记忆
|
|
||||||
memory = memories[0]
|
|
||||||
memory_id = memory.id
|
|
||||||
print(f"已加载测试记忆,ID: {memory_id}")
|
|
||||||
print(f"初始强度: {memory.memory_strength}")
|
|
||||||
if memory.summary:
|
|
||||||
print(f"记忆主题: {memory.summary.get('brief', '无主题')}")
|
|
||||||
|
|
||||||
# 先衰减几次
|
|
||||||
print("\n开始衰减:")
|
|
||||||
for i in range(3):
|
|
||||||
await working_memory.decay_all_memories(decay_factor=0.5)
|
|
||||||
current = working_memory.memory_manager.get_by_id(memory_id)
|
|
||||||
if current:
|
|
||||||
print(f"衰减 {i+1} 次后强度: {current.memory_strength}")
|
|
||||||
print(f"衰减详细信息: 压缩次数={current.compress_count}, 历史操作数={len(current.history)}")
|
|
||||||
if current.summary:
|
|
||||||
print(f"记忆概要: {current.summary.get('brief', '无概要')}")
|
|
||||||
else:
|
|
||||||
print(f"记忆已在第 {i+1} 次衰减后被移除。")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 如果记忆还存在,则检索几次增强它
|
|
||||||
current = working_memory.memory_manager.get_by_id(memory_id)
|
|
||||||
if current:
|
|
||||||
print("\n开始检索增强:")
|
|
||||||
for i in range(2):
|
|
||||||
retrieved = await working_memory.retrieve_memory(memory_id)
|
|
||||||
print(f"检索 {i+1} 次后强度: {retrieved.memory_strength}")
|
|
||||||
print(f"检索后详细信息: 提取次数={retrieved.retrieval_count}, 历史记录长度={len(retrieved.history)}")
|
|
||||||
|
|
||||||
# 再次衰减几次,测试是否会被移除
|
|
||||||
print("\n再次衰减:")
|
|
||||||
for i in range(5):
|
|
||||||
await working_memory.decay_all_memories(decay_factor=0.5)
|
|
||||||
current = working_memory.memory_manager.get_by_id(memory_id)
|
|
||||||
if current:
|
|
||||||
print(f"最终衰减 {i+1} 次后强度: {current.memory_strength}")
|
|
||||||
print(f"衰减详细结果: 压缩次数={current.compress_count}")
|
|
||||||
else:
|
|
||||||
print(f"记忆已在最终衰减第 {i+1} 次后被移除。")
|
|
||||||
break
|
|
||||||
|
|
||||||
print("\n测试结束。")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await working_memory.shutdown()
|
|
||||||
|
|
||||||
async def test_multi_memories_decay():
|
|
||||||
"""测试多条记忆同时衰减"""
|
|
||||||
print("\n===== 测试多条记忆同时衰减 =====")
|
|
||||||
|
|
||||||
# 初始化工作记忆
|
|
||||||
chat_id = "multi_decay_test"
|
|
||||||
working_memory = WorkingMemory(chat_id=chat_id, max_memories_per_chat=10, auto_decay_interval=60)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建加载器并加载所有测试文件
|
|
||||||
loader = MemoryFileLoader(working_memory)
|
|
||||||
test_dir = current_dir
|
|
||||||
|
|
||||||
# 加载所有测试文件作为记忆
|
|
||||||
memories = await loader.load_from_directory(
|
|
||||||
directory_path=str(test_dir),
|
|
||||||
file_pattern="*.txt",
|
|
||||||
common_tags=["测试", "多记忆衰减"],
|
|
||||||
source_prefix="多记忆测试"
|
|
||||||
)
|
|
||||||
|
|
||||||
if not memories or len(memories) < 2:
|
|
||||||
print("未能加载足够的记忆文件,测试结束")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 显示已加载的记忆
|
|
||||||
print(f"已加载 {len(memories)} 条记忆:")
|
|
||||||
for idx, mem in enumerate(memories):
|
|
||||||
print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 来源: {mem.from_source}")
|
|
||||||
if mem.summary:
|
|
||||||
print(f" 主题: {mem.summary.get('brief', '无主题')}")
|
|
||||||
|
|
||||||
# 进行多次衰减测试
|
|
||||||
print("\n开始多记忆衰减测试:")
|
|
||||||
for decay_round in range(5):
|
|
||||||
# 执行衰减
|
|
||||||
await working_memory.decay_all_memories(decay_factor=0.5)
|
|
||||||
|
|
||||||
# 获取并显示所有记忆
|
|
||||||
all_memories = working_memory.memory_manager.get_all_items()
|
|
||||||
print(f"\n第 {decay_round+1} 次衰减后,剩余记忆数量: {len(all_memories)}")
|
|
||||||
|
|
||||||
for idx, mem in enumerate(all_memories):
|
|
||||||
print(f"{idx+1}. ID: {mem.id}, 强度: {mem.memory_strength}, 压缩次数: {mem.compress_count}")
|
|
||||||
if mem.summary:
|
|
||||||
print(f" 概要: {mem.summary.get('brief', '无概要')[:30]}...")
|
|
||||||
|
|
||||||
# 如果所有记忆都被移除,退出循环
|
|
||||||
if not all_memories:
|
|
||||||
print("所有记忆已被移除,测试结束。")
|
|
||||||
break
|
|
||||||
|
|
||||||
# 等待一下
|
|
||||||
await asyncio.sleep(0.5)
|
|
||||||
|
|
||||||
print("\n多记忆衰减测试结束。")
|
|
||||||
|
|
||||||
finally:
|
|
||||||
await working_memory.shutdown()
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""运行所有测试"""
|
|
||||||
# 测试手动衰减直到移除
|
|
||||||
await test_manual_decay_until_removal()
|
|
||||||
|
|
||||||
# 测试自动衰减
|
|
||||||
await test_auto_decay()
|
|
||||||
|
|
||||||
# 测试衰减和检索的平衡
|
|
||||||
await test_decay_and_retrieval_balance()
|
|
||||||
|
|
||||||
# 测试多条记忆同时衰减
|
|
||||||
await test_multi_memories_decay()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
asyncio.run(main())
|
|
||||||
@@ -1,121 +0,0 @@
|
|||||||
import asyncio
|
|
||||||
import os
|
|
||||||
import unittest
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
|
||||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
|
||||||
|
|
||||||
class TestWorkingMemory(unittest.TestCase):
|
|
||||||
"""工作记忆测试类"""
|
|
||||||
|
|
||||||
def setUp(self):
|
|
||||||
"""测试前准备"""
|
|
||||||
self.chat_id = "test_chat_123"
|
|
||||||
self.working_memory = WorkingMemory(chat_id=self.chat_id, max_memories_per_chat=10, auto_decay_interval=60)
|
|
||||||
self.test_dir = Path(__file__).parent
|
|
||||||
|
|
||||||
def tearDown(self):
|
|
||||||
"""测试后清理"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
loop.run_until_complete(self.working_memory.shutdown())
|
|
||||||
|
|
||||||
def test_init(self):
|
|
||||||
"""测试初始化"""
|
|
||||||
self.assertEqual(self.working_memory.max_memories_per_chat, 10)
|
|
||||||
self.assertEqual(self.working_memory.auto_decay_interval, 60)
|
|
||||||
|
|
||||||
def test_add_memory_from_files(self):
|
|
||||||
"""从文件添加记忆"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
test_files = self._get_test_files()
|
|
||||||
|
|
||||||
# 添加记忆
|
|
||||||
memories = []
|
|
||||||
for file_path in test_files:
|
|
||||||
content = self._read_file_content(file_path)
|
|
||||||
file_name = os.path.basename(file_path)
|
|
||||||
source = f"test_file_{file_name}"
|
|
||||||
tags = ["测试", f"文件_{file_name}"]
|
|
||||||
|
|
||||||
memory = loop.run_until_complete(
|
|
||||||
self.working_memory.add_memory(
|
|
||||||
content=content,
|
|
||||||
from_source=source,
|
|
||||||
tags=tags
|
|
||||||
)
|
|
||||||
)
|
|
||||||
memories.append(memory)
|
|
||||||
|
|
||||||
# 验证记忆数量
|
|
||||||
all_items = self.working_memory.memory_manager.get_all_items()
|
|
||||||
self.assertEqual(len(all_items), len(test_files))
|
|
||||||
|
|
||||||
# 验证每个记忆的内容和标签
|
|
||||||
for i, memory in enumerate(memories):
|
|
||||||
file_name = os.path.basename(test_files[i])
|
|
||||||
retrieved_memory = loop.run_until_complete(
|
|
||||||
self.working_memory.retrieve_memory(memory.id)
|
|
||||||
)
|
|
||||||
|
|
||||||
self.assertIsNotNone(retrieved_memory)
|
|
||||||
self.assertTrue(retrieved_memory.has_tag("测试"))
|
|
||||||
self.assertTrue(retrieved_memory.has_tag(f"文件_{file_name}"))
|
|
||||||
self.assertEqual(retrieved_memory.from_source, f"test_file_{file_name}")
|
|
||||||
|
|
||||||
# 验证检索后强度增加
|
|
||||||
self.assertGreater(retrieved_memory.memory_strength, 10.0) # 原始强度为10.0,检索后增加1.5倍
|
|
||||||
self.assertEqual(retrieved_memory.retrieval_count, 1)
|
|
||||||
|
|
||||||
def test_decay_memories(self):
|
|
||||||
"""测试记忆衰减"""
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
test_files = self._get_test_files()[:1] # 只使用一个文件测试衰减
|
|
||||||
|
|
||||||
# 添加记忆
|
|
||||||
for file_path in test_files:
|
|
||||||
content = self._read_file_content(file_path)
|
|
||||||
loop.run_until_complete(
|
|
||||||
self.working_memory.add_memory(
|
|
||||||
content=content,
|
|
||||||
from_source="decay_test",
|
|
||||||
tags=["衰减测试"]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取添加后的记忆项
|
|
||||||
all_items_before = self.working_memory.memory_manager.get_all_items()
|
|
||||||
self.assertEqual(len(all_items_before), 1)
|
|
||||||
|
|
||||||
# 记录原始强度
|
|
||||||
original_strength = all_items_before[0].memory_strength
|
|
||||||
|
|
||||||
# 执行衰减
|
|
||||||
loop.run_until_complete(
|
|
||||||
self.working_memory.decay_all_memories(decay_factor=0.5)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 获取衰减后的记忆项
|
|
||||||
all_items_after = self.working_memory.memory_manager.get_all_items()
|
|
||||||
|
|
||||||
# 验证强度衰减
|
|
||||||
self.assertEqual(len(all_items_after), 1)
|
|
||||||
self.assertLess(all_items_after[0].memory_strength, original_strength)
|
|
||||||
|
|
||||||
def _get_test_files(self) -> List[str]:
|
|
||||||
"""获取测试文件列表"""
|
|
||||||
test_dir = self.test_dir
|
|
||||||
return [
|
|
||||||
os.path.join(test_dir, f)
|
|
||||||
for f in os.listdir(test_dir)
|
|
||||||
if f.endswith(".txt")
|
|
||||||
]
|
|
||||||
|
|
||||||
def _read_file_content(self, file_path: str) -> str:
|
|
||||||
"""读取文件内容"""
|
|
||||||
with open(file_path, "r", encoding="utf-8") as f:
|
|
||||||
return f.read()
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
from typing import Dict, List, Any, Optional
|
from typing import List, Any, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
import random
|
import random
|
||||||
from datetime import datetime
|
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
from src.chat.focus_chat.working_memory.memory_manager import MemoryManager, MemoryItem
|
||||||
|
|
||||||
@@ -9,13 +8,14 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
# 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务
|
# 问题是我不知道这个manager是不是需要和其他manager统一管理,因为这个manager是从属于每一个聊天流,都有自己的定时任务
|
||||||
|
|
||||||
|
|
||||||
class WorkingMemory:
|
class WorkingMemory:
|
||||||
"""
|
"""
|
||||||
工作记忆,负责协调和运作记忆
|
工作记忆,负责协调和运作记忆
|
||||||
从属于特定的流,用chat_id来标识
|
从属于特定的流,用chat_id来标识
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,chat_id:str , max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
|
def __init__(self, chat_id: str, max_memories_per_chat: int = 10, auto_decay_interval: int = 60):
|
||||||
"""
|
"""
|
||||||
初始化工作记忆管理器
|
初始化工作记忆管理器
|
||||||
|
|
||||||
@@ -51,11 +51,7 @@ class WorkingMemory:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"自动衰减记忆时出错: {str(e)}")
|
print(f"自动衰减记忆时出错: {str(e)}")
|
||||||
|
|
||||||
|
async def add_memory(self, content: Any, from_source: str = "", tags: Optional[List[str]] = None):
|
||||||
async def add_memory(self,
|
|
||||||
content: Any,
|
|
||||||
from_source: str = "",
|
|
||||||
tags: Optional[List[str]] = None):
|
|
||||||
"""
|
"""
|
||||||
添加一段记忆到指定聊天
|
添加一段记忆到指定聊天
|
||||||
|
|
||||||
@@ -97,7 +93,6 @@ class WorkingMemory:
|
|||||||
return memory_item
|
return memory_item
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def decay_all_memories(self, decay_factor: float = 0.5):
|
async def decay_all_memories(self, decay_factor: float = 0.5):
|
||||||
"""
|
"""
|
||||||
对所有聊天的所有记忆进行衰减
|
对所有聊天的所有记忆进行衰减
|
||||||
@@ -119,7 +114,9 @@ class WorkingMemory:
|
|||||||
continue
|
continue
|
||||||
# 计算衰减量
|
# 计算衰减量
|
||||||
if memory_item.memory_strength < 5:
|
if memory_item.memory_strength < 5:
|
||||||
await self.memory_manager.refine_memory(memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩")
|
await self.memory_manager.refine_memory(
|
||||||
|
memory_id, f"由于时间过去了{self.auto_decay_interval}秒,记忆变的模糊,所以需要压缩"
|
||||||
|
)
|
||||||
|
|
||||||
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
|
async def merge_memory(self, memory_id1: str, memory_id2: str) -> MemoryItem:
|
||||||
"""合并记忆
|
"""合并记忆
|
||||||
@@ -127,9 +124,9 @@ class WorkingMemory:
|
|||||||
Args:
|
Args:
|
||||||
memory_str: 记忆内容
|
memory_str: 记忆内容
|
||||||
"""
|
"""
|
||||||
return await self.memory_manager.merge_memories(memory_id1 = memory_id1, memory_id2 = memory_id2,reason = "两端记忆有重复的内容")
|
return await self.memory_manager.merge_memories(
|
||||||
|
memory_id1=memory_id1, memory_id2=memory_id2, reason="两端记忆有重复的内容"
|
||||||
|
)
|
||||||
|
|
||||||
# 暂时没用,先留着
|
# 暂时没用,先留着
|
||||||
async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2):
|
async def simulate_memory_blur(self, chat_id: str, blur_rate: float = 0.2):
|
||||||
@@ -146,7 +143,7 @@ class WorkingMemory:
|
|||||||
all_summarized_memories = []
|
all_summarized_memories = []
|
||||||
for type_items in memory._memory.values():
|
for type_items in memory._memory.values():
|
||||||
for item in type_items:
|
for item in type_items:
|
||||||
if isinstance(item.data, str) and hasattr(item, 'summary') and item.summary:
|
if isinstance(item.data, str) and hasattr(item, "summary") and item.summary:
|
||||||
all_summarized_memories.append(item)
|
all_summarized_memories.append(item)
|
||||||
|
|
||||||
if not all_summarized_memories:
|
if not all_summarized_memories:
|
||||||
@@ -176,8 +173,6 @@ class WorkingMemory:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}")
|
print(f"模糊记忆 {memory_item.id} 时出错: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
"""关闭管理器,停止所有任务"""
|
"""关闭管理器,停止所有任务"""
|
||||||
if self.decay_task and not self.decay_task.done():
|
if self.decay_task and not self.decay_task.done():
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ class HFCloopObservation:
|
|||||||
if start_time is not None and end_time is not None:
|
if start_time is not None and end_time is not None:
|
||||||
time_diff = int(end_time - start_time)
|
time_diff = int(end_time - start_time)
|
||||||
if time_diff > 60:
|
if time_diff > 60:
|
||||||
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff/60}分钟\n"
|
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff / 60}分钟\n"
|
||||||
else:
|
else:
|
||||||
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}秒\n"
|
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}秒\n"
|
||||||
else:
|
else:
|
||||||
@@ -86,5 +86,4 @@ class HFCloopObservation:
|
|||||||
action_description = action_info["description"]
|
action_description = action_info["description"]
|
||||||
cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
|
cycle_info_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
|
||||||
|
|
||||||
|
|
||||||
self.observe_info = cycle_info_block
|
self.observe_info = cycle_info_block
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from src.common.logger_manager import get_logger
|
|||||||
|
|
||||||
logger = get_logger("observation")
|
logger = get_logger("observation")
|
||||||
|
|
||||||
|
|
||||||
# 所有观察的基类
|
# 所有观察的基类
|
||||||
class Observation:
|
class Observation:
|
||||||
def __init__(self, observe_id):
|
def __init__(self, observe_id):
|
||||||
|
|||||||
@@ -103,7 +103,6 @@ class PersonInfoManager:
|
|||||||
else:
|
else:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def create_person_info(person_id: str, data: dict = None):
|
async def create_person_info(person_id: str, data: dict = None):
|
||||||
"""创建一个项"""
|
"""创建一个项"""
|
||||||
|
|||||||
@@ -451,7 +451,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
# 处理 回复<aaa:bbb>
|
# 处理 回复<aaa:bbb>
|
||||||
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
reply_pattern = r"回复<([^:<>]+):([^:<>]+)>"
|
||||||
|
|
||||||
def reply_replacer(match):
|
def reply_replacer(match, platform=platform):
|
||||||
# aaa = match.group(1)
|
# aaa = match.group(1)
|
||||||
bbb = match.group(2)
|
bbb = match.group(2)
|
||||||
anon_reply = get_anon_name(platform, bbb)
|
anon_reply = get_anon_name(platform, bbb)
|
||||||
@@ -462,7 +462,7 @@ async def build_anonymous_messages(messages: List[Dict[str, Any]]) -> str:
|
|||||||
# 处理 @<aaa:bbb>
|
# 处理 @<aaa:bbb>
|
||||||
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
|
at_pattern = r"@<([^:<>]+):([^:<>]+)>"
|
||||||
|
|
||||||
def at_replacer(match):
|
def at_replacer(match, platform=platform):
|
||||||
# aaa = match.group(1)
|
# aaa = match.group(1)
|
||||||
bbb = match.group(2)
|
bbb = match.group(2)
|
||||||
anon_at = get_anon_name(platform, bbb)
|
anon_at = get_anon_name(platform, bbb)
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""测试插件包"""
|
"""测试插件包"""
|
||||||
|
|
||||||
"""
|
"""
|
||||||
这是一个测试插件
|
这是一个测试插件
|
||||||
"""
|
"""
|
||||||
@@ -2,5 +2,6 @@
|
|||||||
|
|
||||||
# 导入所有动作模块以确保装饰器被执行
|
# 导入所有动作模块以确保装饰器被执行
|
||||||
from . import test_action # noqa
|
from . import test_action # noqa
|
||||||
|
|
||||||
# from . import online_action # noqa
|
# from . import online_action # noqa
|
||||||
from . import mute_action # noqa
|
from . import mute_action # noqa
|
||||||
@@ -4,12 +4,15 @@ from typing import Tuple
|
|||||||
|
|
||||||
logger = get_logger("mute_action")
|
logger = get_logger("mute_action")
|
||||||
|
|
||||||
|
|
||||||
@register_action
|
@register_action
|
||||||
class MuteAction(PluginAction):
|
class MuteAction(PluginAction):
|
||||||
"""测试动作处理类"""
|
"""测试动作处理类"""
|
||||||
|
|
||||||
action_name = "mute_action"
|
action_name = "mute_action"
|
||||||
action_description = "如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人"
|
action_description = (
|
||||||
|
"如果某人违反了公序良俗,或者别人戳你太多,,或者某人刷屏,一定要禁言某人,如果你很生气,可以禁言某人"
|
||||||
|
)
|
||||||
action_parameters = {
|
action_parameters = {
|
||||||
"target": "禁言对象,输入你要禁言的对象的名字,必填,",
|
"target": "禁言对象,输入你要禁言的对象的名字,必填,",
|
||||||
"duration": "禁言时长,输入你要禁言的时长,单位为秒,必填",
|
"duration": "禁言时长,输入你要禁言的时长,单位为秒,必填",
|
||||||
|
|||||||
@@ -4,15 +4,14 @@ from typing import Tuple
|
|||||||
|
|
||||||
logger = get_logger("check_online_action")
|
logger = get_logger("check_online_action")
|
||||||
|
|
||||||
|
|
||||||
@register_action
|
@register_action
|
||||||
class CheckOnlineAction(PluginAction):
|
class CheckOnlineAction(PluginAction):
|
||||||
"""测试动作处理类"""
|
"""测试动作处理类"""
|
||||||
|
|
||||||
action_name = "check_online_action"
|
action_name = "check_online_action"
|
||||||
action_description = "这是一个检查在线状态的动作,当有人要求你检查Maibot(麦麦 机器人)在线状态时使用"
|
action_description = "这是一个检查在线状态的动作,当有人要求你检查Maibot(麦麦 机器人)在线状态时使用"
|
||||||
action_parameters = {
|
action_parameters = {"mode": "查看模式"}
|
||||||
"mode": "查看模式"
|
|
||||||
}
|
|
||||||
action_require = [
|
action_require = [
|
||||||
"当有人要求你检查Maibot(麦麦 机器人)在线状态时使用",
|
"当有人要求你检查Maibot(麦麦 机器人)在线状态时使用",
|
||||||
"mode参数为version时查看在线版本状态,默认用这种",
|
"mode参数为version时查看在线版本状态,默认用这种",
|
||||||
@@ -31,9 +30,9 @@ class CheckOnlineAction(PluginAction):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
if mode == "type":
|
if mode == "type":
|
||||||
await self.send_message(f"#online detail")
|
await self.send_message("#online detail")
|
||||||
elif mode == "version":
|
elif mode == "version":
|
||||||
await self.send_message(f"#online")
|
await self.send_message("#online")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"{self.log_prefix} 执行online动作时出错: {e}")
|
logger.error(f"{self.log_prefix} 执行online动作时出错: {e}")
|
||||||
|
|||||||
@@ -4,15 +4,14 @@ from typing import Tuple
|
|||||||
|
|
||||||
logger = get_logger("test_action")
|
logger = get_logger("test_action")
|
||||||
|
|
||||||
|
|
||||||
@register_action
|
@register_action
|
||||||
class TestAction(PluginAction):
|
class TestAction(PluginAction):
|
||||||
"""测试动作处理类"""
|
"""测试动作处理类"""
|
||||||
|
|
||||||
action_name = "test_action"
|
action_name = "test_action"
|
||||||
action_description = "这是一个测试动作,当有人要求你测试插件系统时使用"
|
action_description = "这是一个测试动作,当有人要求你测试插件系统时使用"
|
||||||
action_parameters = {
|
action_parameters = {"test_param": "测试参数(可选)"}
|
||||||
"test_param": "测试参数(可选)"
|
|
||||||
}
|
|
||||||
action_require = [
|
action_require = [
|
||||||
"测试情况下使用",
|
"测试情况下使用",
|
||||||
"想测试插件动作加载时使用",
|
"想测试插件动作加载时使用",
|
||||||
|
|||||||
Reference in New Issue
Block a user