ref:调整文件位置和命名,结构更清晰
This commit is contained in:
@@ -7,7 +7,6 @@ class ChatState(enum.Enum):
|
||||
NORMAL = "随便水群"
|
||||
FOCUSED = "认真水群"
|
||||
|
||||
|
||||
class ChatStateInfo:
|
||||
def __init__(self):
|
||||
self.chat_status: ChatState = ChatState.NORMAL
|
||||
|
||||
152
src/chat/heart_flow/heartflow_message_processor.py
Normal file
152
src/chat/heart_flow/heartflow_message_processor.py
Normal file
@@ -0,0 +1,152 @@
|
||||
from src.chat.memory_system.Hippocampus import hippocampus_manager
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.heart_flow.heartflow import heartflow
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.common.logger import get_logger
|
||||
import re
|
||||
import math
|
||||
import traceback
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from src.person_info.relationship_manager import get_relationship_manager
|
||||
|
||||
# from ..message_receive.message_buffer import message_buffer
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _handle_error(error: Exception, context: str, message: Optional[MessageRecv] = None) -> None:
|
||||
"""统一的错误处理函数
|
||||
|
||||
Args:
|
||||
error: 捕获到的异常
|
||||
context: 错误发生的上下文描述
|
||||
message: 可选的消息对象,用于记录相关消息内容
|
||||
"""
|
||||
logger.error(f"{context}: {error}")
|
||||
logger.error(traceback.format_exc())
|
||||
if message and hasattr(message, "raw_message"):
|
||||
logger.error(f"相关消息原始内容: {message.raw_message}")
|
||||
|
||||
|
||||
async def _process_relationship(message: MessageRecv) -> None:
|
||||
"""处理用户关系逻辑
|
||||
|
||||
Args:
|
||||
message: 消息对象,包含用户信息
|
||||
"""
|
||||
platform = message.message_info.platform
|
||||
user_id = message.message_info.user_info.user_id
|
||||
nickname = message.message_info.user_info.user_nickname
|
||||
cardname = message.message_info.user_info.user_cardname or nickname
|
||||
|
||||
relationship_manager = get_relationship_manager()
|
||||
is_known = await relationship_manager.is_known_some_one(platform, user_id)
|
||||
|
||||
if not is_known:
|
||||
logger.info(f"首次认识用户: {nickname}")
|
||||
await relationship_manager.first_knowing_some_one(platform, user_id, nickname, cardname)
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
message: 待处理的消息对象
|
||||
|
||||
Returns:
|
||||
Tuple[float, bool]: (兴趣度, 是否被提及)
|
||||
"""
|
||||
is_mentioned, _ = is_mentioned_bot_in_message(message)
|
||||
interested_rate = 0.0
|
||||
|
||||
if global_config.memory.enable_memory:
|
||||
with Timer("记忆激活"):
|
||||
interested_rate = await hippocampus_manager.get_activate_from_text(
|
||||
message.processed_plain_text,
|
||||
fast_retrieval=True,
|
||||
)
|
||||
logger.debug(f"记忆激活率: {interested_rate:.2f}")
|
||||
|
||||
text_len = len(message.processed_plain_text)
|
||||
# 根据文本长度调整兴趣度,长度越大兴趣度越高,但增长率递减,最低0.01,最高0.05
|
||||
# 采用对数函数实现递减增长
|
||||
|
||||
base_interest = 0.01 + (0.05 - 0.01) * (math.log10(text_len + 1) / math.log10(1000 + 1))
|
||||
base_interest = min(max(base_interest, 0.01), 0.05)
|
||||
|
||||
interested_rate += base_interest
|
||||
|
||||
if is_mentioned:
|
||||
interest_increase_on_mention = 1
|
||||
interested_rate += interest_increase_on_mention
|
||||
|
||||
return interested_rate, is_mentioned
|
||||
|
||||
|
||||
class HeartFCMessageReceiver:
|
||||
"""心流处理器,负责处理接收到的消息并计算兴趣度"""
|
||||
|
||||
def __init__(self):
|
||||
"""初始化心流处理器,创建消息存储实例"""
|
||||
self.storage = MessageStorage()
|
||||
|
||||
async def process_message(self, message: MessageRecv) -> None:
|
||||
"""处理接收到的原始消息数据
|
||||
|
||||
主要流程:
|
||||
1. 消息解析与初始化
|
||||
2. 消息缓冲处理
|
||||
3. 过滤检查
|
||||
4. 兴趣度计算
|
||||
5. 关系处理
|
||||
|
||||
Args:
|
||||
message_data: 原始消息字符串
|
||||
"""
|
||||
try:
|
||||
# 1. 消息解析与初始化
|
||||
groupinfo = message.message_info.group_info
|
||||
userinfo = message.message_info.user_info
|
||||
messageinfo = message.message_info
|
||||
|
||||
chat = await get_chat_manager().get_or_create_stream(
|
||||
platform=messageinfo.platform,
|
||||
user_info=userinfo,
|
||||
group_info=groupinfo,
|
||||
)
|
||||
|
||||
await self.storage.store_message(message, chat)
|
||||
|
||||
subheartflow = await heartflow.get_or_create_subheartflow(chat.stream_id)
|
||||
message.update_chat_stream(chat)
|
||||
|
||||
# 6. 兴趣度计算与更新
|
||||
interested_rate, is_mentioned = await _calculate_interest(message)
|
||||
subheartflow.add_message_to_normal_chat_cache(message, interested_rate, is_mentioned)
|
||||
|
||||
# 7. 日志记录
|
||||
mes_name = chat.group_info.group_name if chat.group_info else "私聊"
|
||||
# current_time = time.strftime("%H:%M:%S", time.localtime(message.message_info.time))
|
||||
current_talk_frequency = global_config.chat.get_current_talk_frequency(chat.stream_id)
|
||||
|
||||
# 如果消息中包含图片标识,则日志展示为图片
|
||||
|
||||
picid_match = re.search(r"\[picid:([^\]]+)\]", message.processed_plain_text)
|
||||
if picid_match:
|
||||
logger.info(f"[{mes_name}]{userinfo.user_nickname}: [图片] [当前回复频率: {current_talk_frequency}]")
|
||||
else:
|
||||
logger.info(
|
||||
f"[{mes_name}]{userinfo.user_nickname}:{message.processed_plain_text}[当前回复频率: {current_talk_frequency}]"
|
||||
)
|
||||
|
||||
# 8. 关系处理
|
||||
if global_config.relationship.enable_relationship:
|
||||
await _process_relationship(message)
|
||||
|
||||
except Exception as e:
|
||||
await _handle_error(e, "消息处理失败", message)
|
||||
@@ -1,46 +0,0 @@
|
||||
# 定义了来自外部世界的信息
|
||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.planner_actions.action_manager import ActionManager
|
||||
|
||||
logger = get_logger("observation")
|
||||
|
||||
|
||||
# 特殊的观察,专门用于观察动作
|
||||
# 所有观察的基类
|
||||
class ActionObservation:
|
||||
def __init__(self, observe_id):
|
||||
self.observe_info = ""
|
||||
self.observe_id = observe_id
|
||||
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
|
||||
self.action_manager: ActionManager = None
|
||||
|
||||
self.all_actions = {}
|
||||
self.all_using_actions = {}
|
||||
|
||||
def get_observe_info(self):
|
||||
return self.observe_info
|
||||
|
||||
def set_action_manager(self, action_manager: ActionManager):
|
||||
self.action_manager = action_manager
|
||||
self.all_actions = self.action_manager.get_registered_actions()
|
||||
|
||||
async def observe(self):
|
||||
action_info_block = ""
|
||||
self.all_using_actions = self.action_manager.get_using_actions()
|
||||
for action_name, action_info in self.all_using_actions.items():
|
||||
action_info_block += f"\n{action_name}: {action_info.get('description', '')}"
|
||||
action_info_block += "\n注意,除了上面动作选项之外,你在群聊里不能做其他任何事情,这是你能力的边界\n"
|
||||
|
||||
self.observe_info = action_info_block
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将观察对象转换为可序列化的字典"""
|
||||
return {
|
||||
"observe_info": self.observe_info,
|
||||
"observe_id": self.observe_id,
|
||||
"last_observe_time": self.last_observe_time,
|
||||
"all_actions": self.all_actions,
|
||||
"all_using_actions": self.all_using_actions,
|
||||
}
|
||||
@@ -1,183 +0,0 @@
|
||||
from datetime import datetime
|
||||
from src.config.config import global_config
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
num_new_messages_since,
|
||||
get_person_id_list,
|
||||
)
|
||||
from src.chat.utils.prompt_builder import global_prompt_manager, Prompt
|
||||
from src.chat.heart_flow.observation.observation import Observation
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.heart_flow.utils_chat import get_chat_type_and_target_info
|
||||
|
||||
logger = get_logger("observation")
|
||||
|
||||
# 定义提示模板
|
||||
Prompt(
|
||||
"""这是{chat_type_description},请总结以下聊天记录的主题:
|
||||
{chat_logs}
|
||||
请概括这段聊天记录的主题和主要内容
|
||||
主题:简短的概括,包括时间,人物和事件,不要超过20个字
|
||||
内容:具体的信息内容,包括人物、事件和信息,不要超过200个字,不要分点。
|
||||
|
||||
请用json格式返回,格式如下:
|
||||
{{
|
||||
"theme": "主题,例如 2025-06-14 10:00:00 群聊 麦麦 和 网友 讨论了 游戏 的话题",
|
||||
"content": "内容,可以是对聊天记录的概括,也可以是聊天记录的详细内容"
|
||||
}}
|
||||
""",
|
||||
"chat_summary_prompt",
|
||||
)
|
||||
|
||||
|
||||
class ChattingObservation(Observation):
|
||||
def __init__(self, chat_id):
|
||||
super().__init__(chat_id)
|
||||
self.chat_id = chat_id
|
||||
self.platform = "qq"
|
||||
|
||||
self.is_group_chat, self.chat_target_info = get_chat_type_and_target_info(self.chat_id)
|
||||
|
||||
self.talking_message = []
|
||||
self.talking_message_str = ""
|
||||
self.talking_message_str_truncate = ""
|
||||
self.talking_message_str_short = ""
|
||||
self.talking_message_str_truncate_short = ""
|
||||
self.name = global_config.bot.nickname
|
||||
self.nick_name = global_config.bot.alias_names
|
||||
self.max_now_obs_len = global_config.chat.max_context_size
|
||||
self.overlap_len = global_config.focus_chat.compressed_length
|
||||
self.person_list = []
|
||||
self.compressor_prompt = ""
|
||||
self.oldest_messages = []
|
||||
self.oldest_messages_str = ""
|
||||
|
||||
self.last_observe_time = datetime.now().timestamp()
|
||||
initial_messages = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 10)
|
||||
initial_messages_short = get_raw_msg_before_timestamp_with_chat(self.chat_id, self.last_observe_time, 5)
|
||||
self.last_observe_time = initial_messages[-1]["time"] if initial_messages else self.last_observe_time
|
||||
self.talking_message = initial_messages
|
||||
self.talking_message_short = initial_messages_short
|
||||
self.talking_message_str = build_readable_messages(self.talking_message, show_actions=True)
|
||||
self.talking_message_str_truncate = build_readable_messages(
|
||||
self.talking_message, show_actions=True, truncate=True
|
||||
)
|
||||
self.talking_message_str_short = build_readable_messages(self.talking_message_short, show_actions=True)
|
||||
self.talking_message_str_truncate_short = build_readable_messages(
|
||||
self.talking_message_short, show_actions=True, truncate=True
|
||||
)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将观察对象转换为可序列化的字典"""
|
||||
return {
|
||||
"chat_id": self.chat_id,
|
||||
"platform": self.platform,
|
||||
"is_group_chat": self.is_group_chat,
|
||||
"chat_target_info": self.chat_target_info,
|
||||
"talking_message_str": self.talking_message_str,
|
||||
"talking_message_str_truncate": self.talking_message_str_truncate,
|
||||
"talking_message_str_short": self.talking_message_str_short,
|
||||
"talking_message_str_truncate_short": self.talking_message_str_truncate_short,
|
||||
"name": self.name,
|
||||
"nick_name": self.nick_name,
|
||||
"last_observe_time": self.last_observe_time,
|
||||
}
|
||||
|
||||
def get_observe_info(self, ids=None):
|
||||
return self.talking_message_str
|
||||
|
||||
async def observe(self):
|
||||
# 自上一次观察的新消息
|
||||
new_messages_list = get_raw_msg_by_timestamp_with_chat(
|
||||
chat_id=self.chat_id,
|
||||
timestamp_start=self.last_observe_time,
|
||||
timestamp_end=datetime.now().timestamp(),
|
||||
limit=self.max_now_obs_len,
|
||||
limit_mode="latest",
|
||||
)
|
||||
|
||||
# print(f"new_messages_list: {new_messages_list}")
|
||||
|
||||
last_obs_time_mark = self.last_observe_time
|
||||
if new_messages_list:
|
||||
self.last_observe_time = new_messages_list[-1]["time"]
|
||||
self.talking_message.extend(new_messages_list)
|
||||
|
||||
if len(self.talking_message) > self.max_now_obs_len:
|
||||
# 计算需要移除的消息数量,保留最新的 max_now_obs_len 条
|
||||
messages_to_remove_count = len(self.talking_message) - self.max_now_obs_len
|
||||
oldest_messages = self.talking_message[:messages_to_remove_count]
|
||||
self.talking_message = self.talking_message[messages_to_remove_count:]
|
||||
|
||||
# 构建压缩提示
|
||||
oldest_messages_str = build_readable_messages(
|
||||
messages=oldest_messages, timestamp_mode="normal_no_YMD", read_mark=0, show_actions=True
|
||||
)
|
||||
|
||||
# 根据聊天类型选择提示模板
|
||||
prompt_template_name = "chat_summary_prompt"
|
||||
if self.is_group_chat:
|
||||
chat_type_description = "qq群聊的聊天记录"
|
||||
else:
|
||||
chat_target_name = "对方"
|
||||
if self.chat_target_info:
|
||||
chat_target_name = (
|
||||
self.chat_target_info.get("person_name")
|
||||
or self.chat_target_info.get("user_nickname")
|
||||
or chat_target_name
|
||||
)
|
||||
chat_type_description = f"你和{chat_target_name}的私聊记录"
|
||||
|
||||
prompt = await global_prompt_manager.format_prompt(
|
||||
prompt_template_name,
|
||||
chat_type_description=chat_type_description,
|
||||
chat_logs=oldest_messages_str,
|
||||
)
|
||||
|
||||
self.compressor_prompt = prompt
|
||||
|
||||
# 构建当前消息
|
||||
self.talking_message_str = build_readable_messages(
|
||||
messages=self.talking_message,
|
||||
timestamp_mode="lite",
|
||||
read_mark=last_obs_time_mark,
|
||||
show_actions=True,
|
||||
)
|
||||
self.talking_message_str_truncate = build_readable_messages(
|
||||
messages=self.talking_message,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
# 构建简短版本 - 使用最新一半的消息
|
||||
half_count = len(self.talking_message) // 2
|
||||
recent_messages = self.talking_message[-half_count:] if half_count > 0 else self.talking_message
|
||||
|
||||
self.talking_message_str_short = build_readable_messages(
|
||||
messages=recent_messages,
|
||||
timestamp_mode="lite",
|
||||
read_mark=last_obs_time_mark,
|
||||
show_actions=True,
|
||||
)
|
||||
self.talking_message_str_truncate_short = build_readable_messages(
|
||||
messages=recent_messages,
|
||||
timestamp_mode="normal_no_YMD",
|
||||
read_mark=last_obs_time_mark,
|
||||
truncate=True,
|
||||
show_actions=True,
|
||||
)
|
||||
|
||||
self.person_list = await get_person_id_list(self.talking_message)
|
||||
|
||||
# logger.debug(
|
||||
# f"Chat {self.chat_id} - 现在聊天内容:{self.talking_message_str}"
|
||||
# )
|
||||
|
||||
async def has_new_messages_since(self, timestamp: float) -> bool:
|
||||
"""检查指定时间戳之后是否有新消息"""
|
||||
count = num_new_messages_since(chat_id=self.chat_id, timestamp_start=timestamp)
|
||||
return count > 0
|
||||
@@ -1,128 +0,0 @@
|
||||
# 定义了来自外部世界的信息
|
||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.heartFC_Cycleinfo import CycleDetail
|
||||
from typing import List
|
||||
# Import the new utility function
|
||||
|
||||
logger = get_logger("observation")
|
||||
|
||||
|
||||
# 所有观察的基类
|
||||
class HFCloopObservation:
|
||||
def __init__(self, observe_id):
|
||||
self.observe_info = ""
|
||||
self.observe_id = observe_id
|
||||
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
|
||||
self.history_loop: List[CycleDetail] = []
|
||||
|
||||
def get_observe_info(self):
|
||||
return self.observe_info
|
||||
|
||||
def add_loop_info(self, loop_info: CycleDetail):
|
||||
self.history_loop.append(loop_info)
|
||||
|
||||
async def observe(self):
|
||||
recent_active_cycles: List[CycleDetail] = []
|
||||
for cycle in reversed(self.history_loop):
|
||||
# 只关心实际执行了动作的循环
|
||||
# action_taken = cycle.loop_action_info["action_taken"]
|
||||
# if action_taken:
|
||||
recent_active_cycles.append(cycle)
|
||||
if len(recent_active_cycles) == 5:
|
||||
break
|
||||
|
||||
cycle_info_block = ""
|
||||
action_detailed_str = ""
|
||||
consecutive_text_replies = 0
|
||||
responses_for_prompt = []
|
||||
|
||||
cycle_last_reason = ""
|
||||
|
||||
# 检查这最近的活动循环中有多少是连续的文本回复 (从最近的开始看)
|
||||
for cycle in recent_active_cycles:
|
||||
action_result = cycle.loop_plan_info.get("action_result", {})
|
||||
action_type = action_result.get("action_type", "unknown")
|
||||
action_reasoning = action_result.get("reasoning", "未提供理由")
|
||||
is_taken = cycle.loop_action_info.get("action_taken", False)
|
||||
action_taken_time = cycle.loop_action_info.get("taken_time", 0)
|
||||
action_taken_time_str = (
|
||||
datetime.fromtimestamp(action_taken_time).strftime("%H:%M:%S") if action_taken_time > 0 else "未知时间"
|
||||
)
|
||||
# print(action_type)
|
||||
# print(action_reasoning)
|
||||
# print(is_taken)
|
||||
# print(action_taken_time_str)
|
||||
# print("--------------------------------")
|
||||
if action_reasoning != cycle_last_reason:
|
||||
cycle_last_reason = action_reasoning
|
||||
action_reasoning_str = f"你选择这个action的原因是:{action_reasoning}"
|
||||
else:
|
||||
action_reasoning_str = ""
|
||||
|
||||
if action_type == "reply":
|
||||
consecutive_text_replies += 1
|
||||
response_text = cycle.loop_action_info.get("reply_text", "")
|
||||
responses_for_prompt.append(response_text)
|
||||
|
||||
if is_taken:
|
||||
action_detailed_str += f"{action_taken_time_str}时,你选择回复(action:{action_type},内容是:'{response_text}')。{action_reasoning_str}\n"
|
||||
else:
|
||||
action_detailed_str += f"{action_taken_time_str}时,你选择回复(action:{action_type},内容是:'{response_text}'),但是动作失败了。{action_reasoning_str}\n"
|
||||
elif action_type == "no_reply":
|
||||
# action_detailed_str += (
|
||||
# f"{action_taken_time_str}时,你选择不回复(action:{action_type}),{action_reasoning_str}\n"
|
||||
# )
|
||||
pass
|
||||
else:
|
||||
if is_taken:
|
||||
action_detailed_str += (
|
||||
f"{action_taken_time_str}时,你选择执行了(action:{action_type}),{action_reasoning_str}\n"
|
||||
)
|
||||
else:
|
||||
action_detailed_str += f"{action_taken_time_str}时,你选择执行了(action:{action_type}),但是动作失败了。{action_reasoning_str}\n"
|
||||
|
||||
if action_detailed_str:
|
||||
cycle_info_block = f"\n你最近做的事:\n{action_detailed_str}\n"
|
||||
else:
|
||||
cycle_info_block = "\n"
|
||||
|
||||
# 根据连续文本回复的数量构建提示信息
|
||||
if consecutive_text_replies >= 3: # 如果最近的三个活动都是文本回复
|
||||
cycle_info_block = f'你已经连续回复了三条消息(最近: "{responses_for_prompt[0]}",第二近: "{responses_for_prompt[1]}",第三近: "{responses_for_prompt[2]}")。你回复的有点多了,请注意'
|
||||
elif consecutive_text_replies == 2: # 如果最近的两个活动是文本回复
|
||||
cycle_info_block = f'你已经连续回复了两条消息(最近: "{responses_for_prompt[0]}",第二近: "{responses_for_prompt[1]}"),请注意'
|
||||
|
||||
# 包装提示块,增加可读性,即使没有连续回复也给个标记
|
||||
# if cycle_info_block:
|
||||
# cycle_info_block = f"\n你最近的回复\n{cycle_info_block}\n"
|
||||
# else:
|
||||
# cycle_info_block = "\n"
|
||||
|
||||
# 获取history_loop中最新添加的
|
||||
if self.history_loop:
|
||||
last_loop = self.history_loop[0]
|
||||
start_time = last_loop.start_time
|
||||
end_time = last_loop.end_time
|
||||
if start_time is not None and end_time is not None:
|
||||
time_diff = int(end_time - start_time)
|
||||
if time_diff > 60:
|
||||
cycle_info_block += f"距离你上一次阅读消息并思考和规划,已经过去了{int(time_diff / 60)}分钟\n"
|
||||
else:
|
||||
cycle_info_block += f"距离你上一次阅读消息并思考和规划,已经过去了{time_diff}秒\n"
|
||||
else:
|
||||
cycle_info_block += "你还没看过消息\n"
|
||||
|
||||
self.observe_info = cycle_info_block
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将观察对象转换为可序列化的字典"""
|
||||
# 只序列化基本信息,避免循环引用
|
||||
return {
|
||||
"observe_info": self.observe_info,
|
||||
"observe_id": self.observe_id,
|
||||
"last_observe_time": self.last_observe_time,
|
||||
# 不序列化history_loop,避免循环引用
|
||||
"history_loop_count": len(self.history_loop),
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
# 定义了来自外部世界的信息
|
||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("observation")
|
||||
|
||||
|
||||
# 所有观察的基类
|
||||
class Observation:
|
||||
def __init__(self, observe_id):
|
||||
self.observe_info = ""
|
||||
self.observe_id = observe_id
|
||||
self.last_observe_time = datetime.now().timestamp() # 初始化为当前时间
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""将观察对象转换为可序列化的字典"""
|
||||
return {
|
||||
"observe_info": self.observe_info,
|
||||
"observe_id": self.observe_id,
|
||||
"last_observe_time": self.last_observe_time,
|
||||
}
|
||||
|
||||
async def observe(self):
|
||||
pass
|
||||
@@ -1,34 +0,0 @@
|
||||
# 定义了来自外部世界的信息
|
||||
# 外部世界可以是某个聊天 不同平台的聊天 也可以是任意媒体
|
||||
from datetime import datetime
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.focus_chat.working_memory.working_memory import WorkingMemory
|
||||
from src.chat.focus_chat.working_memory.memory_item import MemoryItem
|
||||
from typing import List
|
||||
# Import the new utility function
|
||||
|
||||
logger = get_logger("observation")
|
||||
|
||||
|
||||
# 所有观察的基类
|
||||
class WorkingMemoryObservation:
|
||||
def __init__(self, observe_id):
|
||||
self.observe_info = ""
|
||||
self.observe_id = observe_id
|
||||
self.last_observe_time = datetime.now().timestamp()
|
||||
|
||||
self.working_memory = WorkingMemory(chat_id=observe_id)
|
||||
|
||||
self.retrieved_working_memory = []
|
||||
|
||||
def get_observe_info(self):
|
||||
return self.working_memory
|
||||
|
||||
def add_retrieved_working_memory(self, retrieved_working_memory: List[MemoryItem]):
|
||||
self.retrieved_working_memory.append(retrieved_working_memory)
|
||||
|
||||
def get_retrieved_working_memory(self):
|
||||
return self.retrieved_working_memory
|
||||
|
||||
async def observe(self):
|
||||
pass
|
||||
@@ -1,5 +1,3 @@
|
||||
from .observation.observation import Observation
|
||||
from src.chat.heart_flow.observation.chatting_observation import ChattingObservation
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, List, Dict, Tuple
|
||||
@@ -10,7 +8,7 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.focus_chat.heartFC_chat import HeartFChatting
|
||||
from src.chat.normal_chat.normal_chat import NormalChat
|
||||
from src.chat.heart_flow.chat_state_info import ChatState, ChatStateInfo
|
||||
from .utils_chat import get_chat_type_and_target_info
|
||||
from src.chat.utils.utils import get_chat_type_and_target_info
|
||||
from src.config.config import global_config
|
||||
from rich.traceback import install
|
||||
|
||||
@@ -314,24 +312,6 @@ class SubHeartflow:
|
||||
f"{log_prefix} 尝试将状态从 {current_state.value} 变为 {new_state.value},但未成功或未执行更改。"
|
||||
)
|
||||
|
||||
def add_observation(self, observation: Observation):
|
||||
for existing_obs in self.observations:
|
||||
if existing_obs.observe_id == observation.observe_id:
|
||||
return
|
||||
self.observations.append(observation)
|
||||
|
||||
def remove_observation(self, observation: Observation):
|
||||
if observation in self.observations:
|
||||
self.observations.remove(observation)
|
||||
|
||||
def get_all_observations(self) -> list[Observation]:
|
||||
return self.observations
|
||||
|
||||
def _get_primary_observation(self) -> Optional[ChattingObservation]:
|
||||
if self.observations and isinstance(self.observations[0], ChattingObservation):
|
||||
return self.observations[0]
|
||||
logger.warning(f"SubHeartflow {self.subheartflow_id} 没有找到有效的 ChattingObservation")
|
||||
return None
|
||||
|
||||
def get_normal_chat_last_speak_time(self) -> float:
|
||||
if self.normal_chat_instance:
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
from typing import Optional, Tuple, Dict
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
|
||||
logger = get_logger("heartflow_utils")
|
||||
|
||||
|
||||
def get_chat_type_and_target_info(chat_id: str) -> Tuple[bool, Optional[Dict]]:
|
||||
"""
|
||||
获取聊天类型(是否群聊)和私聊对象信息。
|
||||
|
||||
Args:
|
||||
chat_id: 聊天流ID
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[Dict]]:
|
||||
- bool: 是否为群聊 (True 是群聊, False 是私聊或未知)
|
||||
- Optional[Dict]: 如果是私聊,包含对方信息的字典;否则为 None。
|
||||
字典包含: platform, user_id, user_nickname, person_id, person_name
|
||||
"""
|
||||
is_group_chat = False # Default to private/unknown
|
||||
chat_target_info = None
|
||||
|
||||
try:
|
||||
chat_stream = get_chat_manager().get_stream(chat_id)
|
||||
|
||||
if chat_stream:
|
||||
if chat_stream.group_info:
|
||||
is_group_chat = True
|
||||
chat_target_info = None # Explicitly None for group chat
|
||||
elif chat_stream.user_info: # It's a private chat
|
||||
is_group_chat = False
|
||||
user_info = chat_stream.user_info
|
||||
platform = chat_stream.platform
|
||||
user_id = user_info.user_id
|
||||
|
||||
# Initialize target_info with basic info
|
||||
target_info = {
|
||||
"platform": platform,
|
||||
"user_id": user_id,
|
||||
"user_nickname": user_info.user_nickname,
|
||||
"person_id": None,
|
||||
"person_name": None,
|
||||
}
|
||||
|
||||
# Try to fetch person info
|
||||
try:
|
||||
# Assume get_person_id is sync (as per original code), keep using to_thread
|
||||
person_id = PersonInfoManager.get_person_id(platform, user_id)
|
||||
person_name = None
|
||||
if person_id:
|
||||
# get_value is async, so await it directly
|
||||
person_info_manager = get_person_info_manager()
|
||||
person_name = person_info_manager.get_value_sync(person_id, "person_name")
|
||||
|
||||
target_info["person_id"] = person_id
|
||||
target_info["person_name"] = person_name
|
||||
except Exception as person_e:
|
||||
logger.warning(
|
||||
f"获取 person_id 或 person_name 时出错 for {platform}:{user_id} in utils: {person_e}"
|
||||
)
|
||||
|
||||
chat_target_info = target_info
|
||||
else:
|
||||
logger.warning(f"无法获取 chat_stream for {chat_id} in utils")
|
||||
# Keep defaults: is_group_chat=False, chat_target_info=None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取聊天类型和目标信息时出错 for {chat_id}: {e}", exc_info=True)
|
||||
# Keep defaults on error
|
||||
|
||||
return is_group_chat, chat_target_info
|
||||
Reference in New Issue
Block a user