better:尝试重构pfc
This commit is contained in:
@@ -5,6 +5,8 @@ from src.common.logger import get_module_logger
|
|||||||
from src.common.database import db
|
from src.common.database import db
|
||||||
from ..message.message_base import UserInfo
|
from ..message.message_base import UserInfo
|
||||||
from ..config.config import global_config
|
from ..config.config import global_config
|
||||||
|
from .chat_states import NotificationManager, create_new_message_notification, create_cold_chat_notification
|
||||||
|
from .message_storage import MessageStorage, MongoDBMessageStorage
|
||||||
|
|
||||||
logger = get_module_logger("chat_observer")
|
logger = get_module_logger("chat_observer")
|
||||||
|
|
||||||
@@ -15,36 +17,40 @@ class ChatObserver:
|
|||||||
_instances: Dict[str, 'ChatObserver'] = {}
|
_instances: Dict[str, 'ChatObserver'] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_instance(cls, stream_id: str) -> 'ChatObserver':
|
def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver':
|
||||||
"""获取或创建观察器实例
|
"""获取或创建观察器实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_id: 聊天流ID
|
stream_id: 聊天流ID
|
||||||
|
message_storage: 消息存储实现,如果为None则使用MongoDB实现
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
ChatObserver: 观察器实例
|
ChatObserver: 观察器实例
|
||||||
"""
|
"""
|
||||||
if stream_id not in cls._instances:
|
if stream_id not in cls._instances:
|
||||||
cls._instances[stream_id] = cls(stream_id)
|
cls._instances[stream_id] = cls(stream_id, message_storage)
|
||||||
return cls._instances[stream_id]
|
return cls._instances[stream_id]
|
||||||
|
|
||||||
def __init__(self, stream_id: str):
|
def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None):
|
||||||
"""初始化观察器
|
"""初始化观察器
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
stream_id: 聊天流ID
|
stream_id: 聊天流ID
|
||||||
|
message_storage: 消息存储实现,如果为None则使用MongoDB实现
|
||||||
"""
|
"""
|
||||||
if stream_id in self._instances:
|
if stream_id in self._instances:
|
||||||
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
raise RuntimeError(f"ChatObserver for {stream_id} already exists. Use get_instance() instead.")
|
||||||
|
|
||||||
self.stream_id = stream_id
|
self.stream_id = stream_id
|
||||||
|
self.message_storage = message_storage or MongoDBMessageStorage()
|
||||||
|
|
||||||
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
self.last_user_speak_time: Optional[float] = None # 对方上次发言时间
|
||||||
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间
|
||||||
self.last_check_time: float = time.time() # 上次查看聊天记录时间
|
self.last_check_time: float = time.time() # 上次查看聊天记录时间
|
||||||
self.last_message_read: Optional[str] = None # 最后读取的消息ID
|
self.last_message_read: Optional[str] = None # 最后读取的消息ID
|
||||||
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
|
self.last_message_time: Optional[float] = None # 最后一条消息的时间戳
|
||||||
|
|
||||||
self.waiting_start_time: Optional[float] = None # 等待开始时间
|
self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间
|
||||||
|
|
||||||
# 消息历史记录
|
# 消息历史记录
|
||||||
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
|
self.message_history: List[Dict[str, Any]] = [] # 所有消息历史
|
||||||
@@ -56,8 +62,16 @@ class ChatObserver:
|
|||||||
self._task: Optional[asyncio.Task] = None
|
self._task: Optional[asyncio.Task] = None
|
||||||
self._update_event = asyncio.Event() # 触发更新的事件
|
self._update_event = asyncio.Event() # 触发更新的事件
|
||||||
self._update_complete = asyncio.Event() # 更新完成的事件
|
self._update_complete = asyncio.Event() # 更新完成的事件
|
||||||
|
|
||||||
|
# 通知管理器
|
||||||
|
self.notification_manager = NotificationManager()
|
||||||
|
|
||||||
|
# 冷场检查配置
|
||||||
|
self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场
|
||||||
|
self.last_cold_chat_check: float = time.time()
|
||||||
|
self.is_cold_chat_state: bool = False
|
||||||
|
|
||||||
def check(self) -> bool:
|
async def check(self) -> bool:
|
||||||
"""检查距离上一次观察之后是否有了新消息
|
"""检查距离上一次观察之后是否有了新消息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -65,13 +79,10 @@ class ChatObserver:
|
|||||||
"""
|
"""
|
||||||
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}")
|
||||||
|
|
||||||
query = {
|
new_message_exists = await self.message_storage.has_new_messages(
|
||||||
"chat_id": self.stream_id,
|
self.stream_id,
|
||||||
"time": {"$gt": self.last_check_time}
|
self.last_check_time
|
||||||
}
|
)
|
||||||
|
|
||||||
# 只需要查询是否存在,不需要获取具体消息
|
|
||||||
new_message_exists = db.messages.find_one(query) is not None
|
|
||||||
|
|
||||||
if new_message_exists:
|
if new_message_exists:
|
||||||
logger.debug("发现新消息")
|
logger.debug("发现新消息")
|
||||||
@@ -79,27 +90,8 @@ class ChatObserver:
|
|||||||
|
|
||||||
return new_message_exists
|
return new_message_exists
|
||||||
|
|
||||||
def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
async def _add_message_to_history(self, message: Dict[str, Any]):
|
||||||
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
|
"""添加消息到历史记录并发送通知
|
||||||
messages = self.get_message_history(self.last_check_time)
|
|
||||||
for message in messages:
|
|
||||||
self._add_message_to_history(message)
|
|
||||||
return messages, self.message_history
|
|
||||||
|
|
||||||
def new_message_after(self, time_point: float) -> bool:
|
|
||||||
"""判断是否在指定时间点后有新消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
time_point: 时间戳
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否有新消息
|
|
||||||
"""
|
|
||||||
logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point}")
|
|
||||||
return self.last_message_time is None or self.last_message_time > time_point
|
|
||||||
|
|
||||||
def _add_message_to_history(self, message: Dict[str, Any]):
|
|
||||||
"""添加消息到历史记录
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
message: 消息数据
|
message: 消息数据
|
||||||
@@ -116,6 +108,75 @@ class ChatObserver:
|
|||||||
else:
|
else:
|
||||||
self.last_user_speak_time = message["time"]
|
self.last_user_speak_time = message["time"]
|
||||||
|
|
||||||
|
# 发送新消息通知
|
||||||
|
notification = create_new_message_notification(
|
||||||
|
sender="chat_observer",
|
||||||
|
target="pfc",
|
||||||
|
message=message
|
||||||
|
)
|
||||||
|
await self.notification_manager.send_notification(notification)
|
||||||
|
|
||||||
|
# 检查并更新冷场状态
|
||||||
|
await self._check_cold_chat()
|
||||||
|
|
||||||
|
async def _check_cold_chat(self):
|
||||||
|
"""检查是否处于冷场状态并发送通知"""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# 每10秒检查一次冷场状态
|
||||||
|
if current_time - self.last_cold_chat_check < 10:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.last_cold_chat_check = current_time
|
||||||
|
|
||||||
|
# 判断是否冷场
|
||||||
|
is_cold = False
|
||||||
|
if self.last_message_time is None:
|
||||||
|
is_cold = True
|
||||||
|
else:
|
||||||
|
is_cold = (current_time - self.last_message_time) > self.cold_chat_threshold
|
||||||
|
|
||||||
|
# 如果冷场状态发生变化,发送通知
|
||||||
|
if is_cold != self.is_cold_chat_state:
|
||||||
|
self.is_cold_chat_state = is_cold
|
||||||
|
notification = create_cold_chat_notification(
|
||||||
|
sender="chat_observer",
|
||||||
|
target="pfc",
|
||||||
|
is_cold=is_cold
|
||||||
|
)
|
||||||
|
await self.notification_manager.send_notification(notification)
|
||||||
|
|
||||||
|
async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||||
|
"""获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象"""
|
||||||
|
messages = await self.message_storage.get_messages_after(
|
||||||
|
self.stream_id,
|
||||||
|
self.last_message_read
|
||||||
|
)
|
||||||
|
for message in messages:
|
||||||
|
await self._add_message_to_history(message)
|
||||||
|
return messages, self.message_history
|
||||||
|
|
||||||
|
def new_message_after(self, time_point: float) -> bool:
|
||||||
|
"""判断是否在指定时间点后有新消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
time_point: 时间戳
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否有新消息
|
||||||
|
"""
|
||||||
|
if time_point is None:
|
||||||
|
logger.warning("time_point 为 None,返回 False")
|
||||||
|
return False
|
||||||
|
|
||||||
|
if self.last_message_time is None:
|
||||||
|
logger.debug("没有最后消息时间,返回 False")
|
||||||
|
return False
|
||||||
|
|
||||||
|
has_new = self.last_message_time > time_point
|
||||||
|
logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}")
|
||||||
|
return has_new
|
||||||
|
|
||||||
def get_message_history(
|
def get_message_history(
|
||||||
self,
|
self,
|
||||||
start_time: Optional[float] = None,
|
start_time: Optional[float] = None,
|
||||||
@@ -159,15 +220,9 @@ class ChatObserver:
|
|||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: 新消息列表
|
List[Dict[str, Any]]: 新消息列表
|
||||||
"""
|
"""
|
||||||
query = {"chat_id": self.stream_id}
|
new_messages = await self.message_storage.get_messages_after(
|
||||||
if self.last_message_read:
|
self.stream_id,
|
||||||
# 获取ID大于last_message_read的消息
|
self.last_message_read
|
||||||
last_message = db.messages.find_one({"message_id": self.last_message_read})
|
|
||||||
if last_message:
|
|
||||||
query["time"] = {"$gt": last_message["time"]}
|
|
||||||
|
|
||||||
new_messages = list(
|
|
||||||
db.messages.find(query).sort("time", 1)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
@@ -184,30 +239,24 @@ class ChatObserver:
|
|||||||
Returns:
|
Returns:
|
||||||
List[Dict[str, Any]]: 最多5条消息
|
List[Dict[str, Any]]: 最多5条消息
|
||||||
"""
|
"""
|
||||||
query = {
|
new_messages = await self.message_storage.get_messages_before(
|
||||||
"chat_id": self.stream_id,
|
self.stream_id,
|
||||||
"time": {"$lt": time_point}
|
time_point
|
||||||
}
|
|
||||||
|
|
||||||
new_messages = list(
|
|
||||||
db.messages.find(query).sort("time", -1).limit(5) # 倒序获取5条
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 将消息按时间正序排列
|
|
||||||
new_messages.reverse()
|
|
||||||
|
|
||||||
if new_messages:
|
if new_messages:
|
||||||
self.last_message_read = new_messages[-1]["message_id"]
|
self.last_message_read = new_messages[-1]["message_id"]
|
||||||
|
|
||||||
return new_messages
|
return new_messages
|
||||||
|
|
||||||
|
'''主要观察循环'''
|
||||||
async def _update_loop(self):
|
async def _update_loop(self):
|
||||||
"""更新循环"""
|
"""更新循环"""
|
||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
messages = await self._fetch_new_messages_before(start_time)
|
messages = await self._fetch_new_messages_before(start_time)
|
||||||
for message in messages:
|
for message in messages:
|
||||||
self._add_message_to_history(message)
|
await self._add_message_to_history(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"缓冲消息出错: {e}")
|
logger.error(f"缓冲消息出错: {e}")
|
||||||
|
|
||||||
@@ -228,7 +277,7 @@ class ChatObserver:
|
|||||||
if new_messages:
|
if new_messages:
|
||||||
# 处理新消息
|
# 处理新消息
|
||||||
for message in new_messages:
|
for message in new_messages:
|
||||||
self._add_message_to_history(message)
|
await self._add_message_to_history(message)
|
||||||
|
|
||||||
# 设置完成事件
|
# 设置完成事件
|
||||||
self._update_complete.set()
|
self._update_complete.set()
|
||||||
|
|||||||
267
src/plugins/PFC/chat_states.py
Normal file
267
src/plugins/PFC/chat_states.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
from enum import Enum, auto
|
||||||
|
from typing import Optional, Dict, Any, List, Set
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import datetime
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
class ChatState(Enum):
|
||||||
|
"""聊天状态枚举"""
|
||||||
|
NORMAL = auto() # 正常状态
|
||||||
|
NEW_MESSAGE = auto() # 有新消息
|
||||||
|
COLD_CHAT = auto() # 冷场状态
|
||||||
|
ACTIVE_CHAT = auto() # 活跃状态
|
||||||
|
BOT_SPEAKING = auto() # 机器人正在说话
|
||||||
|
USER_SPEAKING = auto() # 用户正在说话
|
||||||
|
SILENT = auto() # 沉默状态
|
||||||
|
ERROR = auto() # 错误状态
|
||||||
|
|
||||||
|
class NotificationType(Enum):
|
||||||
|
"""通知类型枚举"""
|
||||||
|
NEW_MESSAGE = auto() # 新消息通知
|
||||||
|
COLD_CHAT = auto() # 冷场通知
|
||||||
|
ACTIVE_CHAT = auto() # 活跃通知
|
||||||
|
BOT_SPEAKING = auto() # 机器人说话通知
|
||||||
|
USER_SPEAKING = auto() # 用户说话通知
|
||||||
|
MESSAGE_DELETED = auto() # 消息删除通知
|
||||||
|
USER_JOINED = auto() # 用户加入通知
|
||||||
|
USER_LEFT = auto() # 用户离开通知
|
||||||
|
ERROR = auto() # 错误通知
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatStateInfo:
|
||||||
|
"""聊天状态信息"""
|
||||||
|
state: ChatState
|
||||||
|
last_message_time: Optional[float] = None
|
||||||
|
last_message_content: Optional[str] = None
|
||||||
|
last_speaker: Optional[str] = None
|
||||||
|
message_count: int = 0
|
||||||
|
cold_duration: float = 0.0 # 冷场持续时间(秒)
|
||||||
|
active_duration: float = 0.0 # 活跃持续时间(秒)
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Notification:
|
||||||
|
"""通知基类"""
|
||||||
|
type: NotificationType
|
||||||
|
timestamp: float
|
||||||
|
sender: str # 发送者标识
|
||||||
|
target: str # 接收者标识
|
||||||
|
data: Dict[str, Any]
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典格式"""
|
||||||
|
return {
|
||||||
|
"type": self.type.name,
|
||||||
|
"timestamp": self.timestamp,
|
||||||
|
"data": self.data
|
||||||
|
}
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class StateNotification(Notification):
|
||||||
|
"""持续状态通知"""
|
||||||
|
is_active: bool = True
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
base_dict = super().to_dict()
|
||||||
|
base_dict["is_active"] = self.is_active
|
||||||
|
return base_dict
|
||||||
|
|
||||||
|
class NotificationHandler(ABC):
|
||||||
|
"""通知处理器接口"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def handle_notification(self, notification: Notification):
|
||||||
|
"""处理通知"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class NotificationManager:
|
||||||
|
"""通知管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 按接收者和通知类型存储处理器
|
||||||
|
self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {}
|
||||||
|
self._active_states: Set[NotificationType] = set()
|
||||||
|
self._notification_history: List[Notification] = []
|
||||||
|
|
||||||
|
def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||||
|
"""注册通知处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: 接收者标识(例如:"pfc")
|
||||||
|
notification_type: 要处理的通知类型
|
||||||
|
handler: 处理器实例
|
||||||
|
"""
|
||||||
|
if target not in self._handlers:
|
||||||
|
self._handlers[target] = {}
|
||||||
|
if notification_type not in self._handlers[target]:
|
||||||
|
self._handlers[target][notification_type] = []
|
||||||
|
self._handlers[target][notification_type].append(handler)
|
||||||
|
|
||||||
|
def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler):
|
||||||
|
"""注销通知处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
target: 接收者标识
|
||||||
|
notification_type: 通知类型
|
||||||
|
handler: 要注销的处理器实例
|
||||||
|
"""
|
||||||
|
if target in self._handlers and notification_type in self._handlers[target]:
|
||||||
|
handlers = self._handlers[target][notification_type]
|
||||||
|
if handler in handlers:
|
||||||
|
handlers.remove(handler)
|
||||||
|
# 如果该类型的处理器列表为空,删除该类型
|
||||||
|
if not handlers:
|
||||||
|
del self._handlers[target][notification_type]
|
||||||
|
# 如果该目标没有任何处理器,删除该目标
|
||||||
|
if not self._handlers[target]:
|
||||||
|
del self._handlers[target]
|
||||||
|
|
||||||
|
async def send_notification(self, notification: Notification):
|
||||||
|
"""发送通知"""
|
||||||
|
self._notification_history.append(notification)
|
||||||
|
|
||||||
|
# 如果是状态通知,更新活跃状态
|
||||||
|
if isinstance(notification, StateNotification):
|
||||||
|
if notification.is_active:
|
||||||
|
self._active_states.add(notification.type)
|
||||||
|
else:
|
||||||
|
self._active_states.discard(notification.type)
|
||||||
|
|
||||||
|
# 调用目标接收者的处理器
|
||||||
|
target = notification.target
|
||||||
|
if target in self._handlers:
|
||||||
|
handlers = self._handlers[target].get(notification.type, [])
|
||||||
|
for handler in handlers:
|
||||||
|
await handler.handle_notification(notification)
|
||||||
|
|
||||||
|
def get_active_states(self) -> Set[NotificationType]:
|
||||||
|
"""获取当前活跃的状态"""
|
||||||
|
return self._active_states.copy()
|
||||||
|
|
||||||
|
def is_state_active(self, state_type: NotificationType) -> bool:
|
||||||
|
"""检查特定状态是否活跃"""
|
||||||
|
return state_type in self._active_states
|
||||||
|
|
||||||
|
def get_notification_history(self,
|
||||||
|
sender: Optional[str] = None,
|
||||||
|
target: Optional[str] = None,
|
||||||
|
limit: Optional[int] = None) -> List[Notification]:
|
||||||
|
"""获取通知历史
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sender: 过滤特定发送者的通知
|
||||||
|
target: 过滤特定接收者的通知
|
||||||
|
limit: 限制返回数量
|
||||||
|
"""
|
||||||
|
history = self._notification_history
|
||||||
|
|
||||||
|
if sender:
|
||||||
|
history = [n for n in history if n.sender == sender]
|
||||||
|
if target:
|
||||||
|
history = [n for n in history if n.target == target]
|
||||||
|
|
||||||
|
if limit is not None:
|
||||||
|
history = history[-limit:]
|
||||||
|
|
||||||
|
return history
|
||||||
|
|
||||||
|
# 一些常用的通知创建函数
|
||||||
|
def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification:
|
||||||
|
"""创建新消息通知"""
|
||||||
|
return Notification(
|
||||||
|
type=NotificationType.NEW_MESSAGE,
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
sender=sender,
|
||||||
|
target=target,
|
||||||
|
data={
|
||||||
|
"message_id": message.get("message_id"),
|
||||||
|
"content": message.get("content"),
|
||||||
|
"sender": message.get("sender"),
|
||||||
|
"time": message.get("time")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification:
|
||||||
|
"""创建冷场状态通知"""
|
||||||
|
return StateNotification(
|
||||||
|
type=NotificationType.COLD_CHAT,
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
sender=sender,
|
||||||
|
target=target,
|
||||||
|
data={"is_cold": is_cold},
|
||||||
|
is_active=is_cold
|
||||||
|
)
|
||||||
|
|
||||||
|
def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification:
|
||||||
|
"""创建活跃状态通知"""
|
||||||
|
return StateNotification(
|
||||||
|
type=NotificationType.ACTIVE_CHAT,
|
||||||
|
timestamp=datetime.now().timestamp(),
|
||||||
|
sender=sender,
|
||||||
|
target=target,
|
||||||
|
data={"is_active": is_active},
|
||||||
|
is_active=is_active
|
||||||
|
)
|
||||||
|
|
||||||
|
class ChatStateManager:
|
||||||
|
"""聊天状态管理器"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.current_state = ChatState.NORMAL
|
||||||
|
self.state_info = ChatStateInfo(state=ChatState.NORMAL)
|
||||||
|
self.state_history: list[ChatStateInfo] = []
|
||||||
|
|
||||||
|
def update_state(self, new_state: ChatState, **kwargs):
|
||||||
|
"""更新聊天状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
new_state: 新的状态
|
||||||
|
**kwargs: 其他状态信息
|
||||||
|
"""
|
||||||
|
self.current_state = new_state
|
||||||
|
self.state_info.state = new_state
|
||||||
|
|
||||||
|
# 更新其他状态信息
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(self.state_info, key):
|
||||||
|
setattr(self.state_info, key, value)
|
||||||
|
|
||||||
|
# 记录状态历史
|
||||||
|
self.state_history.append(self.state_info)
|
||||||
|
|
||||||
|
def get_current_state_info(self) -> ChatStateInfo:
|
||||||
|
"""获取当前状态信息"""
|
||||||
|
return self.state_info
|
||||||
|
|
||||||
|
def get_state_history(self) -> list[ChatStateInfo]:
|
||||||
|
"""获取状态历史"""
|
||||||
|
return self.state_history
|
||||||
|
|
||||||
|
def is_cold_chat(self, threshold: float = 60.0) -> bool:
|
||||||
|
"""判断是否处于冷场状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: 冷场阈值(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否冷场
|
||||||
|
"""
|
||||||
|
if not self.state_info.last_message_time:
|
||||||
|
return True
|
||||||
|
|
||||||
|
current_time = datetime.now().timestamp()
|
||||||
|
return (current_time - self.state_info.last_message_time) > threshold
|
||||||
|
|
||||||
|
def is_active_chat(self, threshold: float = 5.0) -> bool:
|
||||||
|
"""判断是否处于活跃状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
threshold: 活跃阈值(秒)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否活跃
|
||||||
|
"""
|
||||||
|
if not self.state_info.last_message_time:
|
||||||
|
return False
|
||||||
|
|
||||||
|
current_time = datetime.now().timestamp()
|
||||||
|
return (current_time - self.state_info.last_message_time) <= threshold
|
||||||
134
src/plugins/PFC/message_storage.py
Normal file
134
src/plugins/PFC/message_storage.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from src.common.database import db
|
||||||
|
|
||||||
|
class MessageStorage(ABC):
|
||||||
|
"""消息存储接口"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
"""获取指定消息ID之后的所有消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
message_id: 消息ID,如果为None则获取所有消息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: 消息列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||||
|
"""获取指定时间点之前的消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
time_point: 时间戳
|
||||||
|
limit: 最大消息数量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[Dict[str, Any]]: 消息列表
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||||
|
"""检查是否有新消息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
chat_id: 聊天ID
|
||||||
|
after_time: 时间戳
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: 是否有新消息
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class MongoDBMessageStorage(MessageStorage):
|
||||||
|
"""MongoDB消息存储实现"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.db = db
|
||||||
|
|
||||||
|
async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
query = {"chat_id": chat_id}
|
||||||
|
|
||||||
|
if message_id:
|
||||||
|
# 获取ID大于message_id的消息
|
||||||
|
last_message = self.db.messages.find_one({"message_id": message_id})
|
||||||
|
if last_message:
|
||||||
|
query["time"] = {"$gt": last_message["time"]}
|
||||||
|
|
||||||
|
return list(
|
||||||
|
self.db.messages.find(query).sort("time", 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||||
|
query = {
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"time": {"$lt": time_point}
|
||||||
|
}
|
||||||
|
|
||||||
|
messages = list(
|
||||||
|
self.db.messages.find(query).sort("time", -1).limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将消息按时间正序排列
|
||||||
|
messages.reverse()
|
||||||
|
return messages
|
||||||
|
|
||||||
|
async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||||
|
query = {
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"time": {"$gt": after_time}
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.db.messages.find_one(query) is not None
|
||||||
|
|
||||||
|
# # 创建一个内存消息存储实现,用于测试
|
||||||
|
# class InMemoryMessageStorage(MessageStorage):
|
||||||
|
# """内存消息存储实现,主要用于测试"""
|
||||||
|
|
||||||
|
# def __init__(self):
|
||||||
|
# self.messages: Dict[str, List[Dict[str, Any]]] = {}
|
||||||
|
|
||||||
|
# async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||||
|
# if chat_id not in self.messages:
|
||||||
|
# return []
|
||||||
|
|
||||||
|
# messages = self.messages[chat_id]
|
||||||
|
# if not message_id:
|
||||||
|
# return messages
|
||||||
|
|
||||||
|
# # 找到message_id的索引
|
||||||
|
# try:
|
||||||
|
# index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id)
|
||||||
|
# return messages[index + 1:]
|
||||||
|
# except StopIteration:
|
||||||
|
# return []
|
||||||
|
|
||||||
|
# async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]:
|
||||||
|
# if chat_id not in self.messages:
|
||||||
|
# return []
|
||||||
|
|
||||||
|
# messages = [
|
||||||
|
# m for m in self.messages[chat_id]
|
||||||
|
# if m["time"] < time_point
|
||||||
|
# ]
|
||||||
|
|
||||||
|
# return messages[-limit:]
|
||||||
|
|
||||||
|
# async def has_new_messages(self, chat_id: str, after_time: float) -> bool:
|
||||||
|
# if chat_id not in self.messages:
|
||||||
|
# return False
|
||||||
|
|
||||||
|
# return any(m["time"] > after_time for m in self.messages[chat_id])
|
||||||
|
|
||||||
|
# # 测试辅助方法
|
||||||
|
# def add_message(self, chat_id: str, message: Dict[str, Any]):
|
||||||
|
# """添加测试消息"""
|
||||||
|
# if chat_id not in self.messages:
|
||||||
|
# self.messages[chat_id] = []
|
||||||
|
# self.messages[chat_id].append(message)
|
||||||
|
# self.messages[chat_id].sort(key=lambda m: m["time"])
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
#Prefrontal cortex
|
#Prefrontal cortex
|
||||||
import datetime
|
import datetime
|
||||||
import asyncio
|
import asyncio
|
||||||
from typing import List, Optional, Dict, Any, Tuple, Literal
|
from typing import List, Optional, Dict, Any, Tuple, Literal, Set
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from ..chat.chat_stream import ChatStream
|
from ..chat.chat_stream import ChatStream
|
||||||
@@ -19,7 +19,9 @@ from .pfc_KnowledgeFetcher import KnowledgeFetcher
|
|||||||
from .reply_checker import ReplyChecker
|
from .reply_checker import ReplyChecker
|
||||||
from .pfc_utils import get_items_from_json
|
from .pfc_utils import get_items_from_json
|
||||||
from src.individuality.individuality import Individuality
|
from src.individuality.individuality import Individuality
|
||||||
|
from .chat_states import NotificationHandler, Notification, NotificationType
|
||||||
import time
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
logger = get_module_logger("pfc")
|
logger = get_module_logger("pfc")
|
||||||
|
|
||||||
@@ -42,6 +44,99 @@ class ConversationState(Enum):
|
|||||||
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
|
ActionType = Literal["direct_reply", "fetch_knowledge", "wait"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DecisionInfo:
|
||||||
|
"""决策信息类,用于收集和管理来自chat_observer的通知信息"""
|
||||||
|
|
||||||
|
# 消息相关
|
||||||
|
last_message_time: Optional[float] = None
|
||||||
|
last_message_content: Optional[str] = None
|
||||||
|
last_message_sender: Optional[str] = None
|
||||||
|
new_messages_count: int = 0
|
||||||
|
unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list)
|
||||||
|
|
||||||
|
# 对话状态
|
||||||
|
is_cold_chat: bool = False
|
||||||
|
cold_chat_duration: float = 0.0
|
||||||
|
last_bot_speak_time: Optional[float] = None
|
||||||
|
last_user_speak_time: Optional[float] = None
|
||||||
|
|
||||||
|
# 对话参与者
|
||||||
|
active_users: Set[str] = field(default_factory=set)
|
||||||
|
bot_id: str = field(default="")
|
||||||
|
|
||||||
|
def update_from_message(self, message: Dict[str, Any]):
|
||||||
|
"""从消息更新信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
message: 消息数据
|
||||||
|
"""
|
||||||
|
self.last_message_time = message["time"]
|
||||||
|
self.last_message_content = message.get("processed_plain_text", "")
|
||||||
|
|
||||||
|
user_info = UserInfo.from_dict(message.get("user_info", {}))
|
||||||
|
self.last_message_sender = user_info.user_id
|
||||||
|
|
||||||
|
if user_info.user_id == self.bot_id:
|
||||||
|
self.last_bot_speak_time = message["time"]
|
||||||
|
else:
|
||||||
|
self.last_user_speak_time = message["time"]
|
||||||
|
self.active_users.add(user_info.user_id)
|
||||||
|
|
||||||
|
self.new_messages_count += 1
|
||||||
|
self.unprocessed_messages.append(message)
|
||||||
|
|
||||||
|
def update_cold_chat_status(self, is_cold: bool, current_time: float):
|
||||||
|
"""更新冷场状态
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_cold: 是否冷场
|
||||||
|
current_time: 当前时间
|
||||||
|
"""
|
||||||
|
self.is_cold_chat = is_cold
|
||||||
|
if is_cold and self.last_message_time:
|
||||||
|
self.cold_chat_duration = current_time - self.last_message_time
|
||||||
|
|
||||||
|
def get_active_duration(self) -> float:
|
||||||
|
"""获取当前活跃时长
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float: 最后一条消息到现在的时长(秒)
|
||||||
|
"""
|
||||||
|
if not self.last_message_time:
|
||||||
|
return 0.0
|
||||||
|
return time.time() - self.last_message_time
|
||||||
|
|
||||||
|
def get_user_response_time(self) -> Optional[float]:
|
||||||
|
"""获取用户响应时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[float]: 用户最后发言到现在的时长(秒),如果没有用户发言则返回None
|
||||||
|
"""
|
||||||
|
if not self.last_user_speak_time:
|
||||||
|
return None
|
||||||
|
return time.time() - self.last_user_speak_time
|
||||||
|
|
||||||
|
def get_bot_response_time(self) -> Optional[float]:
|
||||||
|
"""获取机器人响应时间
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[float]: 机器人最后发言到现在的时长(秒),如果没有机器人发言则返回None
|
||||||
|
"""
|
||||||
|
if not self.last_bot_speak_time:
|
||||||
|
return None
|
||||||
|
return time.time() - self.last_bot_speak_time
|
||||||
|
|
||||||
|
def clear_unprocessed_messages(self):
|
||||||
|
"""清空未处理消息列表"""
|
||||||
|
self.unprocessed_messages.clear()
|
||||||
|
self.new_messages_count = 0
|
||||||
|
|
||||||
|
|
||||||
|
# Forward reference for type hints
|
||||||
|
DecisionInfoType = DecisionInfo
|
||||||
|
|
||||||
|
|
||||||
class ActionPlanner:
|
class ActionPlanner:
|
||||||
"""行动规划器"""
|
"""行动规划器"""
|
||||||
|
|
||||||
@@ -62,22 +157,24 @@ class ActionPlanner:
|
|||||||
method: str,
|
method: str,
|
||||||
reasoning: str,
|
reasoning: str,
|
||||||
action_history: List[Dict[str, str]] = None,
|
action_history: List[Dict[str, str]] = None,
|
||||||
chat_observer: Optional[ChatObserver] = None, # 添加chat_observer参数
|
decision_info: DecisionInfoType = None # Use DecisionInfoType here
|
||||||
) -> Tuple[str, str]:
|
) -> Tuple[str, str]:
|
||||||
"""规划下一步行动
|
"""规划下一步行动
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
goal: 对话目标
|
goal: 对话目标
|
||||||
|
method: 实现方法
|
||||||
reasoning: 目标原因
|
reasoning: 目标原因
|
||||||
action_history: 行动历史记录
|
action_history: 行动历史记录
|
||||||
|
decision_info: 决策信息
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Tuple[str, str]: (行动类型, 行动原因)
|
Tuple[str, str]: (行动类型, 行动原因)
|
||||||
"""
|
"""
|
||||||
# 构建提示词
|
# 构建提示词
|
||||||
# 获取最近20条消息
|
logger.debug(f"开始规划行动:当前目标: {goal}")
|
||||||
self.chat_observer.waiting_start_time = time.time()
|
|
||||||
|
|
||||||
|
# 获取最近20条消息
|
||||||
messages = self.chat_observer.get_message_history(limit=20)
|
messages = self.chat_observer.get_message_history(limit=20)
|
||||||
chat_history_text = ""
|
chat_history_text = ""
|
||||||
for msg in messages:
|
for msg in messages:
|
||||||
@@ -92,22 +189,42 @@ class ActionPlanner:
|
|||||||
|
|
||||||
# 构建action历史文本
|
# 构建action历史文本
|
||||||
action_history_text = ""
|
action_history_text = ""
|
||||||
if action_history:
|
if action_history and action_history[-1]['action'] == "direct_reply":
|
||||||
if action_history[-1]['action'] == "direct_reply":
|
action_history_text = "你刚刚发言回复了对方"
|
||||||
action_history_text = "你刚刚发言回复了对方"
|
|
||||||
|
# 构建决策信息文本
|
||||||
|
decision_info_text = ""
|
||||||
|
if decision_info:
|
||||||
|
decision_info_text = "当前对话状态:\n"
|
||||||
|
if decision_info.is_cold_chat:
|
||||||
|
decision_info_text += f"对话处于冷场状态,已持续{int(decision_info.cold_chat_duration)}秒\n"
|
||||||
|
|
||||||
|
if decision_info.new_messages_count > 0:
|
||||||
|
decision_info_text += f"有{decision_info.new_messages_count}条新消息未处理\n"
|
||||||
|
|
||||||
|
user_response_time = decision_info.get_user_response_time()
|
||||||
|
if user_response_time:
|
||||||
|
decision_info_text += f"距离用户上次发言已过去{int(user_response_time)}秒\n"
|
||||||
|
|
||||||
|
bot_response_time = decision_info.get_bot_response_time()
|
||||||
|
if bot_response_time:
|
||||||
|
decision_info_text += f"距离你上次发言已过去{int(bot_response_time)}秒\n"
|
||||||
|
|
||||||
|
if decision_info.active_users:
|
||||||
|
decision_info_text += f"当前活跃用户数: {len(decision_info.active_users)}\n"
|
||||||
|
|
||||||
# 获取时间信息
|
prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下内容,根据信息决定下一步行动:
|
||||||
time_info = self.chat_observer.get_time_info()
|
|
||||||
|
|
||||||
prompt = f"""现在你在参与一场QQ聊天,请分析以下内容,根据信息决定下一步行动:
|
|
||||||
{personality_text}
|
|
||||||
当前对话目标:{goal}
|
当前对话目标:{goal}
|
||||||
实现该对话目标的方式:{method}
|
实现该对话目标的方式:{method}
|
||||||
产生该对话目标的原因:{reasoning}
|
产生该对话目标的原因:{reasoning}
|
||||||
{time_info}
|
|
||||||
|
{decision_info_text}
|
||||||
|
{action_history_text}
|
||||||
|
|
||||||
最近的对话记录:
|
最近的对话记录:
|
||||||
{chat_history_text}
|
{chat_history_text}
|
||||||
{action_history_text}
|
|
||||||
请你接下去想想要你要做什么,可以发言,可以等待,可以倾听,可以调取知识。注意不同行动类型的要求,不要重复发言:
|
请你接下去想想要你要做什么,可以发言,可以等待,可以倾听,可以调取知识。注意不同行动类型的要求,不要重复发言:
|
||||||
行动类型:
|
行动类型:
|
||||||
fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择
|
fetch_knowledge: 需要调取知识,当需要专业知识或特定信息时选择
|
||||||
@@ -413,16 +530,23 @@ class Waiter:
|
|||||||
Returns:
|
Returns:
|
||||||
bool: 是否超时(True表示超时)
|
bool: 是否超时(True表示超时)
|
||||||
"""
|
"""
|
||||||
wait_start_time = self.chat_observer.waiting_start_time
|
# 使用当前时间作为等待开始时间
|
||||||
while not self.chat_observer.new_message_after(wait_start_time):
|
wait_start_time = time.time()
|
||||||
await asyncio.sleep(1)
|
self.chat_observer.waiting_start_time = wait_start_time # 设置等待开始时间
|
||||||
logger.info("等待中...")
|
|
||||||
# 检查是否超过60秒
|
while True:
|
||||||
|
# 检查是否有新消息
|
||||||
|
if self.chat_observer.new_message_after(wait_start_time):
|
||||||
|
logger.info("等待结束,收到新消息")
|
||||||
|
return False
|
||||||
|
|
||||||
|
# 检查是否超时
|
||||||
if time.time() - wait_start_time > 300:
|
if time.time() - wait_start_time > 300:
|
||||||
logger.info("等待超过300秒,结束对话")
|
logger.info("等待超过300秒,结束对话")
|
||||||
return True
|
return True
|
||||||
logger.info("等待结束")
|
|
||||||
return False
|
await asyncio.sleep(1)
|
||||||
|
logger.info("等待中...")
|
||||||
|
|
||||||
|
|
||||||
class ReplyGenerator:
|
class ReplyGenerator:
|
||||||
@@ -519,16 +643,16 @@ class ReplyGenerator:
|
|||||||
try:
|
try:
|
||||||
content, _ = await self.llm.generate_response_async(prompt)
|
content, _ = await self.llm.generate_response_async(prompt)
|
||||||
logger.info(f"生成的回复: {content}")
|
logger.info(f"生成的回复: {content}")
|
||||||
is_new = self.chat_observer.check()
|
# is_new = self.chat_observer.check()
|
||||||
logger.debug(f"再看一眼聊天记录,{'有' if is_new else '没有'}新消息")
|
# logger.debug(f"再看一眼聊天记录,{'有' if is_new else '没有'}新消息")
|
||||||
|
|
||||||
# 如果有新消息,重新生成回复
|
# 如果有新消息,重新生成回复
|
||||||
if is_new:
|
# if is_new:
|
||||||
logger.info("检测到新消息,重新生成回复")
|
# logger.info("检测到新消息,重新生成回复")
|
||||||
return await self.generate(
|
# return await self.generate(
|
||||||
goal, chat_history, knowledge_cache,
|
# goal, chat_history, knowledge_cache,
|
||||||
None, retry_count
|
# None, retry_count
|
||||||
)
|
# )
|
||||||
|
|
||||||
return content
|
return content
|
||||||
|
|
||||||
@@ -555,12 +679,69 @@ class ReplyGenerator:
|
|||||||
return await self.reply_checker.check(reply, goal, retry_count)
|
return await self.reply_checker.check(reply, goal, retry_count)
|
||||||
|
|
||||||
|
|
||||||
|
class PFCNotificationHandler(NotificationHandler):
|
||||||
|
"""PFC的通知处理器"""
|
||||||
|
|
||||||
|
def __init__(self, conversation: 'Conversation'):
|
||||||
|
self.conversation = conversation
|
||||||
|
self.logger = get_module_logger("pfc_notification")
|
||||||
|
self.decision_info = conversation.decision_info
|
||||||
|
|
||||||
|
async def handle_notification(self, notification: Notification):
|
||||||
|
"""处理通知"""
|
||||||
|
try:
|
||||||
|
if not notification or not hasattr(notification, 'data') or notification.data is None:
|
||||||
|
self.logger.error("收到无效的通知:notification 或 data 为空")
|
||||||
|
return
|
||||||
|
|
||||||
|
if notification.type == NotificationType.NEW_MESSAGE:
|
||||||
|
# 处理新消息通知
|
||||||
|
message = notification.data
|
||||||
|
if not isinstance(message, dict):
|
||||||
|
self.logger.error(f"无效的消息格式: {type(message)}")
|
||||||
|
return
|
||||||
|
|
||||||
|
content = message.get('content', '')
|
||||||
|
self.logger.info(f"收到新消息通知: {content[:30] if content else ''}...")
|
||||||
|
|
||||||
|
# 更新决策信息
|
||||||
|
try:
|
||||||
|
self.decision_info.update_from_message(message)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"更新决策信息失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 触发对话系统更新
|
||||||
|
self.conversation.chat_observer.trigger_update()
|
||||||
|
|
||||||
|
elif notification.type == NotificationType.COLD_CHAT:
|
||||||
|
# 处理冷场通知
|
||||||
|
try:
|
||||||
|
is_cold = bool(notification.data.get("is_cold", False))
|
||||||
|
# 更新决策信息
|
||||||
|
self.decision_info.update_cold_chat_status(is_cold, time.time())
|
||||||
|
|
||||||
|
if is_cold:
|
||||||
|
self.logger.info("检测到对话冷场")
|
||||||
|
else:
|
||||||
|
self.logger.info("对话恢复活跃")
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"处理冷场状态失败: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.error(f"处理通知时出错: {str(e)}")
|
||||||
|
# 添加更详细的错误信息
|
||||||
|
self.logger.error(f"通知类型: {getattr(notification, 'type', None)}")
|
||||||
|
self.logger.error(f"通知数据: {getattr(notification, 'data', None)}")
|
||||||
|
|
||||||
|
|
||||||
class Conversation:
|
class Conversation:
|
||||||
# 类级别的实例管理
|
# 类级别的实例管理
|
||||||
_instances: Dict[str, 'Conversation'] = {}
|
_instances: Dict[str, 'Conversation'] = {}
|
||||||
_instance_lock = asyncio.Lock() # 类级别的全局锁
|
_instance_lock = asyncio.Lock()
|
||||||
_init_events: Dict[str, asyncio.Event] = {} # 初始化完成事件
|
_init_events: Dict[str, asyncio.Event] = {}
|
||||||
_initializing: Dict[str, bool] = {} # 标记是否正在初始化
|
_initializing: Dict[str, bool] = {}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_instance(cls, stream_id: str) -> Optional['Conversation']:
|
async def get_instance(cls, stream_id: str) -> Optional['Conversation']:
|
||||||
@@ -573,102 +754,89 @@ class Conversation:
|
|||||||
Optional[Conversation]: 对话实例,如果创建或等待失败则返回None
|
Optional[Conversation]: 对话实例,如果创建或等待失败则返回None
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用全局锁来确保线程安全
|
# 检查是否已经有实例
|
||||||
async with cls._instance_lock:
|
if stream_id in cls._instances:
|
||||||
# 如果已经在初始化中,等待初始化完成
|
|
||||||
if stream_id in cls._initializing and cls._initializing[stream_id]:
|
|
||||||
# 释放锁等待初始化
|
|
||||||
cls._instance_lock.release()
|
|
||||||
try:
|
|
||||||
await asyncio.wait_for(cls._init_events[stream_id].wait(), timeout=5.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
logger.error(f"等待实例 {stream_id} 初始化超时")
|
|
||||||
return None
|
|
||||||
finally:
|
|
||||||
await cls._instance_lock.acquire()
|
|
||||||
|
|
||||||
# 如果实例不存在,创建新实例
|
|
||||||
if stream_id not in cls._instances:
|
|
||||||
cls._instances[stream_id] = cls(stream_id)
|
|
||||||
cls._init_events[stream_id] = asyncio.Event()
|
|
||||||
cls._initializing[stream_id] = True
|
|
||||||
logger.info(f"创建新的对话实例: {stream_id}")
|
|
||||||
|
|
||||||
return cls._instances[stream_id]
|
return cls._instances[stream_id]
|
||||||
|
|
||||||
|
async with cls._instance_lock:
|
||||||
|
# 再次检查,防止在获取锁的过程中其他线程创建了实例
|
||||||
|
if stream_id in cls._instances:
|
||||||
|
return cls._instances[stream_id]
|
||||||
|
|
||||||
|
# 如果正在初始化,等待初始化完成
|
||||||
|
if stream_id in cls._initializing and cls._initializing[stream_id]:
|
||||||
|
event = cls._init_events.get(stream_id)
|
||||||
|
if event:
|
||||||
|
try:
|
||||||
|
# 在等待之前释放锁
|
||||||
|
cls._instance_lock.release()
|
||||||
|
await asyncio.wait_for(event.wait(), timeout=10.0) # 增加超时时间到10秒
|
||||||
|
# 重新获取锁
|
||||||
|
await cls._instance_lock.acquire()
|
||||||
|
if stream_id in cls._instances:
|
||||||
|
return cls._instances[stream_id]
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(f"等待实例 {stream_id} 初始化超时")
|
||||||
|
# 清理超时的初始化状态
|
||||||
|
cls._initializing[stream_id] = False
|
||||||
|
if stream_id in cls._init_events:
|
||||||
|
del cls._init_events[stream_id]
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 创建新实例
|
||||||
|
logger.info(f"创建新的对话实例: {stream_id}")
|
||||||
|
cls._initializing[stream_id] = True
|
||||||
|
cls._init_events[stream_id] = asyncio.Event()
|
||||||
|
|
||||||
|
# 在锁保护下创建实例
|
||||||
|
instance = cls(stream_id)
|
||||||
|
cls._instances[stream_id] = instance
|
||||||
|
|
||||||
|
# 启动实例初始化(在后台运行)
|
||||||
|
asyncio.create_task(instance._initialize())
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取对话实例失败: {e}")
|
logger.error(f"获取对话实例失败: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
async def _initialize(self):
|
||||||
|
"""初始化实例(在后台运行)"""
|
||||||
|
try:
|
||||||
|
logger.info(f"开始初始化对话实例: {self.stream_id}")
|
||||||
|
self.chat_observer.start() # 启动观察器
|
||||||
|
await asyncio.sleep(1) # 给观察器一些启动时间
|
||||||
|
|
||||||
|
# 获取初始目标
|
||||||
|
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
||||||
|
|
||||||
|
# 标记初始化完成
|
||||||
|
self.__class__._initializing[self.stream_id] = False
|
||||||
|
if self.stream_id in self.__class__._init_events:
|
||||||
|
self.__class__._init_events[self.stream_id].set()
|
||||||
|
|
||||||
|
# 启动对话循环
|
||||||
|
asyncio.create_task(self._conversation_loop())
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"初始化对话实例失败: {e}")
|
||||||
|
# 清理失败的初始化
|
||||||
|
self.__class__._initializing[self.stream_id] = False
|
||||||
|
if self.stream_id in self.__class__._init_events:
|
||||||
|
self.__class__._init_events[self.stream_id].set()
|
||||||
|
if self.stream_id in self.__class__._instances:
|
||||||
|
del self.__class__._instances[self.stream_id]
|
||||||
|
|
||||||
@classmethod
|
|
||||||
async def remove_instance(cls, stream_id: str):
|
|
||||||
"""删除对话实例
|
|
||||||
|
|
||||||
Args:
|
|
||||||
stream_id: 聊天流ID
|
|
||||||
"""
|
|
||||||
async with cls._instance_lock:
|
|
||||||
if stream_id in cls._instances:
|
|
||||||
# 停止相关组件
|
|
||||||
instance = cls._instances[stream_id]
|
|
||||||
instance.chat_observer.stop()
|
|
||||||
# 删除实例
|
|
||||||
del cls._instances[stream_id]
|
|
||||||
if stream_id in cls._init_events:
|
|
||||||
del cls._init_events[stream_id]
|
|
||||||
if stream_id in cls._initializing:
|
|
||||||
del cls._initializing[stream_id]
|
|
||||||
logger.info(f"已删除对话实例 {stream_id}")
|
|
||||||
|
|
||||||
def __init__(self, stream_id: str):
|
|
||||||
"""初始化对话系统"""
|
|
||||||
self.stream_id = stream_id
|
|
||||||
self.state = ConversationState.INIT
|
|
||||||
self.current_goal: Optional[str] = None
|
|
||||||
self.current_method: Optional[str] = None
|
|
||||||
self.goal_reasoning: Optional[str] = None
|
|
||||||
self.generated_reply: Optional[str] = None
|
|
||||||
self.should_continue = True
|
|
||||||
|
|
||||||
# 初始化聊天观察器
|
|
||||||
self.chat_observer = ChatObserver.get_instance(stream_id)
|
|
||||||
|
|
||||||
# 添加action历史记录
|
|
||||||
self.action_history: List[Dict[str, str]] = []
|
|
||||||
|
|
||||||
# 知识缓存
|
|
||||||
self.knowledge_cache: Dict[str, str] = {} # 确保初始化为字典
|
|
||||||
|
|
||||||
# 初始化各个组件
|
|
||||||
self.goal_analyzer = GoalAnalyzer(self.stream_id)
|
|
||||||
self.action_planner = ActionPlanner(self.stream_id)
|
|
||||||
self.reply_generator = ReplyGenerator(self.stream_id)
|
|
||||||
self.knowledge_fetcher = KnowledgeFetcher()
|
|
||||||
self.direct_sender = DirectMessageSender()
|
|
||||||
self.waiter = Waiter(self.stream_id)
|
|
||||||
|
|
||||||
# 创建聊天流
|
|
||||||
self.chat_stream = chat_manager.get_stream(self.stream_id)
|
|
||||||
|
|
||||||
def _clear_knowledge_cache(self):
|
|
||||||
"""清空知识缓存"""
|
|
||||||
self.knowledge_cache.clear() # 使用clear方法清空字典
|
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
"""开始对话流程"""
|
"""开始对话流程"""
|
||||||
try:
|
try:
|
||||||
logger.info("对话系统启动")
|
logger.info("对话系统启动")
|
||||||
self.should_continue = True
|
self.should_continue = True
|
||||||
self.chat_observer.start() # 启动观察器
|
|
||||||
await asyncio.sleep(1)
|
|
||||||
# 启动对话循环
|
|
||||||
await self._conversation_loop()
|
await self._conversation_loop()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"启动对话系统失败: {e}")
|
logger.error(f"启动对话系统失败: {e}")
|
||||||
raise
|
raise
|
||||||
finally:
|
|
||||||
# 标记初始化完成
|
|
||||||
self._init_events[self.stream_id].set()
|
|
||||||
self._initializing[self.stream_id] = False
|
|
||||||
|
|
||||||
async def _conversation_loop(self):
|
async def _conversation_loop(self):
|
||||||
"""对话循环"""
|
"""对话循环"""
|
||||||
@@ -681,17 +849,21 @@ class Conversation:
|
|||||||
if not await self.chat_observer.wait_for_update():
|
if not await self.chat_observer.wait_for_update():
|
||||||
logger.warning("等待消息更新超时")
|
logger.warning("等待消息更新超时")
|
||||||
|
|
||||||
|
# 使用决策信息来辅助行动规划
|
||||||
action, reason = await self.action_planner.plan(
|
action, reason = await self.action_planner.plan(
|
||||||
self.current_goal,
|
self.current_goal,
|
||||||
self.current_method,
|
self.current_method,
|
||||||
self.goal_reasoning,
|
self.goal_reasoning,
|
||||||
self.action_history, # 传入action历史
|
self.action_history,
|
||||||
self.chat_observer # 传入chat_observer
|
self.decision_info # 传入决策信息
|
||||||
)
|
)
|
||||||
|
|
||||||
# 执行行动
|
# 执行行动
|
||||||
await self._handle_action(action, reason)
|
await self._handle_action(action, reason)
|
||||||
|
|
||||||
|
# 清理已处理的消息
|
||||||
|
self.decision_info.clear_unprocessed_messages()
|
||||||
|
|
||||||
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message:
|
||||||
"""将消息字典转换为Message对象"""
|
"""将消息字典转换为Message对象"""
|
||||||
try:
|
try:
|
||||||
@@ -742,87 +914,6 @@ class Conversation:
|
|||||||
self.current_goal
|
self.current_goal
|
||||||
)
|
)
|
||||||
|
|
||||||
if not is_suitable:
|
|
||||||
logger.warning(f"生成的回复不合适,原因: {reason}")
|
|
||||||
if need_replan:
|
|
||||||
# 尝试切换到其他备选目标
|
|
||||||
alternative_goals = await self.goal_analyzer.get_alternative_goals()
|
|
||||||
if alternative_goals:
|
|
||||||
# 有备选目标,尝试使用下一个目标
|
|
||||||
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
|
|
||||||
logger.info(f"切换到备选目标: {self.current_goal}")
|
|
||||||
# 使用新目标生成回复
|
|
||||||
self.generated_reply = await self.reply_generator.generate(
|
|
||||||
self.current_goal,
|
|
||||||
self.current_method,
|
|
||||||
[self._convert_to_message(msg) for msg in messages],
|
|
||||||
self.knowledge_cache
|
|
||||||
)
|
|
||||||
# 检查使用新目标生成的回复是否合适
|
|
||||||
is_suitable, reason, _ = await self.reply_generator.check_reply(
|
|
||||||
self.generated_reply,
|
|
||||||
self.current_goal
|
|
||||||
)
|
|
||||||
if is_suitable:
|
|
||||||
# 如果新目标的回复合适,调整目标优先级
|
|
||||||
await self.goal_analyzer._update_goals(
|
|
||||||
self.current_goal,
|
|
||||||
self.current_method,
|
|
||||||
self.goal_reasoning
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 如果新目标还是不合适,重新思考目标
|
|
||||||
self.state = ConversationState.RETHINKING
|
|
||||||
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# 没有备选目标,重新分析
|
|
||||||
self.state = ConversationState.RETHINKING
|
|
||||||
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# 重新生成回复
|
|
||||||
self.generated_reply = await self.reply_generator.generate(
|
|
||||||
self.current_goal,
|
|
||||||
self.current_method,
|
|
||||||
[self._convert_to_message(msg) for msg in messages],
|
|
||||||
self.knowledge_cache,
|
|
||||||
self.generated_reply # 将不合适的回复作为previous_reply传入
|
|
||||||
)
|
|
||||||
|
|
||||||
while self.chat_observer.check():
|
|
||||||
if not is_suitable:
|
|
||||||
logger.warning(f"生成的回复不合适,原因: {reason}")
|
|
||||||
if need_replan:
|
|
||||||
# 尝试切换到其他备选目标
|
|
||||||
alternative_goals = await self.goal_analyzer.get_alternative_goals()
|
|
||||||
if alternative_goals:
|
|
||||||
# 有备选目标,尝试使用下一个目标
|
|
||||||
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
|
|
||||||
logger.info(f"切换到备选目标: {self.current_goal}")
|
|
||||||
# 使用新目标生成回复
|
|
||||||
self.generated_reply = await self.reply_generator.generate(
|
|
||||||
self.current_goal,
|
|
||||||
self.current_method,
|
|
||||||
[self._convert_to_message(msg) for msg in messages],
|
|
||||||
self.knowledge_cache
|
|
||||||
)
|
|
||||||
is_suitable = True # 假设使用新目标后回复是合适的
|
|
||||||
else:
|
|
||||||
# 没有备选目标,重新分析
|
|
||||||
self.state = ConversationState.RETHINKING
|
|
||||||
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# 重新生成回复
|
|
||||||
self.generated_reply = await self.reply_generator.generate(
|
|
||||||
self.current_goal,
|
|
||||||
self.current_method,
|
|
||||||
[self._convert_to_message(msg) for msg in messages],
|
|
||||||
self.knowledge_cache,
|
|
||||||
self.generated_reply # 将不合适的回复作为previous_reply传入
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._send_reply()
|
await self._send_reply()
|
||||||
|
|
||||||
elif action == "fetch_knowledge":
|
elif action == "fetch_knowledge":
|
||||||
@@ -836,59 +927,6 @@ class Conversation:
|
|||||||
|
|
||||||
if knowledge != "未找到相关知识":
|
if knowledge != "未找到相关知识":
|
||||||
self.knowledge_cache[sources] = knowledge
|
self.knowledge_cache[sources] = knowledge
|
||||||
|
|
||||||
self.generated_reply = await self.reply_generator.generate(
|
|
||||||
self.current_goal,
|
|
||||||
self.current_method,
|
|
||||||
[self._convert_to_message(msg) for msg in messages],
|
|
||||||
self.knowledge_cache
|
|
||||||
)
|
|
||||||
|
|
||||||
# 检查回复是否合适
|
|
||||||
is_suitable, reason, need_replan = await self.reply_generator.check_reply(
|
|
||||||
self.generated_reply,
|
|
||||||
self.current_goal
|
|
||||||
)
|
|
||||||
|
|
||||||
if not is_suitable:
|
|
||||||
logger.warning(f"生成的回复不合适,原因: {reason}")
|
|
||||||
if need_replan:
|
|
||||||
# 尝试切换到其他备选目标
|
|
||||||
alternative_goals = await self.goal_analyzer.get_alternative_goals()
|
|
||||||
if alternative_goals:
|
|
||||||
# 有备选目标,尝试使用
|
|
||||||
self.current_goal, self.current_method, self.goal_reasoning = alternative_goals[0]
|
|
||||||
logger.info(f"切换到备选目标: {self.current_goal}")
|
|
||||||
# 使用新目标获取知识并生成回复
|
|
||||||
knowledge, sources = await self.knowledge_fetcher.fetch(
|
|
||||||
self.current_goal,
|
|
||||||
[self._convert_to_message(msg) for msg in messages]
|
|
||||||
)
|
|
||||||
if knowledge != "未找到相关知识":
|
|
||||||
self.knowledge_cache[sources] = knowledge
|
|
||||||
|
|
||||||
self.generated_reply = await self.reply_generator.generate(
|
|
||||||
self.current_goal,
|
|
||||||
self.current_method,
|
|
||||||
[self._convert_to_message(msg) for msg in messages],
|
|
||||||
self.knowledge_cache
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 没有备选目标,重新分析
|
|
||||||
self.state = ConversationState.RETHINKING
|
|
||||||
self.current_goal, self.current_method, self.goal_reasoning = await self.goal_analyzer.analyze_goal()
|
|
||||||
return
|
|
||||||
else:
|
|
||||||
# 重新生成回复
|
|
||||||
self.generated_reply = await self.reply_generator.generate(
|
|
||||||
self.current_goal,
|
|
||||||
self.current_method,
|
|
||||||
[self._convert_to_message(msg) for msg in messages],
|
|
||||||
self.knowledge_cache,
|
|
||||||
self.generated_reply # 将不合适的回复作为previous_reply传入
|
|
||||||
)
|
|
||||||
|
|
||||||
await self._send_reply()
|
|
||||||
|
|
||||||
elif action == "rethink_goal":
|
elif action == "rethink_goal":
|
||||||
self.state = ConversationState.RETHINKING
|
self.state = ConversationState.RETHINKING
|
||||||
|
|||||||
@@ -25,9 +25,9 @@ config_config = LogConfig(
|
|||||||
logger = get_module_logger("config", config=config_config)
|
logger = get_module_logger("config", config=config_config)
|
||||||
|
|
||||||
#考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
#考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
|
||||||
is_test = False
|
is_test = True
|
||||||
mai_version_main = "0.6.1"
|
mai_version_main = "0.6.2"
|
||||||
mai_version_fix = ""
|
mai_version_fix = "snapshot-1"
|
||||||
if mai_version_fix:
|
if mai_version_fix:
|
||||||
if is_test:
|
if is_test:
|
||||||
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
|
mai_version = f"test-{mai_version_main}-{mai_version_fix}"
|
||||||
@@ -441,6 +441,7 @@ class BotConfig:
|
|||||||
config.emoji_response_penalty = willing_config.get(
|
config.emoji_response_penalty = willing_config.get(
|
||||||
"emoji_response_penalty", config.emoji_response_penalty
|
"emoji_response_penalty", config.emoji_response_penalty
|
||||||
)
|
)
|
||||||
|
if config.INNER_VERSION in SpecifierSet(">=1.2.5"):
|
||||||
config.mentioned_bot_inevitable_reply = willing_config.get(
|
config.mentioned_bot_inevitable_reply = willing_config.get(
|
||||||
"mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply
|
"mentioned_bot_inevitable_reply", config.mentioned_bot_inevitable_reply
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[inner]
|
[inner]
|
||||||
version = "1.2.4"
|
version = "1.2.5"
|
||||||
|
|
||||||
|
|
||||||
#以下是给开发人员阅读的,一般用户不需要阅读
|
#以下是给开发人员阅读的,一般用户不需要阅读
|
||||||
|
|||||||
Reference in New Issue
Block a user