fix:调整hf位置

This commit is contained in:
SengokuCola
2025-05-13 18:43:36 +08:00
parent 334211cba1
commit 6b42f3c2b0
26 changed files with 47 additions and 47 deletions

View File

@@ -0,0 +1,258 @@
from datetime import datetime
from src.chat.models.utils_model import LLMRequest
from src.config.config import global_config
import traceback
from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat,
build_readable_messages,
get_raw_msg_by_timestamp_with_chat,
num_new_messages_since,
get_person_id_list,
)
from src.chat.utils.prompt_builder import global_prompt_manager
from typing import Optional
import difflib
from src.chat.message_receive.message import MessageRecv # 添加 MessageRecv 导入
from src.chat.heart_flow.observation.observation import Observation
from src.common.logger_manager import get_logger
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
from src.chat.utils.prompt_builder import Prompt
logger = get_logger(__name__)
Prompt(
"""这是qq群聊的聊天记录请总结以下聊天记录的主题
{chat_logs}
请用一句话概括,包括人物、事件和主要信息,不要分点。""",
"chat_summary_group_prompt", # Template for group chat
)
Prompt(
"""这是你和{chat_target}的私聊记录,请总结以下聊天记录的主题:
{chat_logs}
请用一句话概括,包括事件,时间,和主要信息,不要分点。""",
"chat_summary_private_prompt", # Template for private chat
)
# --- End Prompt Template Definition ---
# 聊天观察
class ChattingObservation(Observation):
def __init__(self, chat_id):
super().__init__(chat_id)
self.chat_id = chat_id
# --- Initialize attributes (defaults) ---
self.is_group_chat: bool = False
self.chat_target_info: Optional[dict] = None
# --- End Initialization ---
# --- Other attributes initialized in __init__ ---
self.talking_message = []
self.talking_message_str = ""
self.talking_message_str_truncate = ""
self.name = global_config.BOT_NICKNAME
self.nick_name = global_config.BOT_ALIAS_NAMES
self.max_now_obs_len = global_config.observation_context_size
self.overlap_len = global_config.compressed_length
self.mid_memorys = []
self.max_mid_memory_len = global_config.compress_length_limit
self.mid_memory_info = ""
self.person_list = []
self.oldest_messages = []
self.oldest_messages_str = ""
self.compressor_prompt = ""
self.llm_summary = LLMRequest(
model=global_config.llm_observation, temperature=0.7, max_tokens=300, request_type="chat_observation"
)
async def initialize(self):
self.is_group_chat, self.chat_target_info = await get_chat_type_and_target_info(self.chat_id)
logger.debug(f"初始化observation: self.is_group_chat: {self.is_group_chat}")
logger.debug(f"初始化observation: self.chat_target_info: {self.chat_target_info}")
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
self.talking_message = initial_messages
self.talking_message_str = await build_readable_messages(self.talking_message)
# 进行一次观察 返回观察结果observe_info
def get_observe_info(self, ids=None):
mid_memory_str = ""
if ids:
for id in ids:
print(f"id{id}")
try:
for mid_memory in self.mid_memorys:
if mid_memory["id"] == id:
mid_memory_by_id = mid_memory
msg_str = ""
for msg in mid_memory_by_id["messages"]:
msg_str += f"{msg['detailed_plain_text']}"
# time_diff = int((datetime.now().timestamp() - mid_memory_by_id["created_at"]) / 60)
# mid_memory_str += f"距离现在{time_diff}分钟前:\n{msg_str}\n"
mid_memory_str += f"{msg_str}\n"
except Exception as e:
logger.error(f"获取mid_memory_id失败: {e}")
traceback.print_exc()
return self.talking_message_str
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
else:
mid_memory_str = "之前的聊天内容:\n"
for mid_memory in self.mid_memorys:
mid_memory_str += f"{mid_memory['theme']}\n"
return mid_memory_str + "现在群里正在聊:\n" + self.talking_message_str
def serch_message_by_text(self, text: str) -> Optional[MessageRecv]:
"""
根据回复的纯文本
1. 在talking_message中查找最新的最匹配的消息
2. 如果找到,则返回消息
"""
msg_list = []
find_msg = None
reverse_talking_message = list(reversed(self.talking_message))
for message in reverse_talking_message:
if message["processed_plain_text"] == text:
find_msg = message
logger.debug(f"找到的锚定消息find_msg: {find_msg}")
break
else:
similarity = difflib.SequenceMatcher(None, text, message["processed_plain_text"]).ratio()
msg_list.append({"message": message, "similarity": similarity})
logger.debug(f"对锚定消息检查message: {message['processed_plain_text']},similarity: {similarity}")
if not find_msg:
if msg_list:
msg_list.sort(key=lambda x: x["similarity"], reverse=True)
if msg_list[0]["similarity"] >= 0.5: # 只返回相似度大于等于0.5的消息
find_msg = msg_list[0]["message"]
else:
logger.debug("没有找到锚定消息,相似度低")
return None
else:
logger.debug("没有找到锚定消息,没有消息捕获")
return None
# logger.debug(f"找到的锚定消息find_msg: {find_msg}")
group_info = find_msg.get("chat_info", {}).get("group_info")
user_info = find_msg.get("chat_info", {}).get("user_info")
content_format = ""
accept_format = ""
template_items = {}
format_info = {"content_format": content_format, "accept_format": accept_format}
template_info = {
"template_items": template_items,
}
message_info = {
"platform": find_msg.get("platform"),
"message_id": find_msg.get("message_id"),
"time": find_msg.get("time"),
"group_info": group_info,
"user_info": user_info,
"additional_config": find_msg.get("additional_config"),
"format_info": format_info,
"template_info": template_info,
}
message_dict = {
"message_info": message_info,
"raw_message": find_msg.get("processed_plain_text"),
"detailed_plain_text": find_msg.get("processed_plain_text"),
"processed_plain_text": find_msg.get("processed_plain_text"),
}
find_rec_msg = MessageRecv(message_dict)
logger.debug(f"锚定消息处理后find_rec_msg: {find_rec_msg}")
return find_rec_msg
async def observe(self):
# 自上一次观察的新消息
new_messages_list = get_raw_msg_by_timestamp_with_chat(
chat_id=self.chat_id,
timestamp_start=self.last_observe_time,
timestamp_end=datetime.now().timestamp(),
limit=self.max_now_obs_len,
limit_mode="latest",
)
last_obs_time_mark = self.last_observe_time
if new_messages_list:
self.last_observe_time = new_messages_list[-1]["time"]
self.talking_message.extend(new_messages_list)
if len(self.talking_message) > self.max_now_obs_len:
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
oldest_messages = self.talking_message[:messages_to_remove_count]
self.talking_message = self.talking_message[messages_to_remove_count:] # 保留后半部分,即最新的
oldest_messages_str = await build_readable_messages(
messages=oldest_messages, timestamp_mode="normal", read_mark=0
)
# --- Build prompt using template ---
prompt = None # Initialize prompt as None
try:
# 构建 Prompt - 根据 is_group_chat 选择模板
if self.is_group_chat:
prompt_template_name = "chat_summary_group_prompt"
prompt = await global_prompt_manager.format_prompt(
prompt_template_name, chat_logs=oldest_messages_str
)
else:
# For private chat, add chat_target to the prompt variables
prompt_template_name = "chat_summary_private_prompt"
# Determine the target name for the prompt
chat_target_name = "对方" # Default fallback
if self.chat_target_info:
# Prioritize person_name, then nickname
chat_target_name = (
self.chat_target_info.get("person_name")
or self.chat_target_info.get("user_nickname")
or chat_target_name
)
# Format the private chat prompt
prompt = await global_prompt_manager.format_prompt(
prompt_template_name,
# Assuming the private prompt template uses {chat_target}
chat_target=chat_target_name,
chat_logs=oldest_messages_str,
)
except Exception as e:
logger.error(f"构建总结 Prompt 失败 for chat {self.chat_id}: {e}")
# prompt remains None
if prompt: # Check if prompt was built successfully
self.compressor_prompt = prompt
self.oldest_messages = oldest_messages
self.oldest_messages_str = oldest_messages_str
self.talking_message_str = await build_readable_messages(
messages=self.talking_message,
timestamp_mode="lite",
read_mark=last_obs_time_mark,
)
self.talking_message_str_truncate = await build_readable_messages(
messages=self.talking_message,
timestamp_mode="normal",
read_mark=last_obs_time_mark,
truncate=True,
)
self.person_list = await get_person_id_list(self.talking_message)
# print(f"self.11111person_list: {self.person_list}")
logger.trace(
f"Chat {self.chat_id} - 压缩早期记忆:{self.mid_memory_info}\n现在聊天内容:{self.talking_message_str}"
)
async def has_new_messages_since(self, timestamp: float) -> bool:
"""检查指定时间戳之后是否有新消息"""
count = num_new_messages_since(chat_id=self.chat_id, timestamp_start=timestamp)
return count > 0

View File

@@ -0,0 +1,82 @@
# 定义了来自外部世界的信息
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime
from src.common.logger_manager import get_logger
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
from typing import List
# Import the new utility function
logger = get_logger("observation")
# 所有观察的基类
class HFCloopObservation:
def __init__(self, observe_id):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop: List[CycleDetail] = []
def get_observe_info(self):
return self.observe_info
def add_loop_info(self, loop_info: CycleDetail):
# logger.debug(f"添加循环信息111111111111111111111111111111111111: {loop_info}")
# print(f"添加循环信息111111111111111111111111111111111111: {loop_info}")
print(f"action_taken: {loop_info.action_taken}")
print(f"action_type: {loop_info.action_type}")
print(f"response_info: {loop_info.response_info}")
self.history_loop.append(loop_info)
async def observe(self):
recent_active_cycles: List[CycleDetail] = []
for cycle in reversed(self.history_loop):
# 只关心实际执行了动作的循环
if cycle.action_taken:
recent_active_cycles.append(cycle)
# 最多找最近的3个活动循环
if len(recent_active_cycles) == 3:
break
cycle_info_block = ""
consecutive_text_replies = 0
responses_for_prompt = []
# 检查这最近的活动循环中有多少是连续的文本回复 (从最近的开始看)
for cycle in recent_active_cycles:
if cycle.action_type == "reply":
consecutive_text_replies += 1
# 获取回复内容,如果不存在则返回'[空回复]'
response_text = cycle.response_info.get("response_text", "[空回复]")
responses_for_prompt.append(response_text)
else:
break
# 根据连续文本回复的数量构建提示信息
# 注意: responses_for_prompt 列表是从最近到最远排序的
if consecutive_text_replies >= 3: # 如果最近的三个活动都是文本回复
cycle_info_block = f'你已经连续回复了三条消息(最近: "{responses_for_prompt[0]}",第二近: "{responses_for_prompt[1]}",第三近: "{responses_for_prompt[2]}")。你回复的有点多了,请注意'
elif consecutive_text_replies == 2: # 如果最近的两个活动是文本回复
cycle_info_block = f'你已经连续回复了两条消息(最近: "{responses_for_prompt[0]}",第二近: "{responses_for_prompt[1]}"),请注意'
elif consecutive_text_replies == 1: # 如果最近的一个活动是文本回复
cycle_info_block = f'你刚刚已经回复一条消息(内容: "{responses_for_prompt[0]}"'
# 包装提示块,增加可读性,即使没有连续回复也给个标记
if cycle_info_block:
cycle_info_block = f"\n你最近的回复\n{cycle_info_block}\n"
else:
# 如果最近的活动循环不是文本回复,或者没有活动循环
cycle_info_block = "\n"
# 获取history_loop中最新添加的
if self.history_loop:
last_loop = self.history_loop[-1]
start_time = last_loop.start_time
end_time = last_loop.end_time
if start_time is not None and end_time is not None:
time_diff = int(end_time - start_time)
cycle_info_block += f"\n距离你上一次阅读消息已经过去了{time_diff}分钟\n"
else:
cycle_info_block += "\n无法获取上一次阅读消息的时间\n"
self.observe_info = cycle_info_block

View File

@@ -0,0 +1,55 @@
from src.chat.heart_flow.observation.observation import Observation
from datetime import datetime
from src.common.logger_manager import get_logger
import traceback
# Import the new utility function
from src.chat.memory_system.Hippocampus import HippocampusManager
import jieba
from typing import List
logger = get_logger("memory")
class MemoryObservation(Observation):
def __init__(self, observe_id):
super().__init__(observe_id)
self.observe_info: str = ""
self.context: str = ""
self.running_memory: List[dict] = []
def get_observe_info(self):
for memory in self.running_memory:
self.observe_info += f"{memory['topic']}:{memory['content']}\n"
return self.observe_info
async def observe(self):
# ---------- 2. 获取记忆 ----------
try:
# 从聊天内容中提取关键词
chat_words = set(jieba.cut(self.context))
# 过滤掉停用词和单字词
keywords = [word for word in chat_words if len(word) > 1]
# 去重并限制数量
keywords = list(set(keywords))[:5]
logger.debug(f"取的关键词: {keywords}")
# 调用记忆系统获取相关记忆
related_memory = await HippocampusManager.get_instance().get_memory_from_topic(
valid_keywords=keywords, max_memory_num=3, max_memory_length=2, max_depth=3
)
logger.debug(f"获取到的记忆: {related_memory}")
if related_memory:
for topic, memory in related_memory:
# 将记忆添加到 running_memory
self.running_memory.append(
{"topic": topic, "content": memory, "timestamp": datetime.now().isoformat()}
)
logger.debug(f"添加新记忆: {topic} - {memory}")
except Exception as e:
logger.error(f"观察 记忆时出错: {e}")
logger.error(traceback.format_exc())

View File

@@ -0,0 +1,17 @@
# 定义了来自外部世界的信息
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime
from src.common.logger_manager import get_logger
logger = get_logger("observation")
# 所有观察的基类
class Observation:
def __init__(self, observe_id):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
async def observe(self):
pass

View File

@@ -0,0 +1,34 @@
# 定义了来自外部世界的信息
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
from datetime import datetime
from src.common.logger_manager import get_logger
# Import the new utility function
logger = get_logger("observation")
# 所有观察的基类
class WorkingObservation:
def __init__(self, observe_id):
self.observe_info = ""
self.observe_id = observe_id
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
self.history_loop = []
self.structured_info = []
def get_observe_info(self):
return self.structured_info
def add_structured_info(self, structured_info: dict):
self.structured_info.append(structured_info)
async def observe(self):
observed_structured_infos = []
for structured_info in self.structured_info:
if structured_info.get("ttl") > 0:
structured_info["ttl"] -= 1
observed_structured_infos.append(structured_info)
logger.debug(f"观察到结构化信息仍旧在: {structured_info}")
self.structured_info = observed_structured_infos