This commit is contained in:
DrSmoothl
2025-04-11 10:55:56 +08:00
10 changed files with 234 additions and 253 deletions

View File

@@ -45,7 +45,19 @@ class ActionPlanner:
# 构建对话目标 # 构建对话目标
if conversation_info.goal_list: if conversation_info.goal_list:
goal, reasoning = conversation_info.goal_list[-1] last_goal = conversation_info.goal_list[-1]
print(last_goal)
# 处理字典或元组格式
if isinstance(last_goal, tuple) and len(last_goal) == 2:
goal, reasoning = last_goal
elif isinstance(last_goal, dict) and 'goal' in last_goal and 'reasoning' in last_goal:
# 处理字典格式
goal = last_goal.get('goal', "目前没有明确对话目标")
reasoning = last_goal.get('reasoning', "目前没有明确对话目标,最好思考一个对话目标")
else:
# 处理未知格式
goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标"
else: else:
goal = "目前没有明确对话目标" goal = "目前没有明确对话目标"
reasoning = "目前没有明确对话目标,最好思考一个对话目标" reasoning = "目前没有明确对话目标,最好思考一个对话目标"
@@ -54,14 +66,14 @@ class ActionPlanner:
chat_history_list = observation_info.chat_history chat_history_list = observation_info.chat_history
chat_history_text = "" chat_history_text = ""
for msg in chat_history_list: for msg in chat_history_list:
chat_history_text += f"{msg}\n" chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
if observation_info.new_messages_count > 0: if observation_info.new_messages_count > 0:
new_messages_list = observation_info.unprocessed_messages new_messages_list = observation_info.unprocessed_messages
chat_history_text += f"{observation_info.new_messages_count}条新消息:\n" chat_history_text += f"{observation_info.new_messages_count}条新消息:\n"
for msg in new_messages_list: for msg in new_messages_list:
chat_history_text += f"{msg}\n" chat_history_text += f"{msg.get('detailed_plain_text', '')}\n"
observation_info.clear_unprocessed_messages() observation_info.clear_unprocessed_messages()

View File

@@ -1,5 +1,6 @@
import time import time
import asyncio import asyncio
import traceback
from typing import Optional, Dict, Any, List, Tuple from typing import Optional, Dict, Any, List, Tuple
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..message.message_base import UserInfo from ..message.message_base import UserInfo
@@ -17,45 +18,39 @@ class ChatObserver:
_instances: Dict[str, "ChatObserver"] = {} _instances: Dict[str, "ChatObserver"] = {}
@classmethod @classmethod
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver": def get_instance(cls, stream_id: str) -> "ChatObserver":
"""获取或创建观察器实例 """获取或创建观察器实例
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
message_storage: 消息存储实现如果为None则使用MongoDB实现
Returns: Returns:
ChatObserver: 观察器实例 ChatObserver: 观察器实例
""" """
if stream_id not in cls._instances: if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id, message_storage) cls._instances[stream_id] = cls(stream_id)
return cls._instances[stream_id] return cls._instances[stream_id]
def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None): def __init__(self, stream_id: str):
"""初始化观察器 """初始化观察器
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
message_storage: 消息存储实现如果为None则使用MongoDB实现
""" """
if stream_id in self._instances: if stream_id in self._instances:
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.") raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
self.stream_id = stream_id self.stream_id = stream_id
self.message_storage = message_storage or MongoDBMessageStorage() self.message_storage = MongoDBMessageStorage()
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 # self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 # self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间 # self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID self.last_message_read: Optional[Dict[str, Any]] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳 self.last_message_time: float = time.time()
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
# 消息历史记录
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
self.last_message_id: Optional[str] = None # 最后一条消息的ID
self.message_count: int = 0 # 消息计数
# 运行状态 # 运行状态
self._running: bool = False self._running: bool = False
@@ -72,7 +67,7 @@ class ChatObserver:
self.is_cold_chat_state: bool = False self.is_cold_chat_state: bool = False
self.update_event = asyncio.Event() self.update_event = asyncio.Event()
self.update_interval = 5 # 更新间隔(秒) self.update_interval = 2 # 更新间隔(秒)
self.message_cache = [] self.message_cache = []
self.update_running = False self.update_running = False
@@ -98,21 +93,17 @@ class ChatObserver:
Args: Args:
message: 消息数据 message: 消息数据
""" """
self.message_history.append(message) try:
self.last_message_id = message["message_id"]
self.last_message_time = message["time"] # 更新最后消息时间
self.message_count += 1
# 更新说话时间
user_info = UserInfo.from_dict(message.get("user_info", {}))
if user_info.user_id == global_config.BOT_QQ:
self.last_bot_speak_time = message["time"]
else:
self.last_user_speak_time = message["time"]
# 发送新消息通知 # 发送新消息通知
notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message) # logger.info(f"发送新ccchandleer消息通知: {message}")
notification = create_new_message_notification(sender="chat_observer", target="observation_info", message=message)
# logger.info(f"发送新消ddddd息通知: {notification}")
# print(self.notification_manager)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
except Exception as e:
logger.error(f"添加消息到历史记录时出错: {e}")
print(traceback.format_exc())
# 检查并更新冷场状态 # 检查并更新冷场状态
await self._check_cold_chat() await self._check_cold_chat()
@@ -140,12 +131,6 @@ class ChatObserver:
notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold) notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold)
await self.notification_manager.send_notification(notification) await self.notification_manager.send_notification(notification)
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read)
for message in messages:
await self._add_message_to_history(message)
return messages, self.message_history
def new_message_after(self, time_point: float) -> bool: def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息 """判断是否在指定时间点后有新消息
@@ -156,9 +141,6 @@ class ChatObserver:
Returns: Returns:
bool: 是否有新消息 bool: 是否有新消息
""" """
if time_point is None:
logger.warning("time_point 为 None返回 False")
return False
if self.last_message_time is None: if self.last_message_time is None:
logger.debug("没有最后消息时间,返回 False") logger.debug("没有最后消息时间,返回 False")
@@ -210,10 +192,13 @@ class ChatObserver:
Returns: Returns:
List[Dict[str, Any]]: 新消息列表 List[Dict[str, Any]]: 新消息列表
""" """
new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read) new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_time)
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]
self.last_message_time = new_messages[-1]["time"]
print(f"获取数据库中找到的新消息: {new_messages}")
return new_messages return new_messages
@@ -231,26 +216,32 @@ class ChatObserver:
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
logger.debug(f"获取指定时间点111之前的消息: {new_messages}")
return new_messages return new_messages
"""主要观察循环""" """主要观察循环"""
async def _update_loop(self): async def _update_loop(self):
"""更新循环""" """更新循环"""
try: # try:
start_time = time.time() # start_time = time.time()
messages = await self._fetch_new_messages_before(start_time) # messages = await self._fetch_new_messages_before(start_time)
for message in messages: # for message in messages:
await self._add_message_to_history(message) # await self._add_message_to_history(message)
except Exception as e: # logger.debug(f"缓冲消息: {messages}")
logger.error(f"缓冲消息出错: {e}") # except Exception as e:
# logger.error(f"缓冲消息出错: {e}")
while self._running: while self._running:
try: try:
# 等待事件或超时1秒 # 等待事件或超时1秒
try: try:
# print("等待事件")
await asyncio.wait_for(self._update_event.wait(), timeout=1) await asyncio.wait_for(self._update_event.wait(), timeout=1)
except asyncio.TimeoutError: except asyncio.TimeoutError:
# print("超时")
pass # 超时后也执行一次检查 pass # 超时后也执行一次检查
self._update_event.clear() # 重置触发事件 self._update_event.clear() # 重置触发事件
@@ -269,6 +260,7 @@ class ChatObserver:
except Exception as e: except Exception as e:
logger.error(f"更新循环出错: {e}") logger.error(f"更新循环出错: {e}")
logger.error(traceback.format_exc())
self._update_complete.set() # 即使出错也要设置完成事件 self._update_complete.set() # 即使出错也要设置完成事件
def trigger_update(self): def trigger_update(self):
@@ -355,51 +347,6 @@ class ChatObserver:
return time_info return time_info
def start_periodic_update(self):
"""启动观察器的定期更新"""
if not self.update_running:
self.update_running = True
asyncio.create_task(self._periodic_update())
async def _periodic_update(self):
"""定期更新消息历史"""
try:
while self.update_running:
await self._update_message_history()
await asyncio.sleep(self.update_interval)
except Exception as e:
logger.error(f"定期更新消息历史时出错: {str(e)}")
async def _update_message_history(self) -> bool:
"""更新消息历史
Returns:
bool: 是否有新消息
"""
try:
messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50)
if not messages:
return False
# 检查是否有新消息
has_new_messages = False
if messages and (
not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]
):
has_new_messages = True
self.message_cache = messages
if has_new_messages:
self.update_event.set()
self.update_event.clear()
return True
return False
except Exception as e:
logger.error(f"更新消息历史时出错: {str(e)}")
return False
def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]: def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]:
"""获取缓存的消息历史 """获取缓存的消息历史
@@ -421,3 +368,6 @@ class ChatObserver:
if not self.message_cache: if not self.message_cache:
return None return None
return self.message_cache[0] return self.message_cache[0]
def __str__(self):
return f"ChatObserver for {self.stream_id}"

View File

@@ -98,11 +98,17 @@ class NotificationManager:
notification_type: 要处理的通知类型 notification_type: 要处理的通知类型
handler: 处理器实例 handler: 处理器实例
""" """
print(1145145511114445551111444)
if target not in self._handlers: if target not in self._handlers:
print("没11有target")
self._handlers[target] = {} self._handlers[target] = {}
if notification_type not in self._handlers[target]: if notification_type not in self._handlers[target]:
print("没11有notification_type")
self._handlers[target][notification_type] = [] self._handlers[target][notification_type] = []
print(self._handlers[target][notification_type])
print(f"注册1111111111111111111111处理器: {target} {notification_type} {handler}")
self._handlers[target][notification_type].append(handler) self._handlers[target][notification_type].append(handler)
print(self._handlers[target][notification_type])
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注销通知处理器 """注销通知处理器
@@ -126,6 +132,7 @@ class NotificationManager:
async def send_notification(self, notification: Notification): async def send_notification(self, notification: Notification):
"""发送通知""" """发送通知"""
self._notification_history.append(notification) self._notification_history.append(notification)
# print("kaishichul-----------------------------------i")
# 如果是状态通知,更新活跃状态 # 如果是状态通知,更新活跃状态
if isinstance(notification, StateNotification): if isinstance(notification, StateNotification):
@@ -134,11 +141,15 @@ class NotificationManager:
else: else:
self._active_states.discard(notification.type) self._active_states.discard(notification.type)
# 调用目标接收者的处理器 # 调用目标接收者的处理器
target = notification.target target = notification.target
if target in self._handlers: if target in self._handlers:
handlers = self._handlers[target].get(notification.type, []) handlers = self._handlers[target].get(notification.type, [])
# print(1111111)
print(handlers)
for handler in handlers: for handler in handlers:
print(f"调用处理器: {handler}")
await handler.handle_notification(notification) await handler.handle_notification(notification)
def get_active_states(self) -> Set[NotificationType]: def get_active_states(self) -> Set[NotificationType]:
@@ -171,6 +182,13 @@ class NotificationManager:
return history return history
def __str__(self):
str = ""
for target, handlers in self._handlers.items():
for notification_type, handler_list in handlers.items():
str += f"NotificationManager for {target} {notification_type} {handler_list}"
return str
# 一些常用的通知创建函数 # 一些常用的通知创建函数
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification: def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
@@ -182,8 +200,9 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str,
target=target, target=target,
data={ data={
"message_id": message.get("message_id"), "message_id": message.get("message_id"),
"content": message.get("content"), "processed_plain_text": message.get("processed_plain_text"),
"sender": message.get("sender"), "detailed_plain_text": message.get("detailed_plain_text"),
"user_info": message.get("user_info"),
"time": message.get("time"), "time": message.get("time"),
}, },
) )
@@ -276,3 +295,5 @@ class ChatStateManager:
current_time = datetime.now().timestamp() current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) <= threshold return (current_time - self.state_info.last_message_time) <= threshold

View File

@@ -60,9 +60,10 @@ class Conversation:
self.chat_observer = ChatObserver.get_instance(self.stream_id) self.chat_observer = ChatObserver.get_instance(self.stream_id)
self.chat_observer.start() self.chat_observer.start()
self.observation_info = ObservationInfo() self.observation_info = ObservationInfo()
self.observation_info.bind_to_chat_observer(self.stream_id) self.observation_info.bind_to_chat_observer(self.chat_observer)
# print(self.chat_observer.get_cached_messages(limit=)
# 对话信息
self.conversation_info = ConversationInfo() self.conversation_info = ConversationInfo()
except Exception as e: except Exception as e:
logger.error(f"初始化对话实例:注册信息组件失败: {e}") logger.error(f"初始化对话实例:注册信息组件失败: {e}")
@@ -140,6 +141,7 @@ class Conversation:
if action == "direct_reply": if action == "direct_reply":
self.state = ConversationState.GENERATING self.state = ConversationState.GENERATING
self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info) self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info)
print(f"生成回复: {self.generated_reply}")
# # 检查回复是否合适 # # 检查回复是否合适
# is_suitable, reason, need_replan = await self.reply_generator.check_reply( # is_suitable, reason, need_replan = await self.reply_generator.check_reply(
@@ -148,6 +150,7 @@ class Conversation:
# ) # )
if self._check_new_messages_after_planning(): if self._check_new_messages_after_planning():
logger.info("333333发现新消息重新考虑行动")
return None return None
await self._send_reply() await self._send_reply()
@@ -212,15 +215,9 @@ class Conversation:
logger.warning("没有生成回复") logger.warning("没有生成回复")
return return
messages = self.chat_observer.get_cached_messages(limit=1)
if not messages:
logger.warning("没有最近的消息可以回复")
return
latest_message = self._convert_to_message(messages[0])
try: try:
await self.direct_sender.send_message( await self.direct_sender.send_message(
chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message chat_stream=self.chat_stream, content=self.generated_reply
) )
self.chat_observer.trigger_update() # 触发立即更新 self.chat_observer.trigger_update() # 触发立即更新
if not await self.chat_observer.wait_for_update(): if not await self.chat_observer.wait_for_update():

View File

@@ -1,18 +1,18 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional from typing import List, Dict, Any, Optional
from src.common.database import db from src.common.database import db
import time
class MessageStorage(ABC): class MessageStorage(ABC):
"""消息存储接口""" """消息存储接口"""
@abstractmethod @abstractmethod
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: async def get_messages_after(self, chat_id: str, message: Dict[str, Any]) -> List[Dict[str, Any]]:
"""获取指定消息ID之后的所有消息 """获取指定消息ID之后的所有消息
Args: Args:
chat_id: 聊天ID chat_id: 聊天ID
message_id: 消息ID如果为None则获取所有消息 message: 消息
Returns: Returns:
List[Dict[str, Any]]: 消息列表 List[Dict[str, Any]]: 消息列表
@@ -53,14 +53,11 @@ class MongoDBMessageStorage(MessageStorage):
def __init__(self): def __init__(self):
self.db = db self.db = db
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: async def get_messages_after(self, chat_id: str, message_time: float) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id} query = {"chat_id": chat_id}
print(f"storage_check_message: {message_time}")
if message_id: query["time"] = {"$gt": message_time}
# 获取ID大于message_id的消息
last_message = self.db.messages.find_one({"message_id": message_id})
if last_message:
query["time"] = {"$gt": last_message["time"]}
return list(self.db.messages.find(query).sort("time", 1)) return list(self.db.messages.find(query).sort("time", 1))

View File

@@ -1,71 +0,0 @@
from typing import TYPE_CHECKING
from src.common.logger import get_module_logger
from .chat_states import NotificationHandler, Notification, NotificationType
if TYPE_CHECKING:
from .conversation import Conversation
logger = get_module_logger("notification_handler")
class PFCNotificationHandler(NotificationHandler):
"""PFC通知处理器"""
def __init__(self, conversation: "Conversation"):
"""初始化PFC通知处理器
Args:
conversation: 对话实例
"""
self.conversation = conversation
async def handle_notification(self, notification: Notification):
"""处理通知
Args:
notification: 通知对象
"""
logger.debug(f"收到通知: {notification.type.name}, 数据: {notification.data}")
# 根据通知类型执行不同的处理
if notification.type == NotificationType.NEW_MESSAGE:
# 新消息通知
await self._handle_new_message(notification)
elif notification.type == NotificationType.COLD_CHAT:
# 冷聊天通知
await self._handle_cold_chat(notification)
elif notification.type == NotificationType.COMMAND:
# 命令通知
await self._handle_command(notification)
else:
logger.warning(f"未知的通知类型: {notification.type.name}")
async def _handle_new_message(self, notification: Notification):
"""处理新消息通知
Args:
notification: 通知对象
"""
# 更新决策信息
observation_info = self.conversation.observation_info
observation_info.last_message_time = notification.data.get("time", 0)
observation_info.add_unprocessed_message(notification.data)
# 手动触发观察器更新
self.conversation.chat_observer.trigger_update()
async def _handle_cold_chat(self, notification: Notification):
"""处理冷聊天通知
Args:
notification: 通知对象
"""
# 获取冷聊天信息
cold_duration = notification.data.get("duration", 0)
# 更新决策信息
observation_info = self.conversation.observation_info
observation_info.conversation_cold_duration = cold_duration
logger.info(f"对话已冷: {cold_duration}")

View File

@@ -6,7 +6,7 @@ import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from .chat_observer import ChatObserver from .chat_observer import ChatObserver
from .chat_states import NotificationHandler from .chat_states import NotificationHandler, NotificationType
logger = get_module_logger("observation_info") logger = get_module_logger("observation_info")
@@ -22,63 +22,70 @@ class ObservationInfoHandler(NotificationHandler):
""" """
self.observation_info = observation_info self.observation_info = observation_info
async def handle_notification(self, notification: Dict[str, Any]): async def handle_notification(self, notification):
"""处理通知 # 获取通知类型和数据
notification_type = notification.type
data = notification.data
Args: if notification_type == NotificationType.NEW_MESSAGE:
notification: 通知数据
"""
notification_type = notification.get("type")
data = notification.get("data", {})
if notification_type == "NEW_MESSAGE":
# 处理新消息通知 # 处理新消息通知
logger.debug(f"收到新消息通知data: {data}") logger.debug(f"收到新消息通知data: {data}")
message = data.get("message", {}) message_id = data.get("message_id")
self.observation_info.update_from_message(message) processed_plain_text = data.get("processed_plain_text")
# self.observation_info.has_unread_messages = True detailed_plain_text = data.get("detailed_plain_text")
# self.observation_info.new_unread_message.append(message.get("processed_plain_text", "")) user_info = data.get("user_info")
time_value = data.get("time")
elif notification_type == "COLD_CHAT": message = {
"message_id": message_id,
"processed_plain_text": processed_plain_text,
"detailed_plain_text": detailed_plain_text,
"user_info": user_info,
"time": time_value
}
self.observation_info.update_from_message(message)
elif notification_type == NotificationType.COLD_CHAT:
# 处理冷场通知 # 处理冷场通知
is_cold = data.get("is_cold", False) is_cold = data.get("is_cold", False)
self.observation_info.update_cold_chat_status(is_cold, time.time()) self.observation_info.update_cold_chat_status(is_cold, time.time())
elif notification_type == "ACTIVE_CHAT": elif notification_type == NotificationType.ACTIVE_CHAT:
# 处理活跃通知 # 处理活跃通知
is_active = data.get("is_active", False) is_active = data.get("is_active", False)
self.observation_info.is_cold = not is_active self.observation_info.is_cold = not is_active
elif notification_type == "BOT_SPEAKING": elif notification_type == NotificationType.BOT_SPEAKING:
# 处理机器人说话通知 # 处理机器人说话通知
self.observation_info.is_typing = False self.observation_info.is_typing = False
self.observation_info.last_bot_speak_time = time.time() self.observation_info.last_bot_speak_time = time.time()
elif notification_type == "USER_SPEAKING": elif notification_type == NotificationType.USER_SPEAKING:
# 处理用户说话通知 # 处理用户说话通知
self.observation_info.is_typing = False self.observation_info.is_typing = False
self.observation_info.last_user_speak_time = time.time() self.observation_info.last_user_speak_time = time.time()
elif notification_type == "MESSAGE_DELETED": elif notification_type == NotificationType.MESSAGE_DELETED:
# 处理消息删除通知 # 处理消息删除通知
message_id = data.get("message_id") message_id = data.get("message_id")
self.observation_info.unprocessed_messages = [ self.observation_info.unprocessed_messages = [
msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id
] ]
elif notification_type == "USER_JOINED": elif notification_type == NotificationType.USER_JOINED:
# 处理用户加入通知 # 处理用户加入通知
user_id = data.get("user_id") user_id = data.get("user_id")
if user_id: if user_id:
self.observation_info.active_users.add(user_id) self.observation_info.active_users.add(user_id)
elif notification_type == "USER_LEFT": elif notification_type == NotificationType.USER_LEFT:
# 处理用户离开通知 # 处理用户离开通知
user_id = data.get("user_id") user_id = data.get("user_id")
if user_id: if user_id:
self.observation_info.active_users.discard(user_id) self.observation_info.active_users.discard(user_id)
elif notification_type == "ERROR": elif notification_type == NotificationType.ERROR:
# 处理错误通知 # 处理错误通知
error_msg = data.get("error", "") error_msg = data.get("error", "")
logger.error(f"收到错误通知: {error_msg}") logger.error(f"收到错误通知: {error_msg}")
@@ -100,6 +107,7 @@ class ObservationInfo:
last_message_content: str = "" last_message_content: str = ""
last_message_sender: Optional[str] = None last_message_sender: Optional[str] = None
bot_id: Optional[str] = None bot_id: Optional[str] = None
chat_history_count: int = 0
new_messages_count: int = 0 new_messages_count: int = 0
cold_chat_duration: float = 0.0 cold_chat_duration: float = 0.0
@@ -117,28 +125,29 @@ class ObservationInfo:
self.chat_observer = None self.chat_observer = None
self.handler = ObservationInfoHandler(self) self.handler = ObservationInfoHandler(self)
def bind_to_chat_observer(self, stream_id: str): def bind_to_chat_observer(self, chat_observer: ChatObserver):
"""绑定到指定的chat_observer """绑定到指定的chat_observer
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
""" """
self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer = chat_observer
self.chat_observer.notification_manager.register_handler( self.chat_observer.notification_manager.register_handler(
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
) )
self.chat_observer.notification_manager.register_handler( self.chat_observer.notification_manager.register_handler(
target="observation_info", notification_type="COLD_CHAT", handler=self.handler target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
) )
print("1919810------------------------绑定-----------------------------")
def unbind_from_chat_observer(self): def unbind_from_chat_observer(self):
"""解除与chat_observer的绑定""" """解除与chat_observer的绑定"""
if self.chat_observer: if self.chat_observer:
self.chat_observer.notification_manager.unregister_handler( self.chat_observer.notification_manager.unregister_handler(
target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler target="observation_info", notification_type=NotificationType.NEW_MESSAGE, handler=self.handler
) )
self.chat_observer.notification_manager.unregister_handler( self.chat_observer.notification_manager.unregister_handler(
target="observation_info", notification_type="COLD_CHAT", handler=self.handler target="observation_info", notification_type=NotificationType.COLD_CHAT, handler=self.handler
) )
self.chat_observer = None self.chat_observer = None
@@ -148,8 +157,11 @@ class ObservationInfo:
Args: Args:
message: 消息数据 message: 消息数据
""" """
print("1919810-----------------------------------------------------")
logger.debug(f"更新信息from_message: {message}") logger.debug(f"更新信息from_message: {message}")
self.last_message_time = message["time"] self.last_message_time = message["time"]
self.last_message_id = message["message_id"]
self.last_message_content = message.get("processed_plain_text", "") self.last_message_content = message.get("processed_plain_text", "")
user_info = UserInfo.from_dict(message.get("user_info", {})) user_info = UserInfo.from_dict(message.get("user_info", {}))
@@ -169,7 +181,6 @@ class ObservationInfo:
def update_changed(self): def update_changed(self):
"""更新changed状态""" """更新changed状态"""
self.changed = True self.changed = True
# self.meta_plan_trigger = True
def update_cold_chat_status(self, is_cold: bool, current_time: float): def update_cold_chat_status(self, is_cold: bool, current_time: float):
"""更新冷场状态 """更新冷场状态
@@ -216,24 +227,10 @@ class ObservationInfo:
"""清空未处理消息列表""" """清空未处理消息列表"""
# 将未处理消息添加到历史记录中 # 将未处理消息添加到历史记录中
for message in self.unprocessed_messages: for message in self.unprocessed_messages:
if "processed_plain_text" in message: self.chat_history.append(message)
self.chat_history.append(message["processed_plain_text"])
# 清空未处理消息列表 # 清空未处理消息列表
self.has_unread_messages = False self.has_unread_messages = False
self.unprocessed_messages.clear() self.unprocessed_messages.clear()
self.chat_history_count = len(self.chat_history)
self.new_messages_count = 0 self.new_messages_count = 0
def add_unprocessed_message(self, message: Dict[str, Any]):
"""添加未处理的消息
Args:
message: 消息数据
"""
# 防止重复添加同一消息
message_id = message.get("message_id")
if message_id and not any(m.get("message_id") == message_id for m in self.unprocessed_messages):
self.unprocessed_messages.append(message)
self.new_messages_count += 1
# 同时更新其他消息相关信息
self.update_from_message(message)

View File

@@ -99,19 +99,21 @@ class GoalAnalyzer:
3. 添加新目标 3. 添加新目标
4. 删除不再相关的目标 4. 删除不再相关的目标
请以JSON格式输出当前的所有对话目标包含以下字段 请以JSON数组格式输出当前的所有对话目标,每个目标包含以下字段:
1. goal: 对话目标(简短的一句话) 1. goal: 对话目标(简短的一句话)
2. reasoning: 对话原因,为什么设定这个目标(简要解释) 2. reasoning: 对话原因,为什么设定这个目标(简要解释)
输出格式示例: 输出格式示例:
{{ [
"goal": "回答用户关于Python编程的具体问题", {{
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答" "goal": "回答用户关于Python编程的具体问题",
}}, "reasoning": "用户提出了关于Python的技术问题需要专业且准确的解答"
{{ }},
"goal": "回答用户关于python安装的具体问题", {{
"reasoning": "用户提出了关于Python的技术问题,需要专业且准确的解答" "goal": "回答用户关于python安装的具体问题",
}}""" "reasoning": "用户提出了关于Python的技术问题需要专业且准确的解答"
}}
]"""
logger.debug(f"发送到LLM的提示词: {prompt}") logger.debug(f"发送到LLM的提示词: {prompt}")
try: try:
@@ -120,13 +122,37 @@ class GoalAnalyzer:
except Exception as e: except Exception as e:
logger.error(f"分析对话目标时出错: {str(e)}") logger.error(f"分析对话目标时出错: {str(e)}")
content = "" content = ""
# 使用简化函数提取JSON内容
success, result = get_items_from_json(
content, "goal", "reasoning", required_types={"goal": str, "reasoning": str}
)
# TODO
conversation_info.goal_list.append(result) # 使用改进后的get_items_from_json函数处理JSON数组
success, result = get_items_from_json(
content, "goal", "reasoning",
required_types={"goal": str, "reasoning": str},
allow_array=True
)
if success:
# 判断结果是单个字典还是字典列表
if isinstance(result, list):
# 清空现有目标列表并添加新目标
conversation_info.goal_list = []
for item in result:
goal = item.get("goal", "")
reasoning = item.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning))
# 返回第一个目标作为当前主要目标(如果有)
if result:
first_goal = result[0]
return (first_goal.get("goal", ""), "", first_goal.get("reasoning", ""))
else:
# 单个目标的情况
goal = result.get("goal", "")
reasoning = result.get("reasoning", "")
conversation_info.goal_list.append((goal, reasoning))
return (goal, "", reasoning)
# 如果解析失败,返回默认值
return ("", "", "")
async def _update_goals(self, new_goal: str, method: str, reasoning: str): async def _update_goals(self, new_goal: str, method: str, reasoning: str):
"""更新目标列表 """更新目标列表

View File

@@ -1,6 +1,6 @@
import json import json
import re import re
from typing import Dict, Any, Optional, Tuple from typing import Dict, Any, Optional, Tuple, List, Union
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
logger = get_module_logger("pfc_utils") logger = get_module_logger("pfc_utils")
@@ -11,7 +11,8 @@ def get_items_from_json(
*items: str, *items: str,
default_values: Optional[Dict[str, Any]] = None, default_values: Optional[Dict[str, Any]] = None,
required_types: Optional[Dict[str, type]] = None, required_types: Optional[Dict[str, type]] = None,
) -> Tuple[bool, Dict[str, Any]]: allow_array: bool = True,
) -> Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]:
"""从文本中提取JSON内容并获取指定字段 """从文本中提取JSON内容并获取指定字段
Args: Args:
@@ -19,9 +20,10 @@ def get_items_from_json(
*items: 要提取的字段名 *items: 要提取的字段名
default_values: 字段的默认值,格式为 {字段名: 默认值} default_values: 字段的默认值,格式为 {字段名: 默认值}
required_types: 字段的必需类型,格式为 {字段名: 类型} required_types: 字段的必需类型,格式为 {字段名: 类型}
allow_array: 是否允许解析JSON数组
Returns: Returns:
Tuple[bool, Dict[str, Any]]: (是否成功, 提取的字段字典) Tuple[bool, Union[Dict[str, Any], List[Dict[str, Any]]]]: (是否成功, 提取的字段字典或字典列表)
""" """
content = content.strip() content = content.strip()
result = {} result = {}
@@ -30,7 +32,57 @@ def get_items_from_json(
if default_values: if default_values:
result.update(default_values) result.update(default_values)
# 尝试解析JSON # 首先尝试解析JSON数组
if allow_array:
try:
# 尝试找到文本中的JSON数组
array_pattern = r"\[[\s\S]*\]"
array_match = re.search(array_pattern, content)
if array_match:
array_content = array_match.group()
json_array = json.loads(array_content)
# 确认是数组类型
if isinstance(json_array, list):
# 验证数组中的每个项目是否包含所有必需字段
valid_items = []
for item in json_array:
if not isinstance(item, dict):
continue
# 检查是否有所有必需字段
if all(field in item for field in items):
# 验证字段类型
if required_types:
type_valid = True
for field, expected_type in required_types.items():
if field in item and not isinstance(item[field], expected_type):
type_valid = False
break
if not type_valid:
continue
# 验证字符串字段不为空
string_valid = True
for field in items:
if isinstance(item[field], str) and not item[field].strip():
string_valid = False
break
if not string_valid:
continue
valid_items.append(item)
if valid_items:
return True, valid_items
except json.JSONDecodeError:
logger.debug("JSON数组解析失败尝试解析单个JSON对象")
except Exception as e:
logger.debug(f"尝试解析JSON数组时出错: {str(e)}")
# 尝试解析JSON对象
try: try:
json_data = json.loads(content) json_data = json.loads(content)
except json.JSONDecodeError: except json.JSONDecodeError:

View File

@@ -26,11 +26,11 @@ logger = get_module_logger("llm_generator", config=llm_config)
class ResponseGenerator: class ResponseGenerator:
def __init__(self): def __init__(self):
self.model_normal = LLM_request( self.model_normal = LLM_request(
model=global_config.llm_normal, temperature=0.3, max_tokens=256, request_type="response_heartflow" model=global_config.llm_normal, temperature=0.15, max_tokens=256, request_type="response_heartflow"
) )
self.model_sum = LLM_request( self.model_sum = LLM_request(
model=global_config.llm_summary_by_topic, temperature=0.7, max_tokens=2000, request_type="relation" model=global_config.llm_summary_by_topic, temperature=0.6, max_tokens=2000, request_type="relation"
) )
self.current_model_type = "r1" # 默认使用 R1 self.current_model_type = "r1" # 默认使用 R1
self.current_model_name = "unknown model" self.current_model_name = "unknown model"