fix:优化记忆提取,修复破损的tool信息
This commit is contained in:
@@ -61,10 +61,10 @@ class ExpressionLearner:
|
|||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
# TODO: API-Adapter修改标记
|
# TODO: API-Adapter修改标记
|
||||||
self.express_learn_model: LLMRequest = LLMRequest(
|
self.express_learn_model: LLMRequest = LLMRequest(
|
||||||
model=global_config.model.normal,
|
model=global_config.model.focus_expressor,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
request_type="response_heartflow",
|
request_type="learn_expression",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
async def get_expression_by_chat_id(self, chat_id: str) -> Tuple[List[Dict[str, str]], List[Dict[str, str]]]:
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from src.chat.focus_chat.info_processors.working_memory_processor import Working
|
|||||||
from src.chat.focus_chat.info_processors.action_processor import ActionProcessor
|
from src.chat.focus_chat.info_processors.action_processor import ActionProcessor
|
||||||
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservation
|
||||||
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
from src.chat.heart_flow.observation.working_observation import WorkingMemoryObservation
|
||||||
|
from src.chat.heart_flow.observation.structure_observation import StructureObservation
|
||||||
from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor
|
from src.chat.focus_chat.info_processors.tool_processor import ToolProcessor
|
||||||
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
from src.chat.focus_chat.expressors.default_expressor import DefaultExpressor
|
||||||
from src.chat.focus_chat.memory_activator import MemoryActivator
|
from src.chat.focus_chat.memory_activator import MemoryActivator
|
||||||
@@ -97,6 +98,7 @@ class HeartFChatting:
|
|||||||
self.log_prefix: str = str(chat_id) # Initial default, will be updated
|
self.log_prefix: str = str(chat_id) # Initial default, will be updated
|
||||||
self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id)
|
self.hfcloop_observation = HFCloopObservation(observe_id=self.stream_id)
|
||||||
self.chatting_observation = observations[0]
|
self.chatting_observation = observations[0]
|
||||||
|
self.structure_observation = StructureObservation(observe_id=self.stream_id)
|
||||||
|
|
||||||
self.memory_activator = MemoryActivator()
|
self.memory_activator = MemoryActivator()
|
||||||
self.working_memory = WorkingMemory(chat_id=self.stream_id)
|
self.working_memory = WorkingMemory(chat_id=self.stream_id)
|
||||||
@@ -415,11 +417,13 @@ class HeartFChatting:
|
|||||||
await self.chatting_observation.observe()
|
await self.chatting_observation.observe()
|
||||||
await self.working_observation.observe()
|
await self.working_observation.observe()
|
||||||
await self.hfcloop_observation.observe()
|
await self.hfcloop_observation.observe()
|
||||||
|
await self.structure_observation.observe()
|
||||||
observations: List[Observation] = []
|
observations: List[Observation] = []
|
||||||
observations.append(self.chatting_observation)
|
observations.append(self.chatting_observation)
|
||||||
observations.append(self.working_observation)
|
observations.append(self.working_observation)
|
||||||
observations.append(self.hfcloop_observation)
|
observations.append(self.hfcloop_observation)
|
||||||
|
observations.append(self.structure_observation)
|
||||||
|
|
||||||
loop_observation_info = {
|
loop_observation_info = {
|
||||||
"observations": observations,
|
"observations": observations,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -76,7 +76,11 @@ class StructuredInfo:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
info_str = ""
|
info_str = ""
|
||||||
|
# print(f"self.data: {self.data}")
|
||||||
|
|
||||||
for key, value in self.data.items():
|
for key, value in self.data.items():
|
||||||
|
|
||||||
|
# print(f"key: {key}, value: {value}")
|
||||||
info_str += f"信息类型:{key},信息内容:{value}\n"
|
info_str += f"信息类型:{key},信息内容:{value}\n"
|
||||||
|
|
||||||
return info_str
|
return info_str
|
||||||
|
|||||||
@@ -75,10 +75,12 @@ class ToolProcessor(BaseProcessor):
|
|||||||
result, used_tools, prompt = await self.execute_tools(observation, running_memorys)
|
result, used_tools, prompt = await self.execute_tools(observation, running_memorys)
|
||||||
|
|
||||||
# 更新WorkingObservation中的结构化信息
|
# 更新WorkingObservation中的结构化信息
|
||||||
|
logger.debug(f"工具调用结果: {result}")
|
||||||
|
|
||||||
for observation in observations:
|
for observation in observations:
|
||||||
if isinstance(observation, StructureObservation):
|
if isinstance(observation, StructureObservation):
|
||||||
for structured_info in result:
|
for structured_info in result:
|
||||||
logger.debug(f"{self.log_prefix} 更新WorkingObservation中的结构化信息: {structured_info}")
|
# logger.debug(f"{self.log_prefix} 更新WorkingObservation中的结构化信息: {structured_info}")
|
||||||
observation.add_structured_info(structured_info)
|
observation.add_structured_info(structured_info)
|
||||||
|
|
||||||
working_infos = observation.get_observe_info()
|
working_infos = observation.get_observe_info()
|
||||||
@@ -87,7 +89,12 @@ class ToolProcessor(BaseProcessor):
|
|||||||
structured_info = StructuredInfo()
|
structured_info = StructuredInfo()
|
||||||
if working_infos:
|
if working_infos:
|
||||||
for working_info in working_infos:
|
for working_info in working_infos:
|
||||||
structured_info.set_info(working_info.get("type"), working_info.get("content"))
|
# print(f"working_info: {working_info}")
|
||||||
|
# print(f"working_info.get('type'): {working_info.get('type')}")
|
||||||
|
# print(f"working_info.get('content'): {working_info.get('content')}")
|
||||||
|
structured_info.set_info(key=working_info.get('type'), value=working_info.get('content'))
|
||||||
|
# info = structured_info.get_processed_info()
|
||||||
|
# print(f"info: {info}")
|
||||||
|
|
||||||
return [structured_info]
|
return [structured_info]
|
||||||
|
|
||||||
@@ -155,7 +162,7 @@ class ToolProcessor(BaseProcessor):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# 调用LLM,专注于工具使用
|
# 调用LLM,专注于工具使用
|
||||||
# logger.debug(f"开始执行工具调用{prompt}")
|
logger.debug(f"开始执行工具调用{prompt}")
|
||||||
response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools)
|
response, _, tool_calls = await self.llm_model.generate_response_tool_async(prompt=prompt, tools=tools)
|
||||||
|
|
||||||
logger.debug(f"获取到工具原始输出:\n{tool_calls}")
|
logger.debug(f"获取到工具原始输出:\n{tool_calls}")
|
||||||
|
|||||||
@@ -4,24 +4,58 @@ from src.chat.heart_flow.observation.hfcloop_observation import HFCloopObservati
|
|||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.common.logger_manager import get_logger
|
from src.common.logger_manager import get_logger
|
||||||
from src.chat.utils.prompt_builder import Prompt
|
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from src.chat.memory_system.Hippocampus import HippocampusManager
|
from src.chat.memory_system.Hippocampus import HippocampusManager
|
||||||
from typing import List, Dict
|
from typing import List, Dict
|
||||||
import difflib
|
import difflib
|
||||||
|
import json
|
||||||
|
from json_repair import repair_json
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("memory_activator")
|
logger = get_logger("memory_activator")
|
||||||
|
|
||||||
|
|
||||||
|
def get_keywords_from_json(json_str):
|
||||||
|
"""
|
||||||
|
从JSON字符串中提取关键词列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
json_str: JSON格式的字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[str]: 关键词列表
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 使用repair_json修复JSON格式
|
||||||
|
fixed_json = repair_json(json_str)
|
||||||
|
|
||||||
|
# 如果repair_json返回的是字符串,需要解析为Python对象
|
||||||
|
if isinstance(fixed_json, str):
|
||||||
|
result = json.loads(fixed_json)
|
||||||
|
else:
|
||||||
|
# 如果repair_json直接返回了字典对象,直接使用
|
||||||
|
result = fixed_json
|
||||||
|
|
||||||
|
# 提取关键词
|
||||||
|
keywords = result.get("keywords", [])
|
||||||
|
return keywords
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析关键词JSON失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
def init_prompt():
|
def init_prompt():
|
||||||
# --- Group Chat Prompt ---
|
# --- Group Chat Prompt ---
|
||||||
memory_activator_prompt = """
|
memory_activator_prompt = """
|
||||||
你是一个记忆分析器,你需要根据以下信息来进行会议
|
你是一个记忆分析器,你需要根据以下信息来进行回忆
|
||||||
以下是一场聊天中的信息,请根据这些信息,总结出几个关键词作为记忆回忆的触发词
|
以下是一场聊天中的信息,请根据这些信息,总结出几个关键词作为记忆回忆的触发词
|
||||||
|
|
||||||
{obs_info_text}
|
{obs_info_text}
|
||||||
|
|
||||||
|
历史关键词(请避免重复提取这些关键词):
|
||||||
|
{cached_keywords}
|
||||||
|
|
||||||
请输出一个json格式,包含以下字段:
|
请输出一个json格式,包含以下字段:
|
||||||
{{
|
{{
|
||||||
"keywords": ["关键词1", "关键词2", "关键词3",......]
|
"keywords": ["关键词1", "关键词2", "关键词3",......]
|
||||||
@@ -39,6 +73,7 @@ class MemoryActivator:
|
|||||||
model=global_config.model.memory_summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
|
model=global_config.model.memory_summary, temperature=0.7, max_tokens=50, request_type="chat_observation"
|
||||||
)
|
)
|
||||||
self.running_memory = []
|
self.running_memory = []
|
||||||
|
self.cached_keywords = set() # 用于缓存历史关键词
|
||||||
|
|
||||||
async def activate_memory(self, observations) -> List[Dict]:
|
async def activate_memory(self, observations) -> List[Dict]:
|
||||||
"""
|
"""
|
||||||
@@ -61,31 +96,47 @@ class MemoryActivator:
|
|||||||
elif isinstance(observation, HFCloopObservation):
|
elif isinstance(observation, HFCloopObservation):
|
||||||
obs_info_text += observation.get_observe_info()
|
obs_info_text += observation.get_observe_info()
|
||||||
|
|
||||||
logger.debug(f"回忆待检索内容:obs_info_text: {obs_info_text}")
|
# logger.debug(f"回忆待检索内容:obs_info_text: {obs_info_text}")
|
||||||
|
|
||||||
# prompt = await global_prompt_manager.format_prompt(
|
# 将缓存的关键词转换为字符串,用于prompt
|
||||||
# "memory_activator_prompt",
|
cached_keywords_str = ", ".join(self.cached_keywords) if self.cached_keywords else "暂无历史关键词"
|
||||||
# obs_info_text=obs_info_text,
|
|
||||||
# )
|
prompt = await global_prompt_manager.format_prompt(
|
||||||
|
"memory_activator_prompt",
|
||||||
# logger.debug(f"prompt: {prompt}")
|
obs_info_text=obs_info_text,
|
||||||
|
cached_keywords=cached_keywords_str,
|
||||||
# response = await self.summary_model.generate_response(prompt)
|
|
||||||
|
|
||||||
# logger.debug(f"response: {response}")
|
|
||||||
|
|
||||||
# # 只取response的第一个元素(字符串)
|
|
||||||
# response_str = response[0]
|
|
||||||
# keywords = list(get_keywords_from_json(response_str))
|
|
||||||
|
|
||||||
# #调用记忆系统获取相关记忆
|
|
||||||
# related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
|
|
||||||
# valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
|
|
||||||
# )
|
|
||||||
related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
|
||||||
text=obs_info_text, max_memory_num=5, max_memory_length=2, max_depth=3, fast_retrieval=True
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.debug(f"prompt: {prompt}")
|
||||||
|
|
||||||
|
response = await self.summary_model.generate_response(prompt)
|
||||||
|
|
||||||
|
logger.debug(f"response: {response}")
|
||||||
|
|
||||||
|
# 只取response的第一个元素(字符串)
|
||||||
|
response_str = response[0]
|
||||||
|
keywords = list(get_keywords_from_json(response_str))
|
||||||
|
|
||||||
|
# 更新关键词缓存
|
||||||
|
if keywords:
|
||||||
|
# 限制缓存大小,最多保留10个关键词
|
||||||
|
if len(self.cached_keywords) > 10:
|
||||||
|
# 转换为列表,移除最早的关键词
|
||||||
|
cached_list = list(self.cached_keywords)
|
||||||
|
self.cached_keywords = set(cached_list[-8:])
|
||||||
|
|
||||||
|
# 添加新的关键词到缓存
|
||||||
|
self.cached_keywords.update(keywords)
|
||||||
|
logger.debug(f"更新关键词缓存: {self.cached_keywords}")
|
||||||
|
|
||||||
|
#调用记忆系统获取相关记忆
|
||||||
|
related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
|
||||||
|
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
|
||||||
|
)
|
||||||
|
# related_memory = await HippocampusManager.get_instance().get_memory_from_text(
|
||||||
|
# text=obs_info_text, max_memory_num=5, max_memory_length=2, max_depth=3, fast_retrieval=False
|
||||||
|
# )
|
||||||
|
|
||||||
# logger.debug(f"获取到的记忆: {related_memory}")
|
# logger.debug(f"获取到的记忆: {related_memory}")
|
||||||
|
|
||||||
# 激活时,所有已有记忆的duration+1,达到3则移除
|
# 激活时,所有已有记忆的duration+1,达到3则移除
|
||||||
|
|||||||
@@ -36,9 +36,8 @@ def init_prompt():
|
|||||||
{mind_info_block}
|
{mind_info_block}
|
||||||
{cycle_info_block}
|
{cycle_info_block}
|
||||||
|
|
||||||
{action_available_block}
|
|
||||||
|
|
||||||
请综合分析聊天内容和你看到的新消息,参考聊天规划,选择合适的action:
|
请综合分析聊天内容和你看到的新消息,参考聊天规划,选择合适的action:
|
||||||
|
注意,除了下面动作选项之外,你在群聊里不能做其他任何事情,这是你能力的边界,现在请你选择合适的action:
|
||||||
|
|
||||||
{action_options_text}
|
{action_options_text}
|
||||||
|
|
||||||
@@ -126,13 +125,6 @@ class ActionPlanner:
|
|||||||
action = "no_reply"
|
action = "no_reply"
|
||||||
reasoning = f"之前选择的动作{action}已被移除,原因: {reason}"
|
reasoning = f"之前选择的动作{action}已被移除,原因: {reason}"
|
||||||
|
|
||||||
using_actions = self.action_manager.get_using_actions()
|
|
||||||
action_available_block = ""
|
|
||||||
for action_name, action_info in using_actions.items():
|
|
||||||
action_description = action_info["description"]
|
|
||||||
action_available_block += f"\n你在聊天中可以使用{action_name},这个动作的描述是{action_description}\n"
|
|
||||||
action_available_block += "注意,除了上述动作选项之外,你在群聊里不能做其他任何事情,这是你能力的边界\n"
|
|
||||||
|
|
||||||
# 继续处理其他信息
|
# 继续处理其他信息
|
||||||
for info in all_plan_info:
|
for info in all_plan_info:
|
||||||
if isinstance(info, ObsInfo):
|
if isinstance(info, ObsInfo):
|
||||||
@@ -147,7 +139,8 @@ class ActionPlanner:
|
|||||||
elif isinstance(info, SelfInfo):
|
elif isinstance(info, SelfInfo):
|
||||||
self_info = info.get_processed_info()
|
self_info = info.get_processed_info()
|
||||||
elif isinstance(info, StructuredInfo):
|
elif isinstance(info, StructuredInfo):
|
||||||
_structured_info = info.get_data()
|
structured_info = info.get_processed_info()
|
||||||
|
# print(f"structured_info: {structured_info}")
|
||||||
elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo
|
elif not isinstance(info, ActionInfo): # 跳过已处理的ActionInfo
|
||||||
extra_info.append(info.get_processed_info())
|
extra_info.append(info.get_processed_info())
|
||||||
|
|
||||||
@@ -178,11 +171,10 @@ class ActionPlanner:
|
|||||||
chat_target_info=None,
|
chat_target_info=None,
|
||||||
observed_messages_str=observed_messages_str, # <-- Pass local variable
|
observed_messages_str=observed_messages_str, # <-- Pass local variable
|
||||||
current_mind=current_mind, # <-- Pass argument
|
current_mind=current_mind, # <-- Pass argument
|
||||||
# structured_info=structured_info, # <-- Pass SubMind info
|
structured_info=structured_info, # <-- Pass SubMind info
|
||||||
current_available_actions=current_available_actions, # <-- Pass determined actions
|
current_available_actions=current_available_actions, # <-- Pass determined actions
|
||||||
cycle_info=cycle_info, # <-- Pass cycle info
|
cycle_info=cycle_info, # <-- Pass cycle info
|
||||||
extra_info=extra_info,
|
extra_info=extra_info,
|
||||||
action_available_block=action_available_block,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 调用 LLM (普通文本生成) ---
|
# --- 调用 LLM (普通文本生成) ---
|
||||||
@@ -268,7 +260,7 @@ class ActionPlanner:
|
|||||||
chat_target_info: Optional[dict], # Now passed as argument
|
chat_target_info: Optional[dict], # Now passed as argument
|
||||||
observed_messages_str: str,
|
observed_messages_str: str,
|
||||||
current_mind: Optional[str],
|
current_mind: Optional[str],
|
||||||
action_available_block: str,
|
structured_info: Optional[str],
|
||||||
current_available_actions: Dict[str, ActionInfo],
|
current_available_actions: Dict[str, ActionInfo],
|
||||||
cycle_info: Optional[str],
|
cycle_info: Optional[str],
|
||||||
extra_info: list[str],
|
extra_info: list[str],
|
||||||
@@ -326,7 +318,8 @@ class ActionPlanner:
|
|||||||
action_options_block += using_action_prompt
|
action_options_block += using_action_prompt
|
||||||
|
|
||||||
extra_info_block = "\n".join(extra_info)
|
extra_info_block = "\n".join(extra_info)
|
||||||
if extra_info:
|
extra_info_block += f"\n{structured_info}"
|
||||||
|
if extra_info or structured_info:
|
||||||
extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策"
|
extra_info_block = f"以下是一些额外的信息,现在请你阅读以下内容,进行决策\n{extra_info_block}\n以上是一些额外的信息,现在请你阅读以下内容,进行决策"
|
||||||
else:
|
else:
|
||||||
extra_info_block = ""
|
extra_info_block = ""
|
||||||
@@ -343,7 +336,7 @@ class ActionPlanner:
|
|||||||
mind_info_block=mind_info_block,
|
mind_info_block=mind_info_block,
|
||||||
cycle_info_block=cycle_info,
|
cycle_info_block=cycle_info,
|
||||||
action_options_text=action_options_block,
|
action_options_text=action_options_block,
|
||||||
action_available_block=action_available_block,
|
# action_available_block=action_available_block,
|
||||||
extra_info_block=extra_info_block,
|
extra_info_block=extra_info_block,
|
||||||
moderation_prompt=moderation_prompt_block,
|
moderation_prompt=moderation_prompt_block,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -526,12 +526,12 @@ class Hippocampus:
|
|||||||
if not keywords:
|
if not keywords:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# logger.info(f"提取的关键词: {', '.join(keywords)}")
|
logger.info(f"提取的关键词: {', '.join(keywords)}")
|
||||||
|
|
||||||
# 过滤掉不存在于记忆图中的关键词
|
# 过滤掉不存在于记忆图中的关键词
|
||||||
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
valid_keywords = [keyword for keyword in keywords if keyword in self.memory_graph.G]
|
||||||
if not valid_keywords:
|
if not valid_keywords:
|
||||||
# logger.info("没有找到有效的关键词节点")
|
logger.info("没有找到有效的关键词节点")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
logger.debug(f"有效的关键词: {', '.join(valid_keywords)}")
|
||||||
|
|||||||
@@ -33,10 +33,10 @@ def init_prompt() -> None:
|
|||||||
class PersonalityExpression:
|
class PersonalityExpression:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.express_learn_model: LLMRequest = LLMRequest(
|
self.express_learn_model: LLMRequest = LLMRequest(
|
||||||
model=global_config.model.normal,
|
model=global_config.model.focus_expressor,
|
||||||
temperature=0.1,
|
temperature=0.1,
|
||||||
max_tokens=256,
|
max_tokens=256,
|
||||||
request_type="response_heartflow",
|
request_type="learn_expression",
|
||||||
)
|
)
|
||||||
self.meta_file_path = os.path.join("data", "expression", "personality", "expression_style_meta.json")
|
self.meta_file_path = os.path.join("data", "expression", "personality", "expression_style_meta.json")
|
||||||
self.expressions_file_path = os.path.join("data", "expression", "personality", "expressions.json")
|
self.expressions_file_path = os.path.join("data", "expression", "personality", "expressions.json")
|
||||||
|
|||||||
@@ -255,7 +255,8 @@ provider = "SILICONFLOW"
|
|||||||
pri_in = 2
|
pri_in = 2
|
||||||
pri_out = 8
|
pri_out = 8
|
||||||
|
|
||||||
#表达器模型,用于生成表达方式
|
#表达器模型,用于表达麦麦的想法,生成最终回复,对语言风格影响极大
|
||||||
|
#也用于表达方式学习
|
||||||
[model.focus_expressor]
|
[model.focus_expressor]
|
||||||
name = "Pro/deepseek-ai/DeepSeek-V3"
|
name = "Pro/deepseek-ai/DeepSeek-V3"
|
||||||
provider = "SILICONFLOW"
|
provider = "SILICONFLOW"
|
||||||
|
|||||||
Reference in New Issue
Block a user