fix: PFC不读取聊天记录

This commit is contained in:
SengokuCola
2025-04-11 10:22:49 +08:00
parent 9aacbd55cb
commit 68a60f7e71
9 changed files with 222 additions and 224 deletions

View File

@@ -45,7 +45,19 @@ class ActionPlanner:
# 构建对话目标
if conversation_info.goal_list:
goal, reasoning = conversation_info.goal_list[-1]
last_goal = conversation_info.goal_list[-1]
print(last_goal)
# 处理字典或元组格式
if isinstance(last_goal, tuple) and len(last_goal) == 2:
goal, reasoning = last_goal
elif isinstance(last_goal, dict) and 'goal' in last_goal and 'reasoning' in last_goal:
# 处理字典格式
goal = last_goal.get('goal', "目前没有明确对话目标")
reasoning = last_goal.get('reasoning', "目前没有明确对话目标,最好思考一个对话目标")
else:
# 处理未知格式
goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
else:
goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标"

View File

@@ -1,5 +1,6 @@
import time
import asyncio
import traceback
from typing import Optional, Dict, Any, List, Tuple
from src.common.logger import get_module_logger
from ..message.message_base import UserInfo
@@ -44,18 +45,14 @@ class ChatObserver:
self.stream_id = stream_id
self.message_storage = message_storage or MongoDBMessageStorage()
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间
# self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
# self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
# self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
# 消息历史记录
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
self.last_message_id: Optional[str] = None # 最后一条消息的ID
self.message_count: int = 0 # 消息计数
# 运行状态
self._running: bool = False
@@ -72,7 +69,7 @@ class ChatObserver:
self.is_cold_chat_state: bool = False
self.update_event = asyncio.Event()
self.update_interval = 5 # 更新间隔(秒)
self.update_interval = 2 # 更新间隔(秒)
self.message_cache = []
self.update_running = False
@@ -98,21 +95,17 @@ class ChatObserver:
Args:
message: 消息数据
"""
self.message_history.append(message)
self.last_message_id = message["message_id"]
self.last_message_time = message["time"] # 更新最后消息时间
self.message_count += 1
try:
# 更新说话时间
user_info = UserInfo.from_dict(message.get("user_info", {}))
if user_info.user_id == global_config.BOT_QQ:
self.last_bot_speak_time = message["time"]
else:
self.last_user_speak_time = message["time"]
# 发送新消息通知
notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message)
await self.notification_manager.send_notification(notification)
# 发送新消息通知
# logger.info(f"发送新ccchandleer消息通知: {message}")
notification = create_new_message_notification(sender="chat_observer", target="observation_info", message=message)
# logger.info(f"发送新消ddddd息通知: {notification}")
# print(self.notification_manager)
await self.notification_manager.send_notification(notification)
except Exception as e:
logger.error(f"添加消息到历史记录时出错: {e}")
print(traceback.format_exc())
# 检查并更新冷场状态
await self._check_cold_chat()
@@ -156,9 +149,6 @@ class ChatObserver:
Returns:
bool: 是否有新消息
"""
if time_point is None:
logger.warning("time_point 为 None返回 False")
return False
if self.last_message_time is None:
logger.debug("没有最后消息时间,返回 False")
@@ -214,6 +204,8 @@ class ChatObserver:
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
print(f"获取111111111122222222新消息: {new_messages}")
return new_messages
@@ -230,6 +222,8 @@ class ChatObserver:
if new_messages:
self.last_message_read = new_messages[-1]["message_id"]
logger.debug(f"获取指定时间点111之前的消息: {new_messages}")
return new_messages
@@ -237,20 +231,24 @@ class ChatObserver:
async def _update_loop(self):
"""更新循环"""
try:
start_time = time.time()
messages = await self._fetch_new_messages_before(start_time)
for message in messages:
await self._add_message_to_history(message)
except Exception as e:
logger.error(f"缓冲消息出错: {e}")
# try:
# start_time = time.time()
# messages = await self._fetch_new_messages_before(start_time)
# for message in messages:
# await self._add_message_to_history(message)
# logger.debug(f"缓冲消息: {messages}")
# except Exception as e:
# logger.error(f"缓冲消息出错: {e}")
while self._running:
try:
# 等待事件或超时1秒
try:
# print("等待事件")
await asyncio.wait_for(self._update_event.wait(), timeout=1)
except asyncio.TimeoutError:
# print("超时")
pass # 超时后也执行一次检查
self._update_event.clear() # 重置触发事件
@@ -355,51 +353,6 @@ class ChatObserver:
return time_info
def start_periodic_update(self):
"""启动观察器的定期更新"""
if not self.update_running:
self.update_running = True
asyncio.create_task(self._periodic_update())
async def _periodic_update(self):
"""定期更新消息历史"""
try:
while self.update_running:
await self._update_message_history()
await asyncio.sleep(self.update_interval)
except Exception as e:
logger.error(f"定期更新消息历史时出错: {str(e)}")
async def _update_message_history(self) -> bool:
"""更新消息历史
Returns:
bool: 是否有新消息
"""
try:
messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50)
if not messages:
return False
# 检查是否有新消息
has_new_messages = False
if messages and (
not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]
):
has_new_messages = True
self.message_cache = messages
if has_new_messages:
self.update_event.set()
self.update_event.clear()
return True
return False
except Exception as e:
logger.error(f"更新消息历史时出错: {str(e)}")
return False
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
"""获取缓存的消息历史
@@ -421,3 +374,6 @@ class ChatObserver:
if not self.message_cache:
return None
return self.message_cache[0]
def __str__(self):
return f"ChatObserver for {self.stream_id}"

View File

@@ -98,11 +98,17 @@ class NotificationManager:
notification_type: 要处理的通知类型
handler: 处理器实例
"""
print(1145145511114445551111444)
if target not in self._handlers:
print("没11有target")
self._handlers[target] = {}
if notification_type not in self._handlers[target]:
print("没11有notification_type")
self._handlers[target][notification_type] = []
print(self._handlers[target][notification_type])
print(f"注册1111111111111111111111处理器: {target} {notification_type} {handler}")
self._handlers[target][notification_type].append(handler)
print(self._handlers[target][notification_type])
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注销通知处理器
@@ -126,6 +132,7 @@ class NotificationManager:
async def send_notification(self, notification: Notification):
"""发送通知"""
self._notification_history.append(notification)
# print("kaishichul-----------------------------------i")
# 如果是状态通知,更新活跃状态
if isinstance(notification, StateNotification):
@@ -133,12 +140,16 @@ class NotificationManager:
self._active_states.add(notification.type)
else:
self._active_states.discard(notification.type)
# 调用目标接收者的处理器
target = notification.target
if target in self._handlers:
handlers = self._handlers[target].get(notification.type, [])
# print(1111111)
print(handlers)
for handler in handlers:
print(f"调用处理器: {handler}")
await handler.handle_notification(notification)
def get_active_states(self) -> Set[NotificationType]:
@@ -170,6 +181,13 @@ class NotificationManager:
history = history[-limit:]
return history
def __str__(self):
str = ""
for target, handlers in self._handlers.items():
for notification_type, handler_list in handlers.items():
str += f"NotificationManager for {target} {notification_type} {handler_list}"
return str
# 一些常用的通知创建函数
@@ -182,8 +200,9 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str,
target=target,
data={
"message_id": message.get("message_id"),
"content": message.get("content"),
"sender": message.get("sender"),
"processed_plain_text": message.get("processed_plain_text"),
"detailed_plain_text": message.get("detailed_plain_text"),
"user_info": message.get("user_info"),
"time": message.get("time"),
},
)
@@ -276,3 +295,5 @@ class ChatStateManager:
current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) <= threshold

View File

@@ -60,9 +60,10 @@ class Conversation:
self.chat_observer = ChatObserver.get_instance(self.stream_id)
self.chat_observer.start()
self.observation_info = ObservationInfo()
self.observation_info.bind_to_chat_observer(self.stream_id)
self.observation_info.bind_to_chat_observer(self.chat_observer)
# print(self.chat_observer.get_cached_messages(limit=)
# 对话信息
self.conversation_info = ConversationInfo()
except Exception as e:
logger.error(f"初始化对话实例:注册信息组件失败: {e}")
@@ -140,6 +141,7 @@ class Conversation:
if action == "direct_reply":
self.state = ConversationState.GENERATING
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
print(f"生成回复: {self.generated_reply}")
# # 检查回复是否合适
# is_suitable, reason, need_replan = await self.reply_generator.check_reply(
@@ -148,6 +150,7 @@ class Conversation:
# )
if self._check_new_messages_after_planning():
logger.info("333333发现新消息重新考虑行动")
return None
await self._send_reply()
@@ -212,15 +215,9 @@ class Conversation:
logger.warning("没有生成回复")
return
messages = self.chat_observer.get_cached_messages(limit=1)
if not messages:
logger.warning("没有最近的消息可以回复")
return
latest_message = self._convert_to_message(messages[0])
try:
await self.direct_sender.send_message(
chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message
chat_stream=self.chat_stream, content=self.generated_reply
)
self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update():

View File

@@ -1,71 +0,0 @@
from typing import TYPE_CHECKING
from src.common.logger import get_module_logger
from .chat_states import NotificationHandler, Notification, NotificationType
if TYPE_CHECKING:
from .conversation import Conversation
logger = get_module_logger("notification_handler")
class PFCNotificationHandler(NotificationHandler):
"""PFC通知处理器"""
def __init__(self, conversation: "Conversation"):
"""初始化PFC通知处理器
Args:
conversation: 对话实例
"""
self.conversation = conversation
async def handle_notification(self, notification: Notification):
"""处理通知
Args:
notification: 通知对象
"""
logger.debug(f"收到通知: {notification.type.name}, 数据: {notification.data}")
# 根据通知类型执行不同的处理
if notification.type == NotificationType.NEW_MESSAGE:
# 新消息通知
await self._handle_new_message(notification)
elif notification.type == NotificationType.COLD_CHAT:
# 冷聊天通知
await self._handle_cold_chat(notification)
elif notification.type == NotificationType.COMMAND:
# 命令通知
await self._handle_command(notification)
else:
logger.warning(f"未知的通知类型: {notification.type.name}")
async def _handle_new_message(self, notification: Notification):
"""处理新消息通知
Args:
notification: 通知对象
"""
# 更新决策信息
observation_info = self.conversation.observation_info
observation_info.last_message_time = notification.data.get("time", 0)
observation_info.add_unprocessed_message(notification.data)
# 手动触发观察器更新
self.conversation.chat_observer.trigger_update()
async def _handle_cold_chat(self, notification: Notification):
"""处理冷聊天通知
Args:
notification: 通知对象
"""
# 获取冷聊天信息
cold_duration = notification.data.get("duration", 0)
# 更新决策信息
observation_info = self.conversation.observation_info
observation_info.conversation_cold_duration = cold_duration
logger.info(f"对话已冷: {cold_duration}")

View File

@@ -6,7 +6,7 @@ import time
from dataclasses import dataclass, field
from src.common.logger import get_module_logger
from .chat_observer import ChatObserver
from .chat_states import NotificationHandler
from .chat_states import NotificationHandler, NotificationType
logger = get_module_logger("observation_info")
@@ -22,63 +22,70 @@ class ObservationInfoHandler(NotificationHandler):
"""
self.observation_info = observation_info
async def handle_notification(self, notification: Dict[str, Any]):
"""处理通知
Args:
notification: 通知数据
"""
notification_type = notification.get("type")
data = notification.get("data", {})
if notification_type == "NEW_MESSAGE":
async def handle_notification(self, notification):
# 获取通知类型和数据
notification_type = notification.type
data = notification.data
if notification_type == NotificationType.NEW_MESSAGE:
# 处理新消息通知
logger.debug(f"收到新消息通知data: {data}")
message = data.get("message", {})
message_id = data.get("message_id")
processed_plain_text = data.get("processed_plain_text")
detailed_plain_text = data.get("detailed_plain_text")
user_info = data.get("user_info")
time_value = data.get("time")
message = {
"message_id": message_id,
"processed_plain_text": processed_plain_text,
"detailed_plain_text": detailed_plain_text,
"user_info": user_info,
"time": time_value
}
self.observation_info.update_from_message(message)
# self.observation_info.has_unread_messages = True
# self.observation_info.new_unread_message.append(message.get("processed_plain_text", ""))
elif notification_type == "COLD_CHAT":
elif notification_type == NotificationType.COLD_CHAT:
# 处理冷场通知
is_cold = data.get("is_cold", False)
self.observation_info.update_cold_chat_status(is_cold, time.time())
elif notification_type == "ACTIVE_CHAT":
elif notification_type == NotificationType.ACTIVE_CHAT:
# 处理活跃通知
is_active = data.get("is_active", False)
self.observation_info.is_cold = not is_active
elif notification_type == "BOT_SPEAKING":
elif notification_type == NotificationType.BOT_SPEAKING:
# 处理机器人说话通知
self.observation_info.is_typing = False
self.observation_info.last_bot_speak_time = time.time()
elif notification_type == "USER_SPEAKING":
elif notification_type == NotificationType.USER_SPEAKING:
# 处理用户说话通知
self.observation_info.is_typing = False
self.observation_info.last_user_speak_time = time.time()
elif notification_type == "MESSAGE_DELETED":
elif notification_type == NotificationType.MESSAGE_DELETED:
# 处理消息删除通知
message_id = data.get("message_id")
self.observation_info.unprocessed_messages = [
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
]
elif notification_type == "USER_JOINED":
elif notification_type == NotificationType.USER_JOINED:
# 处理用户加入通知
user_id = data.get("user_id")
if user_id:
self.observation_info.active_users.add(user_id)
elif notification_type == "USER_LEFT":
elif notification_type == NotificationType.USER_LEFT:
# 处理用户离开通知
user_id = data.get("user_id")
if user_id:
self.observation_info.active_users.discard(user_id)
elif notification_type == "ERROR":
elif notification_type == NotificationType.ERROR:
# 处理错误通知
error_msg = data.get("error", "")
logger.error(f"收到错误通知: {error_msg}")
@@ -100,6 +107,7 @@ class ObservationInfo:
last_message_content: str = ""
last_message_sender: Optional[str] = None
bot_id: Optional[str] = None
chat_history_count: int = 0
new_messages_count: int = 0
cold_chat_duration: float = 0.0
@@ -117,28 +125,37 @@ class ObservationInfo:
self.chat_observer = None
self.handler = ObservationInfoHandler(self)
def bind_to_chat_observer(self, stream_id: str):
def bind_to_chat_observer(self, chat_observer: ChatObserver):
"""绑定到指定的chat_observer
Args:
stream_id: 聊天流ID
"""
self.chat_observer = ChatObserver.get_instance(stream_id)
self.chat_observer = chat_observer
print(f"1919810----------------------绑定-----------------------------")
print(self.chat_observer)
print(f"1919810--------------------绑定-----------------------------")
print(self.chat_observer.notification_manager)
print(f"1919810-------------------绑定-----------------------------")
self.chat_observer.notification_manager.register_handler(
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
)
self.chat_observer.notification_manager.register_handler(
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
)
print("1919810------------------------绑定-----------------------------")
print(f"1919810--------------------绑定-----------------------------")
print(self.chat_observer.notification_manager)
print(f"1919810-------------------绑定-----------------------------")
def unbind_from_chat_observer(self):
"""解除与chat_observer的绑定"""
if self.chat_observer:
self.chat_observer.notification_manager.unregister_handler(
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler
target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
)
self.chat_observer.notification_manager.unregister_handler(
target="observation_info", notification_type="COLD_CHAT", handler=self.handler
target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
)
self.chat_observer = None
@@ -148,8 +165,11 @@ class ObservationInfo:
Args:
message: 消息数据
"""
print("1919810-----------------------------------------------------")
logger.debug(f"更新信息from_message: {message}")
self.last_message_time = message["time"]
self.last_message_id = message["message_id"]
self.last_message_content = message.get("processed_plain_text", "")
user_info = UserInfo.from_dict(message.get("user_info", {}))
@@ -169,7 +189,6 @@ class ObservationInfo:
def update_changed(self):
"""更新changed状态"""
self.changed = True
# self.meta_plan_trigger = True
def update_cold_chat_status(self, is_cold: bool, current_time: float):
"""更新冷场状态
@@ -223,17 +242,3 @@ class ObservationInfo:
self.unprocessed_messages.clear()
self.new_messages_count = 0
def add_unprocessed_message(self, message: Dict[str, Any]):
"""添加未处理的消息
Args:
message: 消息数据
"""
# 防止重复添加同一消息
message_id = message.get("message_id")
if message_id and not any(m.get("message_id") == message_id for m in self.unprocessed_messages):
self.unprocessed_messages.append(message)
self.new_messages_count += 1
# 同时更新其他消息相关信息
self.update_from_message(message)

View File

@@ -99,19 +99,21 @@ class GoalAnalyzer:
3. 添加新目标
4. 删除不再相关的目标
请以JSON格式输出当前的所有对话目标包含以下字段
请以JSON数组格式输出当前的所有对话目标,每个目标包含以下字段:
1. goal: 对话目标(简短的一句话)
2. reasoning: 对话原因,为什么设定这个目标(简要解释)
输出格式示例:
{{
"goal": "回答用户关于Python编程的具体问题",
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
}},
{{
"goal": "回答用户关于python安装的具体问题",
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答"
}}"""
[
{{
"goal": "回答用户关于Python编程的具体问题",
"reasoning": "用户提出了关于Python的技术问题需要专业且准确的解答"
}},
{{
"goal": "回答用户关于python安装的具体问题",
"reasoning": "用户提出了关于Python的技术问题需要专业且准确的解答"
}}
]"""
logger.debug(f"发送到LLM的提示词: {prompt}")
try:
@@ -120,13 +122,37 @@ class GoalAnalyzer:
except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)}")
content = ""
# 使用简化函数提取JSON内容
# 使用改进后的get_items_from_json函数处理JSON数组
success, result = get_items_from_json(
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
content, "goal", "reasoning",
required_types={"goal": str, "reasoning": str},
allow_array=True
)
# TODO
conversation_info.goal_list.append(result)
if success:
# 判断结果是单个字典还是字典列表
if isinstance(result, list):
# 清空现有目标列表并添加新目标
conversation_info.goal_list = []
for item in result:
goal = item.get("goal", "")
reasoning = item.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning))
# 返回第一个目标作为当前主要目标(如果有)
if result:
first_goal = result[0]
return (first_goal.get("goal", ""), "", first_goal.get("reasoning", ""))
else:
# 单个目标的情况
goal = result.get("goal", "")
reasoning = result.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning))
return (goal, "", reasoning)
# 如果解析失败,返回默认值
return ("", "", "")
async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表

View File

@@ -1,6 +1,6 @@
import json
import re
from typing import Dict, Any, Optional, Tuple
from typing import Dict, Any, Optional, Tuple, List, Union
from src.common.logger import get_module_logger
logger = get_module_logger("pfc_utils")
@@ -11,7 +11,8 @@ def get_items_from_json(
*items: str,
default_values: Optional[Dict[str, Any]] = None,
required_types: Optional[Dict[str, type]] = None,
) -> Tuple[bool, Dict[str, Any]]:
allow_array: bool = True,
) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]:
"""从文本中提取JSON内容并获取指定字段
Args:
@@ -19,18 +20,69 @@ def get_items_from_json(
*items: 要提取的字段名
default_values: 字段的默认值,格式为 {字段名: 默认值}
required_types: 字段的必需类型,格式为 {字段名: 类型}
allow_array: 是否允许解析JSON数组
Returns:
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典)
Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表)
"""
content = content.strip()
result = {}
# 设置默认值
if default_values:
result.update(default_values)
# 尝试解析JSON
# 首先尝试解析JSON数组
if allow_array:
try:
# 尝试找到文本中的JSON数组
array_pattern = r"\[[\s\S]*\]"
array_match = re.search(array_pattern, content)
if array_match:
array_content = array_match.group()
json_array = json.loads(array_content)
# 确认是数组类型
if isinstance(json_array, list):
# 验证数组中的每个项目是否包含所有必需字段
valid_items = []
for item in json_array:
if not isinstance(item, dict):
continue
# 检查是否有所有必需字段
if all(field in item for field in items):
# 验证字段类型
if required_types:
type_valid = True
for field, expected_type in required_types.items():
if field in item and not isinstance(item[field], expected_type):
type_valid = False
break
if not type_valid:
continue
# 验证字符串字段不为空
string_valid = True
for field in items:
if isinstance(item[field], str) and not item[field].strip():
string_valid = False
break
if not string_valid:
continue
valid_items.append(item)
if valid_items:
return True, valid_items
except json.JSONDecodeError:
logger.debug("JSON数组解析失败尝试解析单个JSON对象")
except Exception as e:
logger.debug(f"尝试解析JSON数组时出错: {str(e)}")
# 尝试解析JSON对象
try:
json_data = json.loads(content)
except json.JSONDecodeError:

View File

@@ -26,11 +26,11 @@ logger = get_module_logger("llm_generator", config=llm_config)
class ResponseGenerator:
def __init__(self):
self.model_normal = LLM_request(
model=global_config.llm_normal, temperature=0.3, max_tokens=256, request_type="response_heartflow"
model=global_config.llm_normal, temperature=0.15, max_tokens=256, request_type="response_heartflow"
)
self.model_sum = LLM_request(
model=global_config.llm_summary_by_topic, temperature=0.7, max_tokens=2000, request_type="relation"
model=global_config.llm_summary_by_topic, temperature=0.6, max_tokens=2000, request_type="relation"
)
self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model"