better:尝试重构pfc

This commit is contained in:
SengokuCola
2025-04-08 00:23:08 +08:00
parent 0d7068acab
commit cc190ac2b9
6 changed files with 797 additions and 308 deletions

View File

@@ -5,6 +5,8 @@ from src.common.logger import get_module_logger
from src.common.database import db from src.common.database import db
from ..message.message_base import UserInfo from ..message.message_base import UserInfo
from ..config.config import global_config from ..config.config import global_config
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
from .message_storage import MessageStorage, MongoDBMessageStorage
logger = get_module_logger("chat_observer") logger = get_module_logger("chat_observer")
@@ -15,36 +17,40 @@ class ChatObserver:
_instances: Dict[str, 'ChatObserver'] = {} _instances: Dict[str, 'ChatObserver'] = {}
@classmethod @classmethod
def get_instance(cls, stream_id: str) -> 'ChatObserver': def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
"""获取或创建观察器实例 """获取或创建观察器实例
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
message_storage: 消息存储实现如果为None则使用MongoDB实现
Returns: Returns:
ChatObserver: 观察器实例 ChatObserver: 观察器实例
""" """
if stream_id not in cls._instances: if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id) cls._instances[stream_id] = cls(stream_id, message_storage)
return cls._instances[stream_id] return cls._instances[stream_id]
def __init__(self, stream_id: str): def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None):
"""初始化观察器 """初始化观察器
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
message_storage: 消息存储实现如果为None则使用MongoDB实现
""" """
if stream_id in self._instances: if stream_id in self._instances:
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.") raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
self.stream_id = stream_id self.stream_id = stream_id
self.message_storage = message_storage or MongoDBMessageStorage()
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
self.last_check_time: float = time.time() # 上次查看聊天记录时间 self.last_check_time: float = time.time() # 上次查看聊天记录时间
self.last_message_read: Optional[str] = None # 最后读取的消息ID self.last_message_read: Optional[str] = None # 最后读取的消息ID
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳 self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
self.waiting_start_time: Optional[float] = None # 等待开始时间 self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
# 消息历史记录 # 消息历史记录
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史 self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
@@ -57,7 +63,15 @@ class ChatObserver:
self._update_event = asyncio.Event() # 触发更新的事件 self._update_event = asyncio.Event() # 触发更新的事件
self._update_complete = asyncio.Event() # 更新完成的事件 self._update_complete = asyncio.Event() # 更新完成的事件
def check(self) -> bool: # 通知管理器
self.notification_manager = NotificationManager()
# 冷场检查配置
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
self.last_cold_chat_check: float = time.time()
self.is_cold_chat_state: bool = False
async def check(self) -> bool:
"""检查距离上一次观察之后是否有了新消息 """检查距离上一次观察之后是否有了新消息
Returns: Returns:
@@ -65,13 +79,10 @@ class ChatObserver:
""" """
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}") logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
query = { new_message_exists = await self.message_storage.has_new_messages(
"chat_id": self.stream_id, self.stream_id,
"time": {"$gt": self.last_check_time} self.last_check_time
} )
# 只需要查询是否存在,不需要获取具体消息
new_message_exists = db.messages.find_one(query) is not None
if new_message_exists: if new_message_exists:
logger.debug("发现新消息") logger.debug("发现新消息")
@@ -79,27 +90,8 @@ class ChatObserver:
return new_message_exists return new_message_exists
def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: async def _add_message_to_history(self, message: Dict[str, Any]):
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象""" """添加消息到历史记录并发送通知
messages = self.get_message_history(self.last_check_time)
for message in messages:
self._add_message_to_history(message)
return messages, self.message_history
def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息
Args:
time_point: 时间戳
Returns:
bool: 是否有新消息
"""
logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point}")
return self.last_message_time is None or self.last_message_time > time_point
def _add_message_to_history(self, message: Dict[str, Any]):
"""添加消息到历史记录
Args: Args:
message: 消息数据 message: 消息数据
@@ -116,6 +108,75 @@ class ChatObserver:
else: else:
self.last_user_speak_time = message["time"] self.last_user_speak_time = message["time"]
# 发送新消息通知
notification = create_new_message_notification(
sender="chat_observer",
target="pfc",
message=message
)
await self.notification_manager.send_notification(notification)
# 检查并更新冷场状态
await self._check_cold_chat()
async def _check_cold_chat(self):
"""检查是否处于冷场状态并发送通知"""
current_time = time.time()
# 每10秒检查一次冷场状态
if current_time - self.last_cold_chat_check < 10:
return
self.last_cold_chat_check = current_time
# 判断是否冷场
is_cold = False
if self.last_message_time is None:
is_cold = True
else:
is_cold = (current_time - self.last_message_time) > self.cold_chat_threshold
# 如果冷场状态发生变化,发送通知
if is_cold != self.is_cold_chat_state:
self.is_cold_chat_state = is_cold
notification = create_cold_chat_notification(
sender="chat_observer",
target="pfc",
is_cold=is_cold
)
await self.notification_manager.send_notification(notification)
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
messages = await self.message_storage.get_messages_after(
self.stream_id,
self.last_message_read
)
for message in messages:
await self._add_message_to_history(message)
return messages, self.message_history
def new_message_after(self, time_point: float) -> bool:
"""判断是否在指定时间点后有新消息
Args:
time_point: 时间戳
Returns:
bool: 是否有新消息
"""
if time_point is None:
logger.warning("time_point 为 None返回 False")
return False
if self.last_message_time is None:
logger.debug("没有最后消息时间,返回 False")
return False
has_new = self.last_message_time > time_point
logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}")
return has_new
def get_message_history( def get_message_history(
self, self,
start_time: Optional[float] = None, start_time: Optional[float] = None,
@@ -159,15 +220,9 @@ class ChatObserver:
Returns: Returns:
List[Dict[str, Any]]: 新消息列表 List[Dict[str, Any]]: 新消息列表
""" """
query = {"chat_id": self.stream_id} new_messages = await self.message_storage.get_messages_after(
if self.last_message_read: self.stream_id,
# 获取ID大于last_message_read的消息 self.last_message_read
last_message = db.messages.find_one({"message_id": self.last_message_read})
if last_message:
query["time"] = {"$gt": last_message["time"]}
new_messages = list(
db.messages.find(query).sort("time", 1)
) )
if new_messages: if new_messages:
@@ -184,30 +239,24 @@ class ChatObserver:
Returns: Returns:
List[Dict[str, Any]]: 最多5条消息 List[Dict[str, Any]]: 最多5条消息
""" """
query = { new_messages = await self.message_storage.get_messages_before(
"chat_id": self.stream_id, self.stream_id,
"time": {"$lt": time_point} time_point
}
new_messages = list(
db.messages.find(query).sort("time", -1).limit(5) # 倒序获取5条
) )
# 将消息按时间正序排列
new_messages.reverse()
if new_messages: if new_messages:
self.last_message_read = new_messages[-1]["message_id"] self.last_message_read = new_messages[-1]["message_id"]
return new_messages return new_messages
'''主要观察循环'''
async def _update_loop(self): async def _update_loop(self):
"""更新循环""" """更新循环"""
try: try:
start_time = time.time() start_time = time.time()
messages = await self._fetch_new_messages_before(start_time) messages = await self._fetch_new_messages_before(start_time)
for message in messages: for message in messages:
self._add_message_to_history(message) await self._add_message_to_history(message)
except Exception as e: except Exception as e:
logger.error(f"缓冲消息出错: {e}") logger.error(f"缓冲消息出错: {e}")
@@ -228,7 +277,7 @@ class ChatObserver:
if new_messages: if new_messages:
# 处理新消息 # 处理新消息
for message in new_messages: for message in new_messages:
self._add_message_to_history(message) await self._add_message_to_history(message)
# 设置完成事件 # 设置完成事件
self._update_complete.set() self._update_complete.set()

View File

@@ -0,0 +1,267 @@
from enum import Enum, auto
from typing import Optional, Dict, Any, List, Set
from dataclasses import dataclass
from datetime import datetime
from abc import ABC, abstractmethod
class ChatState(Enum):
"""聊天状态枚举"""
NORMAL = auto() # 正常状态
NEW_MESSAGE = auto() # 有新消息
COLD_CHAT = auto() # 冷场状态
ACTIVE_CHAT = auto() # 活跃状态
BOT_SPEAKING = auto() # 机器人正在说话
USER_SPEAKING = auto() # 用户正在说话
SILENT = auto() # 沉默状态
ERROR = auto() # 错误状态
class NotificationType(Enum):
"""通知类型枚举"""
NEW_MESSAGE = auto() # 新消息通知
COLD_CHAT = auto() # 冷场通知
ACTIVE_CHAT = auto() # 活跃通知
BOT_SPEAKING = auto() # 机器人说话通知
USER_SPEAKING = auto() # 用户说话通知
MESSAGE_DELETED = auto() # 消息删除通知
USER_JOINED = auto() # 用户加入通知
USER_LEFT = auto() # 用户离开通知
ERROR = auto() # 错误通知
@dataclass
class ChatStateInfo:
"""聊天状态信息"""
state: ChatState
last_message_time: Optional[float] = None
last_message_content: Optional[str] = None
last_speaker: Optional[str] = None
message_count: int = 0
cold_duration: float = 0.0 # 冷场持续时间(秒)
active_duration: float = 0.0 # 活跃持续时间(秒)
@dataclass
class Notification:
"""通知基类"""
type: NotificationType
timestamp: float
sender: str # 发送者标识
target: str # 接收者标识
data: Dict[str, Any]
def to_dict(self) -> Dict[str, Any]:
"""转换为字典格式"""
return {
"type": self.type.name,
"timestamp": self.timestamp,
"data": self.data
}
@dataclass
class StateNotification(Notification):
"""持续状态通知"""
is_active: bool = True
def to_dict(self) -> Dict[str, Any]:
base_dict = super().to_dict()
base_dict["is_active"] = self.is_active
return base_dict
class NotificationHandler(ABC):
"""通知处理器接口"""
@abstractmethod
async def handle_notification(self, notification: Notification):
"""处理通知"""
pass
class NotificationManager:
"""通知管理器"""
def __init__(self):
# 按接收者和通知类型存储处理器
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
self._active_states: Set[NotificationType] = set()
self._notification_history: List[Notification] = []
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注册通知处理器
Args:
target: 接收者标识(例如:"pfc"
notification_type: 要处理的通知类型
handler: 处理器实例
"""
if target not in self._handlers:
self._handlers[target] = {}
if notification_type not in self._handlers[target]:
self._handlers[target][notification_type] = []
self._handlers[target][notification_type].append(handler)
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
"""注销通知处理器
Args:
target: 接收者标识
notification_type: 通知类型
handler: 要注销的处理器实例
"""
if target in self._handlers and notification_type in self._handlers[target]:
handlers = self._handlers[target][notification_type]
if handler in handlers:
handlers.remove(handler)
# 如果该类型的处理器列表为空,删除该类型
if not handlers:
del self._handlers[target][notification_type]
# 如果该目标没有任何处理器,删除该目标
if not self._handlers[target]:
del self._handlers[target]
async def send_notification(self, notification: Notification):
"""发送通知"""
self._notification_history.append(notification)
# 如果是状态通知,更新活跃状态
if isinstance(notification, StateNotification):
if notification.is_active:
self._active_states.add(notification.type)
else:
self._active_states.discard(notification.type)
# 调用目标接收者的处理器
target = notification.target
if target in self._handlers:
handlers = self._handlers[target].get(notification.type, [])
for handler in handlers:
await handler.handle_notification(notification)
def get_active_states(self) -> Set[NotificationType]:
"""获取当前活跃的状态"""
return self._active_states.copy()
def is_state_active(self, state_type: NotificationType) -> bool:
"""检查特定状态是否活跃"""
return state_type in self._active_states
def get_notification_history(self,
sender: Optional[str] = None,
target: Optional[str] = None,
limit: Optional[int] = None) -> List[Notification]:
"""获取通知历史
Args:
sender: 过滤特定发送者的通知
target: 过滤特定接收者的通知
limit: 限制返回数量
"""
history = self._notification_history
if sender:
history = [n for n in history if n.sender == sender]
if target:
history = [n for n in history if n.target == target]
if limit is not None:
history = history[-limit:]
return history
# 一些常用的通知创建函数
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
"""创建新消息通知"""
return Notification(
type=NotificationType.NEW_MESSAGE,
timestamp=datetime.now().timestamp(),
sender=sender,
target=target,
data={
"message_id": message.get("message_id"),
"content": message.get("content"),
"sender": message.get("sender"),
"time": message.get("time")
}
)
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
"""创建冷场状态通知"""
return StateNotification(
type=NotificationType.COLD_CHAT,
timestamp=datetime.now().timestamp(),
sender=sender,
target=target,
data={"is_cold": is_cold},
is_active=is_cold
)
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
"""创建活跃状态通知"""
return StateNotification(
type=NotificationType.ACTIVE_CHAT,
timestamp=datetime.now().timestamp(),
sender=sender,
target=target,
data={"is_active": is_active},
is_active=is_active
)
class ChatStateManager:
"""聊天状态管理器"""
def __init__(self):
self.current_state = ChatState.NORMAL
self.state_info = ChatStateInfo(state=ChatState.NORMAL)
self.state_history: list[ChatStateInfo] = []
def update_state(self, new_state: ChatState, **kwargs):
"""更新聊天状态
Args:
new_state: 新的状态
**kwargs: 其他状态信息
"""
self.current_state = new_state
self.state_info.state = new_state
# 更新其他状态信息
for key, value in kwargs.items():
if hasattr(self.state_info, key):
setattr(self.state_info, key, value)
# 记录状态历史
self.state_history.append(self.state_info)
def get_current_state_info(self) -> ChatStateInfo:
"""获取当前状态信息"""
return self.state_info
def get_state_history(self) -> list[ChatStateInfo]:
"""获取状态历史"""
return self.state_history
def is_cold_chat(self, threshold: float = 60.0) -> bool:
"""判断是否处于冷场状态
Args:
threshold: 冷场阈值(秒)
Returns:
bool: 是否冷场
"""
if not self.state_info.last_message_time:
return True
current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) > threshold
def is_active_chat(self, threshold: float = 5.0) -> bool:
"""判断是否处于活跃状态
Args:
threshold: 活跃阈值(秒)
Returns:
bool: 是否活跃
"""
if not self.state_info.last_message_time:
return False
current_time = datetime.now().timestamp()
return (current_time - self.state_info.last_message_time) <= threshold

View File

@@ -0,0 +1,134 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any, Optional
from src.common.database import db
class MessageStorage(ABC):
"""消息存储接口"""
@abstractmethod
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""获取指定消息ID之后的所有消息
Args:
chat_id: 聊天ID
message_id: 消息ID如果为None则获取所有消息
Returns:
List[Dict[str, Any]]: 消息列表
"""
pass
@abstractmethod
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
"""获取指定时间点之前的消息
Args:
chat_id: 聊天ID
time_point: 时间戳
limit: 最大消息数量
Returns:
List[Dict[str, Any]]: 消息列表
"""
pass
@abstractmethod
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
"""检查是否有新消息
Args:
chat_id: 聊天ID
after_time: 时间戳
Returns:
bool: 是否有新消息
"""
pass
class MongoDBMessageStorage(MessageStorage):
"""MongoDB消息存储实现"""
def __init__(self):
self.db = db
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
query = {"chat_id": chat_id}
if message_id:
# 获取ID大于message_id的消息
last_message = self.db.messages.find_one({"message_id": message_id})
if last_message:
query["time"] = {"$gt": last_message["time"]}
return list(
self.db.messages.find(query).sort("time", 1)
)
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
query = {
"chat_id": chat_id,
"time": {"$lt": time_point}
}
messages = list(
self.db.messages.find(query).sort("time", -1).limit(limit)
)
# 将消息按时间正序排列
messages.reverse()
return messages
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
query = {
"chat_id": chat_id,
"time": {"$gt": after_time}
}
return self.db.messages.find_one(query) is not None
# # 创建一个内存消息存储实现,用于测试
# class InMemoryMessageStorage(MessageStorage):
# """内存消息存储实现,主要用于测试"""
# def __init__(self):
# self.messages: Dict[str, List[Dict[str, Any]]] = {}
# async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
# if chat_id not in self.messages:
# return []
# messages = self.messages[chat_id]
# if not message_id:
# return messages
# # 找到message_id的索引
# try:
# index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id)
# return messages[index + 1:]
# except StopIteration:
# return []
# async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
# if chat_id not in self.messages:
# return []
# messages = [
# m for m in self.messages[chat_id]
# if m["time"] < time_point
# ]
# return messages[-limit:]
# async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
# if chat_id not in self.messages:
# return False
# return any(m["time"] > after_time for m in self.messages[chat_id])
# # 测试辅助方法
# def add_message(self, chat_id: str, message: Dict[str, Any]):
# """添加测试消息"""
# if chat_id not in self.messages:
# self.messages[chat_id] = []
# self.messages[chat_id].append(message)
# self.messages[chat_id].sort(key=lambda m: m["time"])

View File

@@ -2,7 +2,7 @@
#Prefrontal cortex #Prefrontal cortex
import datetime import datetime
import asyncio import asyncio
from typing import List, Optional, Dict, Any, Tuple, Literal from typing import List, Optional, Dict, Any, Tuple, Literal, Set
from enum import Enum from enum import Enum
from src.common.logger import get_module_logger from src.common.logger import get_module_logger
from ..chat.chat_stream import ChatStream from ..chat.chat_stream import ChatStream
@@ -19,7 +19,9 @@ from .pfc_KnowledgeFetcher import KnowledgeFetcher
from .reply_checker import ReplyChecker from .reply_checker import ReplyChecker
from .pfc_utils import get_items_from_json from .pfc_utils import get_items_from_json
from src.individuality.individuality import Individuality from src.individuality.individuality import Individuality
from .chat_states import NotificationHandler, Notification, NotificationType
import time import time
from dataclasses import dataclass, field
logger = get_module_logger("pfc") logger = get_module_logger("pfc")
@@ -42,6 +44,99 @@ class ConversationState(Enum):
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"] ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
@dataclass
class DecisionInfo:
"""决策信息类用于收集和管理来自chat_observer的通知信息"""
# 消息相关
last_message_time: Optional[float] = None
last_message_content: Optional[str] = None
last_message_sender: Optional[str] = None
new_messages_count: int = 0
unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list)
# 对话状态
is_cold_chat: bool = False
cold_chat_duration: float = 0.0
last_bot_speak_time: Optional[float] = None
last_user_speak_time: Optional[float] = None
# 对话参与者
active_users: Set[str] = field(default_factory=set)
bot_id: str = field(default="")
def update_from_message(self, message: Dict[str, Any]):
"""从消息更新信息
Args:
message: 消息数据
"""
self.last_message_time = message["time"]
self.last_message_content = message.get("processed_plain_text", "")
user_info = UserInfo.from_dict(message.get("user_info", {}))
self.last_message_sender = user_info.user_id
if user_info.user_id == self.bot_id:
self.last_bot_speak_time = message["time"]
else:
self.last_user_speak_time = message["time"]
self.active_users.add(user_info.user_id)
self.new_messages_count += 1
self.unprocessed_messages.append(message)
def update_cold_chat_status(self, is_cold: bool, current_time: float):
"""更新冷场状态
Args:
is_cold: 是否冷场
current_time: 当前时间
"""
self.is_cold_chat = is_cold
if is_cold and self.last_message_time:
self.cold_chat_duration = current_time - self.last_message_time
def get_active_duration(self) -> float:
"""获取当前活跃时长
Returns:
float: 最后一条消息到现在的时长(秒)
"""
if not self.last_message_time:
return 0.0
return time.time() - self.last_message_time
def get_user_response_time(self) -> Optional[float]:
"""获取用户响应时间
Returns:
Optional[float]: 用户最后发言到现在的时长如果没有用户发言则返回None
"""
if not self.last_user_speak_time:
return None
return time.time() - self.last_user_speak_time
def get_bot_response_time(self) -> Optional[float]:
"""获取机器人响应时间
Returns:
Optional[float]: 机器人最后发言到现在的时长如果没有机器人发言则返回None
"""
if not self.last_bot_speak_time:
return None
return time.time() - self.last_bot_speak_time
def clear_unprocessed_messages(self):
"""清空未处理消息列表"""
self.unprocessed_messages.clear()
self.new_messages_count = 0
# Forward reference for type hints
DecisionInfoType = DecisionInfo
class ActionPlanner: class ActionPlanner:
"""行动规划器""" """行动规划器"""
@@ -62,22 +157,24 @@ class ActionPlanner:
method: str, method: str,
reasoning: str, reasoning: str,
action_history: List[Dict[str, str]] = None, action_history: List[Dict[str, str]] = None,
chat_observer: Optional[ChatObserver] = None, # 添加chat_observer参数 decision_info: DecisionInfoType = None # Use DecisionInfoType here
) -> Tuple[str, str]: ) -> Tuple[str, str]:
"""规划下一步行动 """规划下一步行动
Args: Args:
goal: 对话目标 goal: 对话目标
method: 实现方法
reasoning: 目标原因 reasoning: 目标原因
action_history: 行动历史记录 action_history: 行动历史记录
decision_info: 决策信息
Returns: Returns:
Tuple[str, str]: (行动类型, 行动原因) Tuple[str, str]: (行动类型, 行动原因)
""" """
# 构建提示词 # 构建提示词
# 获取最近20条消息 logger.debug(f"开始规划行动:当前目标: {goal}")
self.chat_observer.waiting_start_time = time.time()
# 获取最近20条消息
messages = self.chat_observer.get_message_history(limit=20) messages = self.chat_observer.get_message_history(limit=20)
chat_history_text = "" chat_history_text = ""
for msg in messages: for msg in messages:
@@ -92,22 +189,42 @@ class ActionPlanner:
# 构建action历史文本 # 构建action历史文本
action_history_text = "" action_history_text = ""
if action_history: if action_history and action_history[-1]['action'] == "direct_reply":
if action_history[-1]['action'] == "direct_reply": action_history_text = "你刚刚发言回复了对方"
action_history_text = "你刚刚发言回复了对方"
# 获取时间信息 # 构建决策信息文本
time_info = self.chat_observer.get_time_info() decision_info_text = ""
if decision_info:
decision_info_text = "当前对话状态:\n"
if decision_info.is_cold_chat:
decision_info_text += f"对话处于冷场状态,已持续{int(decision_info.cold_chat_duration)}\n"
if decision_info.new_messages_count > 0:
decision_info_text += f"{decision_info.new_messages_count}条新消息未处理\n"
user_response_time = decision_info.get_user_response_time()
if user_response_time:
decision_info_text += f"距离用户上次发言已过去{int(user_response_time)}\n"
bot_response_time = decision_info.get_bot_response_time()
if bot_response_time:
decision_info_text += f"距离你上次发言已过去{int(bot_response_time)}\n"
if decision_info.active_users:
decision_info_text += f"当前活跃用户数: {len(decision_info.active_users)}\n"
prompt = f"""{personality_text}。现在你在参与一场QQ聊天请分析以下内容根据信息决定下一步行动
prompt = f"""现在你在参与一场QQ聊天请分析以下内容根据信息决定下一步行动
{personality_text}
当前对话目标:{goal} 当前对话目标:{goal}
实现该对话目标的方式:{method} 实现该对话目标的方式:{method}
产生该对话目标的原因:{reasoning} 产生该对话目标的原因:{reasoning}
{time_info}
{decision_info_text}
{action_history_text}
最近的对话记录: 最近的对话记录:
{chat_history_text} {chat_history_text}
{action_history_text}
请你接下去想想要你要做什么,可以发言,可以等待,可以倾听,可以调取知识。注意不同行动类型的要求,不要重复发言: 请你接下去想想要你要做什么,可以发言,可以等待,可以倾听,可以调取知识。注意不同行动类型的要求,不要重复发言:
行动类型: 行动类型:
fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择 fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择
@@ -413,16 +530,23 @@ class Waiter:
Returns: Returns:
bool: 是否超时True表示超时 bool: 是否超时True表示超时
""" """
wait_start_time = self.chat_observer.waiting_start_time # 使用当前时间作为等待开始时间
while not self.chat_observer.new_message_after(wait_start_time): wait_start_time = time.time()
await asyncio.sleep(1) self.chat_observer.waiting_start_time = wait_start_time # 设置等待开始时间
logger.info("等待中...")
# 检查是否超过60秒 while True:
# 检查是否有新消息
if self.chat_observer.new_message_after(wait_start_time):
logger.info("等待结束,收到新消息")
return False
# 检查是否超时
if time.time() - wait_start_time > 300: if time.time() - wait_start_time > 300:
logger.info("等待超过300秒结束对话") logger.info("等待超过300秒结束对话")
return True return True
logger.info("等待结束")
return False await asyncio.sleep(1)
logger.info("等待中...")
class ReplyGenerator: class ReplyGenerator:
@@ -519,16 +643,16 @@ class ReplyGenerator:
try: try:
content, _ = await self.llm.generate_response_async(prompt) content, _ = await self.llm.generate_response_async(prompt)
logger.info(f"生成的回复: {content}") logger.info(f"生成的回复: {content}")
is_new = self.chat_observer.check() # is_new = self.chat_observer.check()
logger.debug(f"再看一眼聊天记录,{'' if is_new else '没有'}新消息") # logger.debug(f"再看一眼聊天记录,{'有' if is_new else '没有'}新消息")
# 如果有新消息,重新生成回复 # 如果有新消息,重新生成回复
if is_new: # if is_new:
logger.info("检测到新消息,重新生成回复") # logger.info("检测到新消息,重新生成回复")
return await self.generate( # return await self.generate(
goal, chat_history, knowledge_cache, # goal, chat_history, knowledge_cache,
None, retry_count # None, retry_count
) # )
return content return content
@@ -555,12 +679,69 @@ class ReplyGenerator:
return await self.reply_checker.check(reply, goal, retry_count) return await self.reply_checker.check(reply, goal, retry_count)
class PFCNotificationHandler(NotificationHandler):
"""PFC的通知处理器"""
def __init__(self, conversation: 'Conversation'):
self.conversation = conversation
self.logger = get_module_logger("pfc_notification")
self.decision_info = conversation.decision_info
async def handle_notification(self, notification: Notification):
"""处理通知"""
try:
if not notification or not hasattr(notification, 'data') or notification.data is None:
self.logger.error("收到无效的通知notification 或 data 为空")
return
if notification.type == NotificationType.NEW_MESSAGE:
# 处理新消息通知
message = notification.data
if not isinstance(message, dict):
self.logger.error(f"无效的消息格式: {type(message)}")
return
content = message.get('content', '')
self.logger.info(f"收到新消息通知: {content[:30] if content else ''}...")
# 更新决策信息
try:
self.decision_info.update_from_message(message)
except Exception as e:
self.logger.error(f"更新决策信息失败: {e}")
return
# 触发对话系统更新
self.conversation.chat_observer.trigger_update()
elif notification.type == NotificationType.COLD_CHAT:
# 处理冷场通知
try:
is_cold = bool(notification.data.get("is_cold", False))
# 更新决策信息
self.decision_info.update_cold_chat_status(is_cold, time.time())
if is_cold:
self.logger.info("检测到对话冷场")
else:
self.logger.info("对话恢复活跃")
except Exception as e:
self.logger.error(f"处理冷场状态失败: {e}")
return
except Exception as e:
self.logger.error(f"处理通知时出错: {str(e)}")
# 添加更详细的错误信息
self.logger.error(f"通知类型: {getattr(notification, 'type', None)}")
self.logger.error(f"通知数据: {getattr(notification, 'data', None)}")
class Conversation: class Conversation:
# 类级别的实例管理 # 类级别的实例管理
_instances: Dict[str, 'Conversation'] = {} _instances: Dict[str, 'Conversation'] = {}
_instance_lock = asyncio.Lock() # 类级别的全局锁 _instance_lock = asyncio.Lock()
_init_events: Dict[str, asyncio.Event] = {} # 初始化完成事件 _init_events: Dict[str, asyncio.Event] = {}
_initializing: Dict[str, bool] = {} # 标记是否正在初始化 _initializing: Dict[str, bool] = {}
@classmethod @classmethod
async def get_instance(cls, stream_id: str) -> Optional['Conversation']: async def get_instance(cls, stream_id: str) -> Optional['Conversation']:
@@ -573,102 +754,89 @@ class Conversation:
Optional[Conversation]: 对话实例如果创建或等待失败则返回None Optional[Conversation]: 对话实例如果创建或等待失败则返回None
""" """
try: try:
# 使用全局锁来确保线程安全 # 检查是否已经有实例
async with cls._instance_lock: if stream_id in cls._instances:
# 如果已经在初始化中,等待初始化完成
if stream_id in cls._initializing and cls._initializing[stream_id]:
# 释放锁等待初始化
cls._instance_lock.release()
try:
await asyncio.wait_for(cls._init_events[stream_id].wait(), timeout=5.0)
except asyncio.TimeoutError:
logger.error(f"等待实例 {stream_id} 初始化超时")
return None
finally:
await cls._instance_lock.acquire()
# 如果实例不存在,创建新实例
if stream_id not in cls._instances:
cls._instances[stream_id] = cls(stream_id)
cls._init_events[stream_id] = asyncio.Event()
cls._initializing[stream_id] = True
logger.info(f"创建新的对话实例: {stream_id}")
return cls._instances[stream_id] return cls._instances[stream_id]
async with cls._instance_lock:
# 再次检查,防止在获取锁的过程中其他线程创建了实例
if stream_id in cls._instances:
return cls._instances[stream_id]
# 如果正在初始化,等待初始化完成
if stream_id in cls._initializing and cls._initializing[stream_id]:
event = cls._init_events.get(stream_id)
if event:
try:
# 在等待之前释放锁
cls._instance_lock.release()
await asyncio.wait_for(event.wait(), timeout=10.0) # 增加超时时间到10秒
# 重新获取锁
await cls._instance_lock.acquire()
if stream_id in cls._instances:
return cls._instances[stream_id]
except asyncio.TimeoutError:
logger.error(f"等待实例 {stream_id} 初始化超时")
# 清理超时的初始化状态
cls._initializing[stream_id] = False
if stream_id in cls._init_events:
del cls._init_events[stream_id]
return None
# 创建新实例
logger.info(f"创建新的对话实例: {stream_id}")
cls._initializing[stream_id] = True
cls._init_events[stream_id] = asyncio.Event()
# 在锁保护下创建实例
instance = cls(stream_id)
cls._instances[stream_id] = instance
# 启动实例初始化(在后台运行)
asyncio.create_task(instance._initialize())
return instance
except Exception as e: except Exception as e:
logger.error(f"获取对话实例失败: {e}") logger.error(f"获取对话实例失败: {e}")
return None return None
@classmethod async def _initialize(self):
async def remove_instance(cls, stream_id: str): """初始化实例(在后台运行)"""
"""删除对话实例 try:
logger.info(f"开始初始化对话实例: {self.stream_id}")
self.chat_observer.start() # 启动观察器
await asyncio.sleep(1) # 给观察器一些启动时间
Args: # 获取初始目标
stream_id: 聊天流ID self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
"""
async with cls._instance_lock:
if stream_id in cls._instances:
# 停止相关组件
instance = cls._instances[stream_id]
instance.chat_observer.stop()
# 删除实例
del cls._instances[stream_id]
if stream_id in cls._init_events:
del cls._init_events[stream_id]
if stream_id in cls._initializing:
del cls._initializing[stream_id]
logger.info(f"已删除对话实例 {stream_id}")
def __init__(self, stream_id: str): # 标记初始化完成
"""初始化对话系统""" self.__class__._initializing[self.stream_id] = False
self.stream_id = stream_id if self.stream_id in self.__class__._init_events:
self.state = ConversationState.INIT self.__class__._init_events[self.stream_id].set()
self.current_goal: Optional[str] = None
self.current_method: Optional[str] = None
self.goal_reasoning: Optional[str] = None
self.generated_reply: Optional[str] = None
self.should_continue = True
# 初始化聊天观察器 # 启动对话循环
self.chat_observer = ChatObserver.get_instance(stream_id) asyncio.create_task(self._conversation_loop())
# 添加action历史记录 except Exception as e:
self.action_history: List[Dict[str, str]] = [] logger.error(f"初始化对话实例失败: {e}")
# 清理失败的初始化
# 知识缓存 self.__class__._initializing[self.stream_id] = False
self.knowledge_cache: Dict[str, str] = {} # 确保初始化为字典 if self.stream_id in self.__class__._init_events:
self.__class__._init_events[self.stream_id].set()
# 初始化各个组件 if self.stream_id in self.__class__._instances:
self.goal_analyzer = GoalAnalyzer(self.stream_id) del self.__class__._instances[self.stream_id]
self.action_planner = ActionPlanner(self.stream_id)
self.reply_generator = ReplyGenerator(self.stream_id)
self.knowledge_fetcher = KnowledgeFetcher()
self.direct_sender = DirectMessageSender()
self.waiter = Waiter(self.stream_id)
# 创建聊天流
self.chat_stream = chat_manager.get_stream(self.stream_id)
def _clear_knowledge_cache(self):
"""清空知识缓存"""
self.knowledge_cache.clear() # 使用clear方法清空字典
async def start(self): async def start(self):
"""开始对话流程""" """开始对话流程"""
try: try:
logger.info("对话系统启动") logger.info("对话系统启动")
self.should_continue = True self.should_continue = True
self.chat_observer.start() # 启动观察器
await asyncio.sleep(1)
# 启动对话循环
await self._conversation_loop() await self._conversation_loop()
except Exception as e: except Exception as e:
logger.error(f"启动对话系统失败: {e}") logger.error(f"启动对话系统失败: {e}")
raise raise
finally:
# 标记初始化完成
self._init_events[self.stream_id].set()
self._initializing[self.stream_id] = False
async def _conversation_loop(self): async def _conversation_loop(self):
"""对话循环""" """对话循环"""
@@ -681,17 +849,21 @@ class Conversation:
if not await self.chat_observer.wait_for_update(): if not await self.chat_observer.wait_for_update():
logger.warning("等待消息更新超时") logger.warning("等待消息更新超时")
# 使用决策信息来辅助行动规划
action, reason = await self.action_planner.plan( action, reason = await self.action_planner.plan(
self.current_goal, self.current_goal,
self.current_method, self.current_method,
self.goal_reasoning, self.goal_reasoning,
self.action_history, # 传入action历史 self.action_history,
self.chat_observer # 传入chat_observer self.decision_info # 传入决策信息
) )
# 执行行动 # 执行行动
await self._handle_action(action, reason) await self._handle_action(action, reason)
# 清理已处理的消息
self.decision_info.clear_unprocessed_messages()
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message: def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
"""将消息字典转换为Message对象""" """将消息字典转换为Message对象"""
try: try:
@@ -742,87 +914,6 @@ class Conversation:
self.current_goal self.current_goal
) )
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
# 尝试切换到其他备选目标
alternative_goals = await self.goal_analyzer.get_alternative_goals()
if alternative_goals:
# 有备选目标,尝试使用下一个目标
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
logger.info(f"切换到备选目标: {self.current_goal}")
# 使用新目标生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
)
# 检查使用新目标生成的回复是否合适
is_suitable, reason, _ = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
)
if is_suitable:
# 如果新目标的回复合适,调整目标优先级
await self.goal_analyzer._update_goals(
self.current_goal,
self.current_method,
self.goal_reasoning
)
else:
# 如果新目标还是不合适,重新思考目标
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
return
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
)
while self.chat_observer.check():
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
# 尝试切换到其他备选目标
alternative_goals = await self.goal_analyzer.get_alternative_goals()
if alternative_goals:
# 有备选目标,尝试使用下一个目标
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
logger.info(f"切换到备选目标: {self.current_goal}")
# 使用新目标生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
)
is_suitable = True # 假设使用新目标后回复是合适的
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
)
await self._send_reply() await self._send_reply()
elif action == "fetch_knowledge": elif action == "fetch_knowledge":
@@ -837,59 +928,6 @@ class Conversation:
if knowledge != "未找到相关知识": if knowledge != "未找到相关知识":
self.knowledge_cache[sources] = knowledge self.knowledge_cache[sources] = knowledge
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
)
# 检查回复是否合适
is_suitable, reason, need_replan = await self.reply_generator.check_reply(
self.generated_reply,
self.current_goal
)
if not is_suitable:
logger.warning(f"生成的回复不合适,原因: {reason}")
if need_replan:
# 尝试切换到其他备选目标
alternative_goals = await self.goal_analyzer.get_alternative_goals()
if alternative_goals:
# 有备选目标,尝试使用
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
logger.info(f"切换到备选目标: {self.current_goal}")
# 使用新目标获取知识并生成回复
knowledge, sources = await self.knowledge_fetcher.fetch(
self.current_goal,
[self._convert_to_message(msg) for msg in messages]
)
if knowledge != "未找到相关知识":
self.knowledge_cache[sources] = knowledge
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache
)
else:
# 没有备选目标,重新分析
self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
return
else:
# 重新生成回复
self.generated_reply = await self.reply_generator.generate(
self.current_goal,
self.current_method,
[self._convert_to_message(msg) for msg in messages],
self.knowledge_cache,
self.generated_reply # 将不合适的回复作为previous_reply传入
)
await self._send_reply()
elif action == "rethink_goal": elif action == "rethink_goal":
self.state = ConversationState.RETHINKING self.state = ConversationState.RETHINKING
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal() self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()

View File

@@ -25,9 +25,9 @@ config_config = LogConfig(
logger = get_module_logger("config", config=config_config) logger = get_module_logger("config", config=config_config)
#考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 #考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
is_test = False is_test = True
mai_version_main = "0.6.1" mai_version_main = "0.6.2"
mai_version_fix = "" mai_version_fix = "snapshot-1"
if mai_version_fix: if mai_version_fix:
if is_test: if is_test:
mai_version = f"test-{mai_version_main}-{mai_version_fix}" mai_version = f"test-{mai_version_main}-{mai_version_fix}"
@@ -441,6 +441,7 @@ class BotConfig:
config.emoji_response_penalty = willing_config.get( config.emoji_response_penalty = willing_config.get(
"emoji_response_penalty", config.emoji_response_penalty "emoji_response_penalty", config.emoji_response_penalty
) )
if config.INNER_VERSION in SpecifierSet(">=1.2.5"):
config.mentioned_bot_inevitable_reply = willing_config.get( config.mentioned_bot_inevitable_reply = willing_config.get(
"mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply "mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply
) )

View File

@@ -1,5 +1,5 @@
[inner] [inner]
version = "1.2.4" version = "1.2.5"
#以下是给开发人员阅读的,一般用户不需要阅读 #以下是给开发人员阅读的,一般用户不需要阅读