diff --git a/bot.py b/bot.py index ca214967e..5b12b0389 100644 --- a/bot.py +++ b/bot.py @@ -196,7 +196,7 @@ def raw_main(): # 安装崩溃日志处理器 install_crash_handler() - + check_eula() print("检查EULA和隐私条款完成") easter_egg() diff --git a/src/common/crash_logger.py b/src/common/crash_logger.py index 658e1bb02..d1e4fb51f 100644 --- a/src/common/crash_logger.py +++ b/src/common/crash_logger.py @@ -4,69 +4,66 @@ import logging from pathlib import Path from logging.handlers import RotatingFileHandler + def setup_crash_logger(): """设置崩溃日志记录器""" # 创建logs/crash目录(如果不存在) crash_log_dir = Path("logs/crash") crash_log_dir.mkdir(parents=True, exist_ok=True) - + # 创建日志记录器 - crash_logger = logging.getLogger('crash_logger') + crash_logger = logging.getLogger("crash_logger") crash_logger.setLevel(logging.ERROR) - + # 设置日志格式 formatter = logging.Formatter( - '%(asctime)s - %(name)s - %(levelname)s\n' - '异常类型: %(exc_info)s\n' - '详细信息:\n%(message)s\n' - '-------------------\n' + "%(asctime)s - %(name)s - %(levelname)s\n异常类型: %(exc_info)s\n详细信息:\n%(message)s\n-------------------\n" ) - + # 创建按大小轮转的文件处理器(最大10MB,保留5个备份) log_file = crash_log_dir / "crash.log" file_handler = RotatingFileHandler( log_file, - maxBytes=10*1024*1024, # 10MB + maxBytes=10 * 1024 * 1024, # 10MB backupCount=5, - encoding='utf-8' + encoding="utf-8", ) file_handler.setFormatter(formatter) crash_logger.addHandler(file_handler) - + return crash_logger + def log_crash(exc_type, exc_value, exc_traceback): """记录崩溃信息到日志文件""" if exc_type is None: return - + # 获取崩溃日志记录器 - crash_logger = logging.getLogger('crash_logger') - + crash_logger = logging.getLogger("crash_logger") + # 获取完整的异常堆栈信息 - stack_trace = ''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)) - + stack_trace = "".join(traceback.format_exception(exc_type, exc_value, exc_traceback)) + # 记录崩溃信息 - crash_logger.error( - stack_trace, - exc_info=(exc_type, exc_value, exc_traceback) - ) + crash_logger.error(stack_trace, exc_info=(exc_type, exc_value, exc_traceback)) + def install_crash_handler(): """安装全局异常处理器""" # 设置崩溃日志记录器 setup_crash_logger() - + # 保存原始的异常处理器 original_hook = sys.excepthook - + def exception_handler(exc_type, exc_value, exc_traceback): """全局异常处理器""" # 记录崩溃信息 log_crash(exc_type, exc_value, exc_traceback) - + # 调用原始的异常处理器 original_hook(exc_type, exc_value, exc_traceback) - + # 设置全局异常处理器 - sys.excepthook = exception_handler \ No newline at end of file + sys.excepthook = exception_handler diff --git a/src/common/server.py b/src/common/server.py new file mode 100644 index 000000000..fd1f3ff18 --- /dev/null +++ b/src/common/server.py @@ -0,0 +1,73 @@ +from fastapi import FastAPI, APIRouter +from typing import Optional, Union +from uvicorn import Config, Server as UvicornServer +import os + + +class Server: + def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"): + self.app = FastAPI(title=app_name) + self._host: str = "127.0.0.1" + self._port: int = 8080 + self._server: Optional[UvicornServer] = None + self.set_address(host, port) + + def register_router(self, router: APIRouter, prefix: str = ""): + """注册路由 + + APIRouter 用于对相关的路由端点进行分组和模块化管理: + 1. 可以将相关的端点组织在一起,便于管理 + 2. 支持添加统一的路由前缀 + 3. 可以为一组路由添加共同的依赖项、标签等 + + 示例: + router = APIRouter() + + @router.get("/users") + def get_users(): + return {"users": [...]} + + @router.post("/users") + def create_user(): + return {"msg": "user created"} + + # 注册路由,添加前缀 "/api/v1" + server.register_router(router, prefix="/api/v1") + """ + self.app.include_router(router, prefix=prefix) + + def set_address(self, host: Optional[str] = None, port: Optional[int] = None): + """设置服务器地址和端口""" + if host: + self._host = host + if port: + self._port = port + + async def run(self): + """启动服务器""" + config = Config(app=self.app, host=self._host, port=self._port) + self._server = UvicornServer(config=config) + try: + await self._server.serve() + except KeyboardInterrupt: + await self.shutdown() + raise + except Exception as e: + await self.shutdown() + raise RuntimeError(f"服务器运行错误: {str(e)}") from e + finally: + await self.shutdown() + + async def shutdown(self): + """安全关闭服务器""" + if self._server: + self._server.should_exit = True + await self._server.shutdown() + self._server = None + + def get_app(self) -> FastAPI: + """获取 FastAPI 实例""" + return self.app + + +global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"])) diff --git a/src/main.py b/src/main.py index aa6f908bf..d94cfce64 100644 --- a/src/main.py +++ b/src/main.py @@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot from .common.logger import get_module_logger from .plugins.remote import heartbeat_thread # noqa: F401 from .individuality.individuality import Individuality - +from .common.server import global_server logger = get_module_logger("main") @@ -33,6 +33,7 @@ class MainSystem: from .plugins.message import global_api self.app = global_api + self.server = global_server async def initialize(self): """初始化系统组件""" @@ -126,6 +127,7 @@ class MainSystem: emoji_manager.start_periodic_check_register(), # emoji_manager.start_periodic_register(), self.app.run(), + self.server.run(), ] await asyncio.gather(*tasks) diff --git a/src/plugins/PFC/action_planner.py b/src/plugins/PFC/action_planner.py index ad69fea1d..43b0749a1 100644 --- a/src/plugins/PFC/action_planner.py +++ b/src/plugins/PFC/action_planner.py @@ -10,6 +10,7 @@ from .conversation_info import ConversationInfo logger = get_module_logger("action_planner") + class ActionPlannerInfo: def __init__(self): self.done_action = [] @@ -20,68 +21,57 @@ class ActionPlannerInfo: class ActionPlanner: """行动规划器""" - + def __init__(self, stream_id: str): self.llm = LLM_request( - model=global_config.llm_normal, - temperature=0.7, - max_tokens=1000, - request_type="action_planning" + model=global_config.llm_normal, temperature=0.7, max_tokens=1000, request_type="action_planning" ) - self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) + self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.name = global_config.BOT_NICKNAME self.chat_observer = ChatObserver.get_instance(stream_id) - - async def plan( - self, - observation_info: ObservationInfo, - conversation_info: ConversationInfo - ) -> Tuple[str, str]: + + async def plan(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> Tuple[str, str]: """规划下一步行动 - + Args: observation_info: 决策信息 conversation_info: 对话信息 - + Returns: Tuple[str, str]: (行动类型, 行动原因) """ # 构建提示词 logger.debug(f"开始规划行动:当前目标: {conversation_info.goal_list}") - - #构建对话目标 + + # 构建对话目标 if conversation_info.goal_list: goal, reasoning = conversation_info.goal_list[-1] else: goal = "目前没有明确对话目标" reasoning = "目前没有明确对话目标,最好思考一个对话目标" - - + # 获取聊天历史记录 chat_history_list = observation_info.chat_history chat_history_text = "" for msg in chat_history_list: chat_history_text += f"{msg}\n" - + if observation_info.new_messages_count > 0: new_messages_list = observation_info.unprocessed_messages - + chat_history_text += f"有{observation_info.new_messages_count}条新消息:\n" for msg in new_messages_list: chat_history_text += f"{msg}\n" - + observation_info.clear_unprocessed_messages() - - + personality_text = f"你的名字是{self.name},{self.personality_info}" - + # 构建action历史文本 action_history_list = conversation_info.done_action action_history_text = "你之前做的事情是:" for action in action_history_list: action_history_text += f"{action}\n" - - prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下内容,根据信息决定下一步行动: @@ -111,29 +101,27 @@ rethink_goal: 重新思考对话目标,当发现对话目标不合适时选择 try: content, _ = await self.llm.generate_response_async(prompt) logger.debug(f"LLM原始返回内容: {content}") - + # 使用简化函数提取JSON内容 success, result = get_items_from_json( - content, - "action", "reason", - default_values={"action": "direct_reply", "reason": "没有明确原因"} + content, "action", "reason", default_values={"action": "direct_reply", "reason": "没有明确原因"} ) - + if not success: return "direct_reply", "JSON解析失败,选择直接回复" - + action = result["action"] reason = result["reason"] - + # 验证action类型 if action not in ["direct_reply", "fetch_knowledge", "wait", "listening", "rethink_goal"]: logger.warning(f"未知的行动类型: {action},默认使用listening") action = "listening" - + logger.info(f"规划的行动: {action}") logger.info(f"行动原因: {reason}") return action, reason - + except Exception as e: logger.error(f"规划行动时出错: {str(e)}") - return "direct_reply", "发生错误,选择直接回复" \ No newline at end of file + return "direct_reply", "发生错误,选择直接回复" diff --git a/src/plugins/PFC/chat_observer.py b/src/plugins/PFC/chat_observer.py index 93618cf2d..c96bc47b1 100644 --- a/src/plugins/PFC/chat_observer.py +++ b/src/plugins/PFC/chat_observer.py @@ -17,20 +17,20 @@ class ChatObserver: _instances: Dict[str, "ChatObserver"] = {} @classmethod - def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> 'ChatObserver': + def get_instance(cls, stream_id: str, message_storage: Optional[MessageStorage] = None) -> "ChatObserver": """获取或创建观察器实例 Args: stream_id: 聊天流ID message_storage: 消息存储实现,如果为None则使用MongoDB实现 - + Returns: ChatObserver: 观察器实例 """ if stream_id not in cls._instances: cls._instances[stream_id] = cls(stream_id, message_storage) return cls._instances[stream_id] - + def __init__(self, stream_id: str, message_storage: Optional[MessageStorage] = None): """初始化观察器 @@ -43,15 +43,15 @@ class ChatObserver: self.stream_id = stream_id self.message_storage = message_storage or MongoDBMessageStorage() - + self.last_user_speak_time: Optional[float] = None # 对方上次发言时间 - self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 - self.last_check_time: float = time.time() # 上次查看聊天记录时间 - self.last_message_read: Optional[str] = None # 最后读取的消息ID - self.last_message_time: Optional[float] = None # 最后一条消息的时间戳 - - self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 - + self.last_bot_speak_time: Optional[float] = None # 机器人上次发言时间 + self.last_check_time: float = time.time() # 上次查看聊天记录时间 + self.last_message_read: Optional[str] = None # 最后读取的消息ID + self.last_message_time: Optional[float] = None # 最后一条消息的时间戳 + + self.waiting_start_time: float = time.time() # 等待开始时间,初始化为当前时间 + # 消息历史记录 self.message_history: List[Dict[str, Any]] = [] # 所有消息历史 self.last_message_id: Optional[str] = None # 最后一条消息的ID @@ -62,20 +62,20 @@ class ChatObserver: self._task: Optional[asyncio.Task] = None self._update_event = asyncio.Event() # 触发更新的事件 self._update_complete = asyncio.Event() # 更新完成的事件 - + # 通知管理器 self.notification_manager = NotificationManager() - + # 冷场检查配置 self.cold_chat_threshold: float = 60.0 # 60秒无消息判定为冷场 self.last_cold_chat_check: float = time.time() self.is_cold_chat_state: bool = False - + self.update_event = asyncio.Event() self.update_interval = 5 # 更新间隔(秒) self.message_cache = [] self.update_running = False - + async def check(self) -> bool: """检查距离上一次观察之后是否有了新消息 @@ -83,21 +83,18 @@ class ChatObserver: bool: 是否有新消息 """ logger.debug(f"检查距离上一次观察之后是否有了新消息: {self.last_check_time}") - - new_message_exists = await self.message_storage.has_new_messages( - self.stream_id, - self.last_check_time - ) - + + new_message_exists = await self.message_storage.has_new_messages(self.stream_id, self.last_check_time) + if new_message_exists: logger.debug("发现新消息") self.last_check_time = time.time() return new_message_exists - + async def _add_message_to_history(self, message: Dict[str, Any]): """添加消息到历史记录并发送通知 - + Args: message: 消息数据 """ @@ -112,76 +109,65 @@ class ChatObserver: self.last_bot_speak_time = message["time"] else: self.last_user_speak_time = message["time"] - + # 发送新消息通知 - notification = create_new_message_notification( - sender="chat_observer", - target="pfc", - message=message - ) + notification = create_new_message_notification(sender="chat_observer", target="pfc", message=message) await self.notification_manager.send_notification(notification) - + # 检查并更新冷场状态 await self._check_cold_chat() - + async def _check_cold_chat(self): """检查是否处于冷场状态并发送通知""" current_time = time.time() - + # 每10秒检查一次冷场状态 if current_time - self.last_cold_chat_check < 10: return - + self.last_cold_chat_check = current_time - + # 判断是否冷场 is_cold = False if self.last_message_time is None: is_cold = True else: is_cold = (current_time - self.last_message_time) > self.cold_chat_threshold - + # 如果冷场状态发生变化,发送通知 if is_cold != self.is_cold_chat_state: self.is_cold_chat_state = is_cold - notification = create_cold_chat_notification( - sender="chat_observer", - target="pfc", - is_cold=is_cold - ) + notification = create_cold_chat_notification(sender="chat_observer", target="pfc", is_cold=is_cold) await self.notification_manager.send_notification(notification) - + async def get_new_message(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]: """获取上一次观察的时间点后的新消息,插入到历史记录中,并返回新消息和历史记录两个对象""" - messages = await self.message_storage.get_messages_after( - self.stream_id, - self.last_message_read - ) + messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read) for message in messages: await self._add_message_to_history(message) return messages, self.message_history - + def new_message_after(self, time_point: float) -> bool: """判断是否在指定时间点后有新消息 - + Args: time_point: 时间戳 - + Returns: bool: 是否有新消息 """ if time_point is None: logger.warning("time_point 为 None,返回 False") return False - + if self.last_message_time is None: logger.debug("没有最后消息时间,返回 False") return False - + has_new = self.last_message_time > time_point logger.debug(f"判断是否在指定时间点后有新消息: {self.last_message_time} > {time_point} = {has_new}") return has_new - + def get_message_history( self, start_time: Optional[float] = None, @@ -224,11 +210,8 @@ class ChatObserver: Returns: List[Dict[str, Any]]: 新消息列表 """ - new_messages = await self.message_storage.get_messages_after( - self.stream_id, - self.last_message_read - ) - + new_messages = await self.message_storage.get_messages_after(self.stream_id, self.last_message_read) + if new_messages: self.last_message_read = new_messages[-1]["message_id"] @@ -243,17 +226,15 @@ class ChatObserver: Returns: List[Dict[str, Any]]: 最多5条消息 """ - new_messages = await self.message_storage.get_messages_before( - self.stream_id, - time_point - ) - + new_messages = await self.message_storage.get_messages_before(self.stream_id, time_point) + if new_messages: self.last_message_read = new_messages[-1]["message_id"] return new_messages - - '''主要观察循环''' + + """主要观察循环""" + async def _update_loop(self): """更新循环""" try: @@ -282,7 +263,7 @@ class ChatObserver: # 处理新消息 for message in new_messages: await self._add_message_to_history(message) - + # 设置完成事件 self._update_complete.set() @@ -379,7 +360,7 @@ class ChatObserver: if not self.update_running: self.update_running = True asyncio.create_task(self._periodic_update()) - + async def _periodic_update(self): """定期更新消息历史""" try: @@ -388,53 +369,52 @@ class ChatObserver: await asyncio.sleep(self.update_interval) except Exception as e: logger.error(f"定期更新消息历史时出错: {str(e)}") - + async def _update_message_history(self) -> bool: """更新消息历史 - + Returns: bool: 是否有新消息 """ try: - messages = await self.message_storage.get_messages_for_stream( - self.stream_id, - limit=50 - ) - + messages = await self.message_storage.get_messages_for_stream(self.stream_id, limit=50) + if not messages: return False - + # 检查是否有新消息 has_new_messages = False - if messages and (not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"]): + if messages and ( + not self.message_cache or messages[0]["message_id"] != self.message_cache[0]["message_id"] + ): has_new_messages = True - + self.message_cache = messages - + if has_new_messages: self.update_event.set() self.update_event.clear() return True return False - + except Exception as e: logger.error(f"更新消息历史时出错: {str(e)}") return False - + def get_cached_messages(self, limit: int = 50) -> List[Dict[str, Any]]: """获取缓存的消息历史 - + Args: limit: 获取的最大消息数量,默认50 - + Returns: List[Dict[str, Any]]: 缓存的消息历史列表 - """ + """ return self.message_cache[:limit] - + def get_last_message(self) -> Optional[Dict[str, Any]]: """获取最后一条消息 - + Returns: Optional[Dict[str, Any]]: 最后一条消息,如果没有则返回None """ diff --git a/src/plugins/PFC/chat_states.py b/src/plugins/PFC/chat_states.py index bb7cfc4a6..b28ca69a6 100644 --- a/src/plugins/PFC/chat_states.py +++ b/src/plugins/PFC/chat_states.py @@ -4,32 +4,38 @@ from dataclasses import dataclass from datetime import datetime from abc import ABC, abstractmethod + class ChatState(Enum): """聊天状态枚举""" - NORMAL = auto() # 正常状态 - NEW_MESSAGE = auto() # 有新消息 - COLD_CHAT = auto() # 冷场状态 - ACTIVE_CHAT = auto() # 活跃状态 - BOT_SPEAKING = auto() # 机器人正在说话 - USER_SPEAKING = auto() # 用户正在说话 - SILENT = auto() # 沉默状态 - ERROR = auto() # 错误状态 + + NORMAL = auto() # 正常状态 + NEW_MESSAGE = auto() # 有新消息 + COLD_CHAT = auto() # 冷场状态 + ACTIVE_CHAT = auto() # 活跃状态 + BOT_SPEAKING = auto() # 机器人正在说话 + USER_SPEAKING = auto() # 用户正在说话 + SILENT = auto() # 沉默状态 + ERROR = auto() # 错误状态 + class NotificationType(Enum): """通知类型枚举""" - NEW_MESSAGE = auto() # 新消息通知 - COLD_CHAT = auto() # 冷场通知 - ACTIVE_CHAT = auto() # 活跃通知 - BOT_SPEAKING = auto() # 机器人说话通知 - USER_SPEAKING = auto() # 用户说话通知 - MESSAGE_DELETED = auto() # 消息删除通知 - USER_JOINED = auto() # 用户加入通知 - USER_LEFT = auto() # 用户离开通知 - ERROR = auto() # 错误通知 + + NEW_MESSAGE = auto() # 新消息通知 + COLD_CHAT = auto() # 冷场通知 + ACTIVE_CHAT = auto() # 活跃通知 + BOT_SPEAKING = auto() # 机器人说话通知 + USER_SPEAKING = auto() # 用户说话通知 + MESSAGE_DELETED = auto() # 消息删除通知 + USER_JOINED = auto() # 用户加入通知 + USER_LEFT = auto() # 用户离开通知 + ERROR = auto() # 错误通知 + @dataclass class ChatStateInfo: """聊天状态信息""" + state: ChatState last_message_time: Optional[float] = None last_message_content: Optional[str] = None @@ -38,53 +44,55 @@ class ChatStateInfo: cold_duration: float = 0.0 # 冷场持续时间(秒) active_duration: float = 0.0 # 活跃持续时间(秒) + @dataclass class Notification: """通知基类""" + type: NotificationType timestamp: float - sender: str # 发送者标识 - target: str # 接收者标识 + sender: str # 发送者标识 + target: str # 接收者标识 data: Dict[str, Any] - + def to_dict(self) -> Dict[str, Any]: """转换为字典格式""" - return { - "type": self.type.name, - "timestamp": self.timestamp, - "data": self.data - } + return {"type": self.type.name, "timestamp": self.timestamp, "data": self.data} + @dataclass class StateNotification(Notification): """持续状态通知""" + is_active: bool = True - + def to_dict(self) -> Dict[str, Any]: base_dict = super().to_dict() base_dict["is_active"] = self.is_active return base_dict + class NotificationHandler(ABC): """通知处理器接口""" - + @abstractmethod async def handle_notification(self, notification: Notification): """处理通知""" pass + class NotificationManager: """通知管理器""" - + def __init__(self): # 按接收者和通知类型存储处理器 self._handlers: Dict[str, Dict[NotificationType, List[NotificationHandler]]] = {} self._active_states: Set[NotificationType] = set() self._notification_history: List[Notification] = [] - + def register_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): """注册通知处理器 - + Args: target: 接收者标识(例如:"pfc") notification_type: 要处理的通知类型 @@ -95,10 +103,10 @@ class NotificationManager: if notification_type not in self._handlers[target]: self._handlers[target][notification_type] = [] self._handlers[target][notification_type].append(handler) - + def unregister_handler(self, target: str, notification_type: NotificationType, handler: NotificationHandler): """注销通知处理器 - + Args: target: 接收者标识 notification_type: 通知类型 @@ -114,56 +122,56 @@ class NotificationManager: # 如果该目标没有任何处理器,删除该目标 if not self._handlers[target]: del self._handlers[target] - + async def send_notification(self, notification: Notification): """发送通知""" self._notification_history.append(notification) - + # 如果是状态通知,更新活跃状态 if isinstance(notification, StateNotification): if notification.is_active: self._active_states.add(notification.type) else: self._active_states.discard(notification.type) - + # 调用目标接收者的处理器 target = notification.target if target in self._handlers: handlers = self._handlers[target].get(notification.type, []) for handler in handlers: await handler.handle_notification(notification) - + def get_active_states(self) -> Set[NotificationType]: """获取当前活跃的状态""" return self._active_states.copy() - + def is_state_active(self, state_type: NotificationType) -> bool: """检查特定状态是否活跃""" return state_type in self._active_states - - def get_notification_history(self, - sender: Optional[str] = None, - target: Optional[str] = None, - limit: Optional[int] = None) -> List[Notification]: + + def get_notification_history( + self, sender: Optional[str] = None, target: Optional[str] = None, limit: Optional[int] = None + ) -> List[Notification]: """获取通知历史 - + Args: sender: 过滤特定发送者的通知 target: 过滤特定接收者的通知 limit: 限制返回数量 """ history = self._notification_history - + if sender: history = [n for n in history if n.sender == sender] if target: history = [n for n in history if n.target == target] - + if limit is not None: history = history[-limit:] - + return history + # 一些常用的通知创建函数 def create_new_message_notification(sender: str, target: str, message: Dict[str, Any]) -> Notification: """创建新消息通知""" @@ -176,10 +184,11 @@ def create_new_message_notification(sender: str, target: str, message: Dict[str, "message_id": message.get("message_id"), "content": message.get("content"), "sender": message.get("sender"), - "time": message.get("time") - } + "time": message.get("time"), + }, ) + def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> StateNotification: """创建冷场状态通知""" return StateNotification( @@ -188,9 +197,10 @@ def create_cold_chat_notification(sender: str, target: str, is_cold: bool) -> St sender=sender, target=target, data={"is_cold": is_cold}, - is_active=is_cold + is_active=is_cold, ) + def create_active_chat_notification(sender: str, target: str, is_active: bool) -> StateNotification: """创建活跃状态通知""" return StateNotification( @@ -199,69 +209,70 @@ def create_active_chat_notification(sender: str, target: str, is_active: bool) - sender=sender, target=target, data={"is_active": is_active}, - is_active=is_active + is_active=is_active, ) + class ChatStateManager: """聊天状态管理器""" - + def __init__(self): self.current_state = ChatState.NORMAL self.state_info = ChatStateInfo(state=ChatState.NORMAL) self.state_history: list[ChatStateInfo] = [] - + def update_state(self, new_state: ChatState, **kwargs): """更新聊天状态 - + Args: new_state: 新的状态 **kwargs: 其他状态信息 """ self.current_state = new_state self.state_info.state = new_state - + # 更新其他状态信息 for key, value in kwargs.items(): if hasattr(self.state_info, key): setattr(self.state_info, key, value) - + # 记录状态历史 self.state_history.append(self.state_info) - + def get_current_state_info(self) -> ChatStateInfo: """获取当前状态信息""" return self.state_info - + def get_state_history(self) -> list[ChatStateInfo]: """获取状态历史""" return self.state_history - + def is_cold_chat(self, threshold: float = 60.0) -> bool: """判断是否处于冷场状态 - + Args: threshold: 冷场阈值(秒) - + Returns: bool: 是否冷场 """ if not self.state_info.last_message_time: return True - + current_time = datetime.now().timestamp() return (current_time - self.state_info.last_message_time) > threshold - + def is_active_chat(self, threshold: float = 5.0) -> bool: """判断是否处于活跃状态 - + Args: threshold: 活跃阈值(秒) - + Returns: bool: 是否活跃 """ if not self.state_info.last_message_time: return False - + current_time = datetime.now().timestamp() - return (current_time - self.state_info.last_message_time) <= threshold \ No newline at end of file + return (current_time - self.state_info.last_message_time) <= threshold diff --git a/src/plugins/PFC/conversation.py b/src/plugins/PFC/conversation.py index dda380491..40a729671 100644 --- a/src/plugins/PFC/conversation.py +++ b/src/plugins/PFC/conversation.py @@ -20,23 +20,23 @@ logger = get_module_logger("pfc_conversation") class Conversation: """对话类,负责管理单个对话的状态和行为""" - + def __init__(self, stream_id: str): """初始化对话实例 - + Args: stream_id: 聊天流ID """ self.stream_id = stream_id self.state = ConversationState.INIT self.should_continue = False - + # 回复相关 self.generated_reply = "" - + async def _initialize(self): """初始化实例,注册所有组件""" - + try: self.action_planner = ActionPlanner(self.stream_id) self.goal_analyzer = GoalAnalyzer(self.stream_id) @@ -44,37 +44,35 @@ class Conversation: self.knowledge_fetcher = KnowledgeFetcher() self.waiter = Waiter(self.stream_id) self.direct_sender = DirectMessageSender() - + # 获取聊天流信息 self.chat_stream = chat_manager.get_stream(self.stream_id) - + self.stop_action_planner = False except Exception as e: logger.error(f"初始化对话实例:注册运行组件失败: {e}") logger.error(traceback.format_exc()) raise - - + try: - #决策所需要的信息,包括自身自信和观察信息两部分 - #注册观察器和观测信息 + # 决策所需要的信息,包括自身自信和观察信息两部分 + # 注册观察器和观测信息 self.chat_observer = ChatObserver.get_instance(self.stream_id) self.chat_observer.start() self.observation_info = ObservationInfo() self.observation_info.bind_to_chat_observer(self.stream_id) - - #对话信息 + + # 对话信息 self.conversation_info = ConversationInfo() except Exception as e: logger.error(f"初始化对话实例:注册信息组件失败: {e}") logger.error(traceback.format_exc()) raise - + # 组件准备完成,启动该论对话 self.should_continue = True asyncio.create_task(self.start()) - - + async def start(self): """开始对话流程""" try: @@ -83,17 +81,13 @@ class Conversation: except Exception as e: logger.error(f"启动对话系统失败: {e}") raise - - + async def _plan_and_action_loop(self): """思考步,PFC核心循环模块""" # 获取最近的消息历史 while self.should_continue: # 使用决策信息来辅助行动规划 - action, reason = await self.action_planner.plan( - self.observation_info, - self.conversation_info - ) + action, reason = await self.action_planner.plan(self.observation_info, self.conversation_info) if self._check_new_messages_after_planning(): continue @@ -107,93 +101,90 @@ class Conversation: # 如果需要,可以在这里添加逻辑来根据新消息重新决定行动 return True return False - - + def _convert_to_message(self, msg_dict: Dict[str, Any]) -> Message: """将消息字典转换为Message对象""" try: chat_info = msg_dict.get("chat_info", {}) chat_stream = ChatStream.from_dict(chat_info) user_info = UserInfo.from_dict(msg_dict.get("user_info", {})) - + return Message( message_id=msg_dict["message_id"], chat_stream=chat_stream, time=msg_dict["time"], user_info=user_info, processed_plain_text=msg_dict.get("processed_plain_text", ""), - detailed_plain_text=msg_dict.get("detailed_plain_text", "") + detailed_plain_text=msg_dict.get("detailed_plain_text", ""), ) except Exception as e: logger.warning(f"转换消息时出错: {e}") raise - async def _handle_action(self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo): + async def _handle_action( + self, action: str, reason: str, observation_info: ObservationInfo, conversation_info: ConversationInfo + ): """处理规划的行动""" logger.info(f"执行行动: {action}, 原因: {reason}") - + # 记录action历史,先设置为stop,完成后再设置为done - conversation_info.done_action.append({ - "action": action, - "reason": reason, - "status": "start", - "time": datetime.datetime.now().strftime("%H:%M:%S") - }) - - + conversation_info.done_action.append( + { + "action": action, + "reason": reason, + "status": "start", + "time": datetime.datetime.now().strftime("%H:%M:%S"), + } + ) + if action == "direct_reply": self.state = ConversationState.GENERATING - self.generated_reply = await self.reply_generator.generate( - observation_info, - conversation_info - ) - + self.generated_reply = await self.reply_generator.generate(observation_info, conversation_info) + # # 检查回复是否合适 # is_suitable, reason, need_replan = await self.reply_generator.check_reply( # self.generated_reply, # self.current_goal # ) - + if self._check_new_messages_after_planning(): return None - + await self._send_reply() - - conversation_info.done_action.append({ - "action": action, - "reason": reason, - "status": "done", - "time": datetime.datetime.now().strftime("%H:%M:%S") - }) - + + conversation_info.done_action.append( + { + "action": action, + "reason": reason, + "status": "done", + "time": datetime.datetime.now().strftime("%H:%M:%S"), + } + ) + elif action == "fetch_knowledge": self.state = ConversationState.FETCHING knowledge = "TODO:知识" topic = "TODO:关键词" - + logger.info(f"假装获取到知识{knowledge},关键词是: {topic}") - + if knowledge: if topic not in self.conversation_info.knowledge_list: - self.conversation_info.knowledge_list.append({ - "topic": topic, - "knowledge": knowledge - }) + self.conversation_info.knowledge_list.append({"topic": topic, "knowledge": knowledge}) else: self.conversation_info.knowledge_list[topic] += knowledge - + elif action == "rethink_goal": self.state = ConversationState.RETHINKING await self.goal_analyzer.analyze_goal(conversation_info, observation_info) - elif action == "listening": self.state = ConversationState.LISTENING logger.info("倾听对方发言...") if await self.waiter.wait(): # 如果返回True表示超时 await self._send_timeout_message() await self._stop_conversation() - + else: # wait self.state = ConversationState.WAITING logger.info("等待更多信息...") @@ -207,12 +198,10 @@ class Conversation: messages = self.chat_observer.get_cached_messages(limit=1) if not messages: return - + latest_message = self._convert_to_message(messages[0]) await self.direct_sender.send_message( - chat_stream=self.chat_stream, - content="TODO:超时消息", - reply_to_message=latest_message + chat_stream=self.chat_stream, content="TODO:超时消息", reply_to_message=latest_message ) except Exception as e: logger.error(f"发送超时消息失败: {str(e)}") @@ -222,24 +211,22 @@ class Conversation: if not self.generated_reply: logger.warning("没有生成回复") return - + messages = self.chat_observer.get_cached_messages(limit=1) if not messages: logger.warning("没有最近的消息可以回复") return - + latest_message = self._convert_to_message(messages[0]) try: await self.direct_sender.send_message( - chat_stream=self.chat_stream, - content=self.generated_reply, - reply_to_message=latest_message + chat_stream=self.chat_stream, content=self.generated_reply, reply_to_message=latest_message ) self.chat_observer.trigger_update() # 触发立即更新 if not await self.chat_observer.wait_for_update(): logger.warning("等待消息更新超时") - + self.state = ConversationState.ANALYZING except Exception as e: logger.error(f"发送消息失败: {str(e)}") - self.state = ConversationState.ANALYZING \ No newline at end of file + self.state = ConversationState.ANALYZING diff --git a/src/plugins/PFC/conversation_info.py b/src/plugins/PFC/conversation_info.py index 5b8262a16..cae9f0b34 100644 --- a/src/plugins/PFC/conversation_info.py +++ b/src/plugins/PFC/conversation_info.py @@ -1,8 +1,6 @@ - - class ConversationInfo: def __init__(self): self.done_action = [] self.goal_list = [] self.knowledge_list = [] - self.memory_list = [] \ No newline at end of file + self.memory_list = [] diff --git a/src/plugins/PFC/message_sender.py b/src/plugins/PFC/message_sender.py index 6df1e7ded..76b07945f 100644 --- a/src/plugins/PFC/message_sender.py +++ b/src/plugins/PFC/message_sender.py @@ -7,12 +7,13 @@ from src.plugins.chat.message import MessageSending logger = get_module_logger("message_sender") + class DirectMessageSender: """直接消息发送器""" - + def __init__(self): pass - + async def send_message( self, chat_stream: ChatStream, @@ -20,7 +21,7 @@ class DirectMessageSender: reply_to_message: Optional[Message] = None, ) -> None: """发送消息到聊天流 - + Args: chat_stream: 聊天流 content: 消息内容 @@ -29,21 +30,18 @@ class DirectMessageSender: try: # 创建消息内容 segments = [Seg(type="text", data={"text": content})] - + # 检查是否需要引用回复 if reply_to_message: reply_id = reply_to_message.message_id - message_sending = MessageSending( - segments=segments, - reply_to_id=reply_id - ) + message_sending = MessageSending(segments=segments, reply_to_id=reply_id) else: message_sending = MessageSending(segments=segments) - + # 发送消息 await chat_stream.send_message(message_sending) logger.info(f"消息已发送: {content}") - + except Exception as e: logger.error(f"发送消息失败: {str(e)}") - raise \ No newline at end of file + raise diff --git a/src/plugins/PFC/message_storage.py b/src/plugins/PFC/message_storage.py index 3c7cab8b3..88f409641 100644 --- a/src/plugins/PFC/message_storage.py +++ b/src/plugins/PFC/message_storage.py @@ -2,133 +2,126 @@ from abc import ABC, abstractmethod from typing import List, Dict, Any, Optional from src.common.database import db + class MessageStorage(ABC): """消息存储接口""" - + @abstractmethod async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: """获取指定消息ID之后的所有消息 - + Args: chat_id: 聊天ID message_id: 消息ID,如果为None则获取所有消息 - + Returns: List[Dict[str, Any]]: 消息列表 """ pass - + @abstractmethod async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: """获取指定时间点之前的消息 - + Args: chat_id: 聊天ID time_point: 时间戳 limit: 最大消息数量 - + Returns: List[Dict[str, Any]]: 消息列表 """ pass - + @abstractmethod async def has_new_messages(self, chat_id: str, after_time: float) -> bool: """检查是否有新消息 - + Args: chat_id: 聊天ID after_time: 时间戳 - + Returns: bool: 是否有新消息 """ pass + class MongoDBMessageStorage(MessageStorage): """MongoDB消息存储实现""" - + def __init__(self): self.db = db - + async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: query = {"chat_id": chat_id} - + if message_id: # 获取ID大于message_id的消息 last_message = self.db.messages.find_one({"message_id": message_id}) if last_message: query["time"] = {"$gt": last_message["time"]} - - return list( - self.db.messages.find(query).sort("time", 1) - ) - + + return list(self.db.messages.find(query).sort("time", 1)) + async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: - query = { - "chat_id": chat_id, - "time": {"$lt": time_point} - } - - messages = list( - self.db.messages.find(query).sort("time", -1).limit(limit) - ) - + query = {"chat_id": chat_id, "time": {"$lt": time_point}} + + messages = list(self.db.messages.find(query).sort("time", -1).limit(limit)) + # 将消息按时间正序排列 messages.reverse() return messages - + async def has_new_messages(self, chat_id: str, after_time: float) -> bool: - query = { - "chat_id": chat_id, - "time": {"$gt": after_time} - } - + query = {"chat_id": chat_id, "time": {"$gt": after_time}} + return self.db.messages.find_one(query) is not None + # # 创建一个内存消息存储实现,用于测试 # class InMemoryMessageStorage(MessageStorage): # """内存消息存储实现,主要用于测试""" - + # def __init__(self): # self.messages: Dict[str, List[Dict[str, Any]]] = {} - + # async def get_messages_after(self, chat_id: str, message_id: Optional[str] = None) -> List[Dict[str, Any]]: # if chat_id not in self.messages: # return [] - + # messages = self.messages[chat_id] # if not message_id: # return messages - + # # 找到message_id的索引 # try: # index = next(i for i, m in enumerate(messages) if m["message_id"] == message_id) # return messages[index + 1:] # except StopIteration: # return [] - + # async def get_messages_before(self, chat_id: str, time_point: float, limit: int = 5) -> List[Dict[str, Any]]: # if chat_id not in self.messages: # return [] - + # messages = [ # m for m in self.messages[chat_id] # if m["time"] < time_point # ] - + # return messages[-limit:] - + # async def has_new_messages(self, chat_id: str, after_time: float) -> bool: # if chat_id not in self.messages: # return False - + # return any(m["time"] > after_time for m in self.messages[chat_id]) - + # # 测试辅助方法 # def add_message(self, chat_id: str, message: Dict[str, Any]): # """添加测试消息""" # if chat_id not in self.messages: # self.messages[chat_id] = [] # self.messages[chat_id].append(message) -# self.messages[chat_id].sort(key=lambda m: m["time"]) \ No newline at end of file +# self.messages[chat_id].sort(key=lambda m: m["time"]) diff --git a/src/plugins/PFC/notification_handler.py b/src/plugins/PFC/notification_handler.py index 38c0d0dee..1131d18bf 100644 --- a/src/plugins/PFC/notification_handler.py +++ b/src/plugins/PFC/notification_handler.py @@ -7,25 +7,26 @@ if TYPE_CHECKING: logger = get_module_logger("notification_handler") + class PFCNotificationHandler(NotificationHandler): """PFC通知处理器""" - - def __init__(self, conversation: 'Conversation'): + + def __init__(self, conversation: "Conversation"): """初始化PFC通知处理器 - + Args: conversation: 对话实例 """ self.conversation = conversation - + async def handle_notification(self, notification: Notification): """处理通知 - + Args: notification: 通知对象 """ logger.debug(f"收到通知: {notification.type.name}, 数据: {notification.data}") - + # 根据通知类型执行不同的处理 if notification.type == NotificationType.NEW_MESSAGE: # 新消息通知 @@ -38,34 +39,33 @@ class PFCNotificationHandler(NotificationHandler): await self._handle_command(notification) else: logger.warning(f"未知的通知类型: {notification.type.name}") - + async def _handle_new_message(self, notification: Notification): """处理新消息通知 - + Args: notification: 通知对象 """ - + # 更新决策信息 observation_info = self.conversation.observation_info observation_info.last_message_time = notification.data.get("time", 0) observation_info.add_unprocessed_message(notification.data) - + # 手动触发观察器更新 self.conversation.chat_observer.trigger_update() - + async def _handle_cold_chat(self, notification: Notification): """处理冷聊天通知 - + Args: notification: 通知对象 """ # 获取冷聊天信息 cold_duration = notification.data.get("duration", 0) - + # 更新决策信息 observation_info = self.conversation.observation_info observation_info.conversation_cold_duration = cold_duration - + logger.info(f"对话已冷: {cold_duration}秒") - \ No newline at end of file diff --git a/src/plugins/PFC/observation_info.py b/src/plugins/PFC/observation_info.py index 2967f10e3..d0eee2236 100644 --- a/src/plugins/PFC/observation_info.py +++ b/src/plugins/PFC/observation_info.py @@ -1,5 +1,5 @@ -#Programmable Friendly Conversationalist -#Prefrontal cortex +# Programmable Friendly Conversationalist +# Prefrontal cortex from typing import List, Optional, Dict, Any, Set from ..message.message_base import UserInfo import time @@ -10,26 +10,27 @@ from .chat_states import NotificationHandler logger = get_module_logger("observation_info") + class ObservationInfoHandler(NotificationHandler): """ObservationInfo的通知处理器""" - - def __init__(self, observation_info: 'ObservationInfo'): + + def __init__(self, observation_info: "ObservationInfo"): """初始化处理器 - + Args: observation_info: 要更新的ObservationInfo实例 """ self.observation_info = observation_info - + async def handle_notification(self, notification: Dict[str, Any]): """处理通知 - + Args: notification: 通知数据 """ notification_type = notification.get("type") data = notification.get("data", {}) - + if notification_type == "NEW_MESSAGE": # 处理新消息通知 logger.debug(f"收到新消息通知data: {data}") @@ -37,62 +38,62 @@ class ObservationInfoHandler(NotificationHandler): self.observation_info.update_from_message(message) # self.observation_info.has_unread_messages = True # self.observation_info.new_unread_message.append(message.get("processed_plain_text", "")) - + elif notification_type == "COLD_CHAT": # 处理冷场通知 is_cold = data.get("is_cold", False) self.observation_info.update_cold_chat_status(is_cold, time.time()) - + elif notification_type == "ACTIVE_CHAT": # 处理活跃通知 is_active = data.get("is_active", False) self.observation_info.is_cold = not is_active - + elif notification_type == "BOT_SPEAKING": # 处理机器人说话通知 self.observation_info.is_typing = False self.observation_info.last_bot_speak_time = time.time() - + elif notification_type == "USER_SPEAKING": # 处理用户说话通知 self.observation_info.is_typing = False self.observation_info.last_user_speak_time = time.time() - + elif notification_type == "MESSAGE_DELETED": # 处理消息删除通知 message_id = data.get("message_id") self.observation_info.unprocessed_messages = [ - msg for msg in self.observation_info.unprocessed_messages - if msg.get("message_id") != message_id + msg for msg in self.observation_info.unprocessed_messages if msg.get("message_id") != message_id ] - + elif notification_type == "USER_JOINED": # 处理用户加入通知 user_id = data.get("user_id") if user_id: self.observation_info.active_users.add(user_id) - + elif notification_type == "USER_LEFT": # 处理用户离开通知 user_id = data.get("user_id") if user_id: self.observation_info.active_users.discard(user_id) - + elif notification_type == "ERROR": # 处理错误通知 error_msg = data.get("error", "") logger.error(f"收到错误通知: {error_msg}") + @dataclass class ObservationInfo: """决策信息类,用于收集和管理来自chat_observer的通知信息""" - - #data_list + + # data_list chat_history: List[str] = field(default_factory=list) unprocessed_messages: List[Dict[str, Any]] = field(default_factory=list) active_users: Set[str] = field(default_factory=set) - - #data + + # data last_bot_speak_time: Optional[float] = None last_user_speak_time: Optional[float] = None last_message_time: Optional[float] = None @@ -101,78 +102,70 @@ class ObservationInfo: bot_id: Optional[str] = None new_messages_count: int = 0 cold_chat_duration: float = 0.0 - - #state + + # state is_typing: bool = False has_unread_messages: bool = False is_cold_chat: bool = False changed: bool = False - + # #spec # meta_plan_trigger: bool = False - + def __post_init__(self): """初始化后创建handler""" self.chat_observer = None self.handler = ObservationInfoHandler(self) - + def bind_to_chat_observer(self, stream_id: str): """绑定到指定的chat_observer - + Args: stream_id: 聊天流ID """ self.chat_observer = ChatObserver.get_instance(stream_id) self.chat_observer.notification_manager.register_handler( - target="observation_info", - notification_type="NEW_MESSAGE", - handler=self.handler + target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler ) self.chat_observer.notification_manager.register_handler( - target="observation_info", - notification_type="COLD_CHAT", - handler=self.handler + target="observation_info", notification_type="COLD_CHAT", handler=self.handler ) - + def unbind_from_chat_observer(self): """解除与chat_observer的绑定""" if self.chat_observer: self.chat_observer.notification_manager.unregister_handler( - target="observation_info", - notification_type="NEW_MESSAGE", - handler=self.handler + target="observation_info", notification_type="NEW_MESSAGE", handler=self.handler ) self.chat_observer.notification_manager.unregister_handler( - target="observation_info", - notification_type="COLD_CHAT", - handler=self.handler + target="observation_info", notification_type="COLD_CHAT", handler=self.handler ) self.chat_observer = None - + def update_from_message(self, message: Dict[str, Any]): """从消息更新信息 - + Args: message: 消息数据 """ logger.debug(f"更新信息from_message: {message}") self.last_message_time = message["time"] self.last_message_content = message.get("processed_plain_text", "") - + user_info = UserInfo.from_dict(message.get("user_info", {})) self.last_message_sender = user_info.user_id - + if user_info.user_id == self.bot_id: self.last_bot_speak_time = message["time"] else: self.last_user_speak_time = message["time"] self.active_users.add(user_info.user_id) - + self.new_messages_count += 1 self.unprocessed_messages.append(message) - + self.update_changed() - + def update_changed(self): """更新changed状态""" self.changed = True @@ -180,7 +173,7 @@ class ObservationInfo: def update_cold_chat_status(self, is_cold: bool, current_time: float): """更新冷场状态 - + Args: is_cold: 是否冷场 current_time: 当前时间 @@ -188,37 +181,37 @@ class ObservationInfo: self.is_cold_chat = is_cold if is_cold and self.last_message_time: self.cold_chat_duration = current_time - self.last_message_time - + def get_active_duration(self) -> float: """获取当前活跃时长 - + Returns: float: 最后一条消息到现在的时长(秒) """ if not self.last_message_time: return 0.0 return time.time() - self.last_message_time - + def get_user_response_time(self) -> Optional[float]: """获取用户响应时间 - + Returns: Optional[float]: 用户最后发言到现在的时长(秒),如果没有用户发言则返回None """ if not self.last_user_speak_time: return None return time.time() - self.last_user_speak_time - + def get_bot_response_time(self) -> Optional[float]: """获取机器人响应时间 - + Returns: Optional[float]: 机器人最后发言到现在的时长(秒),如果没有机器人发言则返回None """ if not self.last_bot_speak_time: return None return time.time() - self.last_bot_speak_time - + def clear_unprocessed_messages(self): """清空未处理消息列表""" # 将未处理消息添加到历史记录中 @@ -229,10 +222,10 @@ class ObservationInfo: self.has_unread_messages = False self.unprocessed_messages.clear() self.new_messages_count = 0 - + def add_unprocessed_message(self, message: Dict[str, Any]): """添加未处理的消息 - + Args: message: 消息数据 """ @@ -241,6 +234,6 @@ class ObservationInfo: if message_id and not any(m.get("message_id") == message_id for m in self.unprocessed_messages): self.unprocessed_messages.append(message) self.new_messages_count += 1 - + # 同时更新其他消息相关信息 - self.update_from_message(message) \ No newline at end of file + self.update_from_message(message) diff --git a/src/plugins/PFC/pfc.py b/src/plugins/PFC/pfc.py index 62b28acb4..3436dce8f 100644 --- a/src/plugins/PFC/pfc.py +++ b/src/plugins/PFC/pfc.py @@ -49,43 +49,40 @@ class GoalAnalyzer: Args: conversation_info: 对话信息 observation_info: 观察信息 - + Returns: Tuple[str, str, str]: (目标, 方法, 原因) """ - #构建对话目标 + # 构建对话目标 goal_list = conversation_info.goal_list goal_text = "" for goal, reason in goal_list: goal_text += f"目标:{goal};" goal_text += f"原因:{reason}\n" - - + # 获取聊天历史记录 chat_history_list = observation_info.chat_history chat_history_text = "" for msg in chat_history_list: chat_history_text += f"{msg}\n" - + if observation_info.new_messages_count > 0: new_messages_list = observation_info.unprocessed_messages - + chat_history_text += f"有{observation_info.new_messages_count}条新消息:\n" for msg in new_messages_list: chat_history_text += f"{msg}\n" - + observation_info.clear_unprocessed_messages() - - + personality_text = f"你的名字是{self.name},{self.personality_info}" - + # 构建action历史文本 action_history_list = conversation_info.done_action action_history_text = "你之前做的事情是:" for action in action_history_list: action_history_text += f"{action}\n" - - + prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请分析以下聊天记录,并根据你的性格特征确定多个明确的对话目标。 这些目标应该反映出对话的不同方面和意图。 @@ -119,20 +116,15 @@ class GoalAnalyzer: logger.debug(f"发送到LLM的提示词: {prompt}") content, _ = await self.llm.generate_response_async(prompt) logger.debug(f"LLM原始返回内容: {content}") - + # 使用简化函数提取JSON内容 success, result = get_items_from_json( - content, - "goal", "reasoning", - required_types={"goal": str, "reasoning": str} + content, "goal", "reasoning", required_types={"goal": str, "reasoning": str} ) - #TODO - - + # TODO + conversation_info.goal_list.append(result) - - async def _update_goals(self, new_goal: str, method: str, reasoning: str): """更新目标列表 @@ -229,24 +221,26 @@ class GoalAnalyzer: try: content, _ = await self.llm.generate_response_async(prompt) logger.debug(f"LLM原始返回内容: {content}") - + # 尝试解析JSON success, result = get_items_from_json( content, - "goal_achieved", "stop_conversation", "reason", - required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str} + "goal_achieved", + "stop_conversation", + "reason", + required_types={"goal_achieved": bool, "stop_conversation": bool, "reason": str}, ) if not success: logger.error("无法解析对话分析结果JSON") return False, False, "解析结果失败" - + goal_achieved = result["goal_achieved"] stop_conversation = result["stop_conversation"] reason = result["reason"] - + return goal_achieved, stop_conversation, reason - + except Exception as e: logger.error(f"分析对话状态时出错: {str(e)}") return False, False, f"分析出错: {str(e)}" @@ -269,23 +263,22 @@ class Waiter: # 使用当前时间作为等待开始时间 wait_start_time = time.time() self.chat_observer.waiting_start_time = wait_start_time # 设置等待开始时间 - + while True: # 检查是否有新消息 if self.chat_observer.new_message_after(wait_start_time): logger.info("等待结束,收到新消息") return False - + # 检查是否超时 if time.time() - wait_start_time > 300: logger.info("等待超过300秒,结束对话") return True - + await asyncio.sleep(1) logger.info("等待中...") - class DirectMessageSender: """直接发送消息到平台的发送器""" diff --git a/src/plugins/PFC/pfc_manager.py b/src/plugins/PFC/pfc_manager.py index 9a36bef19..5be15a100 100644 --- a/src/plugins/PFC/pfc_manager.py +++ b/src/plugins/PFC/pfc_manager.py @@ -5,33 +5,34 @@ import traceback logger = get_module_logger("pfc_manager") + class PFCManager: """PFC对话管理器,负责管理所有对话实例""" - + # 单例模式 _instance = None - + # 会话实例管理 _instances: Dict[str, Conversation] = {} _initializing: Dict[str, bool] = {} - + @classmethod - def get_instance(cls) -> 'PFCManager': + def get_instance(cls) -> "PFCManager": """获取管理器单例 - + Returns: PFCManager: 管理器实例 """ if cls._instance is None: cls._instance = PFCManager() return cls._instance - + async def get_or_create_conversation(self, stream_id: str) -> Optional[Conversation]: """获取或创建对话实例 - + Args: stream_id: 聊天流ID - + Returns: Optional[Conversation]: 对话实例,创建失败则返回None """ @@ -39,11 +40,11 @@ class PFCManager: if stream_id in self._initializing and self._initializing[stream_id]: logger.debug(f"会话实例正在初始化中: {stream_id}") return None - + if stream_id in self._instances: logger.debug(f"使用现有会话实例: {stream_id}") return self._instances[stream_id] - + try: # 创建新实例 logger.info(f"创建新的对话实例: {stream_id}") @@ -51,47 +52,45 @@ class PFCManager: # 创建实例 conversation_instance = Conversation(stream_id) self._instances[stream_id] = conversation_instance - + # 启动实例初始化 await self._initialize_conversation(conversation_instance) except Exception as e: logger.error(f"创建会话实例失败: {stream_id}, 错误: {e}") return None - + return conversation_instance - async def _initialize_conversation(self, conversation: Conversation): """初始化会话实例 - + Args: conversation: 要初始化的会话实例 """ stream_id = conversation.stream_id - + try: logger.info(f"开始初始化会话实例: {stream_id}") # 启动初始化流程 await conversation._initialize() - + # 标记初始化完成 self._initializing[stream_id] = False - + logger.info(f"会话实例 {stream_id} 初始化完成") - + except Exception as e: logger.error(f"管理器初始化会话实例失败: {stream_id}, 错误: {e}") logger.error(traceback.format_exc()) # 清理失败的初始化 - async def get_conversation(self, stream_id: str) -> Optional[Conversation]: """获取已存在的会话实例 - + Args: stream_id: 聊天流ID - + Returns: Optional[Conversation]: 会话实例,不存在则返回None """ - return self._instances.get(stream_id) \ No newline at end of file + return self._instances.get(stream_id) diff --git a/src/plugins/PFC/pfc_types.py b/src/plugins/PFC/pfc_types.py index d7ad8e91f..7391c448d 100644 --- a/src/plugins/PFC/pfc_types.py +++ b/src/plugins/PFC/pfc_types.py @@ -4,6 +4,7 @@ from typing import Literal class ConversationState(Enum): """对话状态""" + INIT = "初始化" RETHINKING = "重新思考" ANALYZING = "分析历史" @@ -18,4 +19,4 @@ class ConversationState(Enum): JUDGING = "判断" -ActionType = Literal["direct_reply", "fetch_knowledge", "wait"] \ No newline at end of file +ActionType = Literal["direct_reply", "fetch_knowledge", "wait"] diff --git a/src/plugins/PFC/reply_generator.py b/src/plugins/PFC/reply_generator.py index beec9dd3e..00ac7c413 100644 --- a/src/plugins/PFC/reply_generator.py +++ b/src/plugins/PFC/reply_generator.py @@ -13,33 +13,26 @@ logger = get_module_logger("reply_generator") class ReplyGenerator: """回复生成器""" - + def __init__(self, stream_id: str): self.llm = LLM_request( - model=global_config.llm_normal, - temperature=0.7, - max_tokens=300, - request_type="reply_generation" + model=global_config.llm_normal, temperature=0.7, max_tokens=300, request_type="reply_generation" ) - self.personality_info = Individuality.get_instance().get_prompt(type = "personality", x_person = 2, level = 2) + self.personality_info = Individuality.get_instance().get_prompt(type="personality", x_person=2, level=2) self.name = global_config.BOT_NICKNAME self.chat_observer = ChatObserver.get_instance(stream_id) self.reply_checker = ReplyChecker(stream_id) - - async def generate( - self, - observation_info: ObservationInfo, - conversation_info: ConversationInfo - ) -> str: + + async def generate(self, observation_info: ObservationInfo, conversation_info: ConversationInfo) -> str: """生成回复 - + Args: goal: 对话目标 chat_history: 聊天历史 knowledge_cache: 知识缓存 previous_reply: 上一次生成的回复(如果有) retry_count: 当前重试次数 - + Returns: str: 生成的回复 """ @@ -51,22 +44,21 @@ class ReplyGenerator: for goal, reason in goal_list: goal_text += f"目标:{goal};" goal_text += f"原因:{reason}\n" - + # 获取聊天历史记录 chat_history_list = observation_info.chat_history chat_history_text = "" for msg in chat_history_list: chat_history_text += f"{msg}\n" - - + # 整理知识缓存 knowledge_text = "" knowledge_list = conversation_info.knowledge_list for knowledge in knowledge_list: knowledge_text += f"知识:{knowledge}\n" - + personality_text = f"你的名字是{self.name},{self.personality_info}" - + prompt = f"""{personality_text}。现在你在参与一场QQ聊天,请根据以下信息生成回复: 当前对话目标:{goal_text} @@ -92,7 +84,7 @@ class ReplyGenerator: logger.info(f"生成的回复: {content}") # is_new = self.chat_observer.check() # logger.debug(f"再看一眼聊天记录,{'有' if is_new else '没有'}新消息") - + # 如果有新消息,重新生成回复 # if is_new: # logger.info("检测到新消息,重新生成回复") @@ -100,27 +92,22 @@ class ReplyGenerator: # goal, chat_history, knowledge_cache, # None, retry_count # ) - + return content - + except Exception as e: logger.error(f"生成回复时出错: {e}") return "抱歉,我现在有点混乱,让我重新思考一下..." - async def check_reply( - self, - reply: str, - goal: str, - retry_count: int = 0 - ) -> Tuple[bool, str, bool]: + async def check_reply(self, reply: str, goal: str, retry_count: int = 0) -> Tuple[bool, str, bool]: """检查回复是否合适 - + Args: reply: 生成的回复 goal: 对话目标 retry_count: 当前重试次数 - + Returns: Tuple[bool, str, bool]: (是否合适, 原因, 是否需要重新规划) """ - return await self.reply_checker.check(reply, goal, retry_count) \ No newline at end of file + return await self.reply_checker.check(reply, goal, retry_count) diff --git a/src/plugins/PFC/waiter.py b/src/plugins/PFC/waiter.py index 0e1bf59f3..66f98e9c3 100644 --- a/src/plugins/PFC/waiter.py +++ b/src/plugins/PFC/waiter.py @@ -3,43 +3,44 @@ from .chat_observer import ChatObserver logger = get_module_logger("waiter") + class Waiter: """等待器,用于等待对话流中的事件""" - + def __init__(self, stream_id: str): self.stream_id = stream_id self.chat_observer = ChatObserver.get_instance(stream_id) - + async def wait(self, timeout: float = 20.0) -> bool: """等待用户回复或超时 - + Args: timeout: 超时时间(秒) - + Returns: bool: 如果因为超时返回则为True,否则为False """ try: message_before = self.chat_observer.get_last_message() - + # 等待新消息 logger.debug(f"等待新消息,超时时间: {timeout}秒") - + is_timeout = await self.chat_observer.wait_for_update(timeout=timeout) if is_timeout: logger.debug("等待超时,没有收到新消息") return True - + # 检查是否是新消息 message_after = self.chat_observer.get_last_message() if message_before and message_after and message_before.get("message_id") == message_after.get("message_id"): # 如果消息ID相同,说明没有新消息 logger.debug("没有收到新消息") return True - + logger.debug("收到新消息") return False - + except Exception as e: logger.error(f"等待时出错: {str(e)}") - return True \ No newline at end of file + return True diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 42234da8e..43d329ff3 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -30,7 +30,7 @@ class ChatBot: self.think_flow_chat = ThinkFlowChat() self.reasoning_chat = ReasoningChat() self.only_process_chat = MessageProcessor() - + # 创建初始化PFC管理器的任务,会在_ensure_started时执行 self.pfc_manager = PFCManager.get_instance() @@ -38,7 +38,7 @@ class ChatBot: """确保所有任务已启动""" if not self._started: logger.info("确保ChatBot所有任务已启动") - + self._started = True async def _create_PFC_chat(self, message: MessageRecv): @@ -46,7 +46,6 @@ class ChatBot: chat_id = str(message.chat_stream.stream_id) if global_config.enable_pfc_chatting: - await self.pfc_manager.get_or_create_conversation(chat_id) except Exception as e: @@ -80,7 +79,7 @@ class ChatBot: try: # 确保所有任务已启动 await self._ensure_started() - + message = MessageRecv(message_data) groupinfo = message.message_info.group_info userinfo = message.message_info.user_info diff --git a/src/plugins/config/config.py b/src/plugins/config/config.py index eccb3bc0b..23e277498 100644 --- a/src/plugins/config/config.py +++ b/src/plugins/config/config.py @@ -24,7 +24,7 @@ config_config = LogConfig( # 配置主程序日志格式 logger = get_module_logger("config", config=config_config) -#考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 +# 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 is_test = True mai_version_main = "0.6.2" mai_version_fix = "snapshot-1" diff --git a/src/plugins/message/__init__.py b/src/plugins/message/__init__.py index bee5c5e58..286ef2310 100644 --- a/src/plugins/message/__init__.py +++ b/src/plugins/message/__init__.py @@ -2,7 +2,7 @@ __version__ = "0.1.0" -from .api import BaseMessageAPI, global_api +from .api import global_api from .message_base import ( Seg, GroupInfo, @@ -14,7 +14,6 @@ from .message_base import ( ) __all__ = [ - "BaseMessageAPI", "Seg", "global_api", "GroupInfo", diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py index 2a6a2b6fc..0c3e3a5a1 100644 --- a/src/plugins/message/api.py +++ b/src/plugins/message/api.py @@ -1,7 +1,8 @@ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect -from typing import Dict, Any, Callable, List, Set +from typing import Dict, Any, Callable, List, Set, Optional from src.common.logger import get_module_logger from src.plugins.message.message_base import MessageBase +from src.common.server import global_server import aiohttp import asyncio import uvicorn @@ -49,13 +50,22 @@ class MessageServer(BaseMessageHandler): _class_handlers: List[Callable] = [] # 类级别的消息处理器 - def __init__(self, host: str = "0.0.0.0", port: int = 18000, enable_token=False): + def __init__( + self, + host: str = "0.0.0.0", + port: int = 18000, + enable_token=False, + app: Optional[FastAPI] = None, + path: str = "/ws", + ): super().__init__() # 将类级别的处理器添加到实例处理器中 self.message_handlers.extend(self._class_handlers) - self.app = FastAPI() self.host = host self.port = port + self.path = path + self.app = app or FastAPI() + self.own_app = app is None # 标记是否使用自己创建的app self.active_websockets: Set[WebSocket] = set() self.platform_websockets: Dict[str, WebSocket] = {} # 平台到websocket的映射 self.valid_tokens: Set[str] = set() @@ -63,28 +73,6 @@ class MessageServer(BaseMessageHandler): self._setup_routes() self._running = False - @classmethod - def register_class_handler(cls, handler: Callable): - """注册类级别的消息处理器""" - if handler not in cls._class_handlers: - cls._class_handlers.append(handler) - - def register_message_handler(self, handler: Callable): - """注册实例级别的消息处理器""" - if handler not in self.message_handlers: - self.message_handlers.append(handler) - - async def verify_token(self, token: str) -> bool: - if not self.enable_token: - return True - return token in self.valid_tokens - - def add_valid_token(self, token: str): - self.valid_tokens.add(token) - - def remove_valid_token(self, token: str): - self.valid_tokens.discard(token) - def _setup_routes(self): @self.app.post("/api/message") async def handle_message(message: Dict[str, Any]): @@ -125,6 +113,90 @@ class MessageServer(BaseMessageHandler): finally: self._remove_websocket(websocket, platform) + @classmethod + def register_class_handler(cls, handler: Callable): + """注册类级别的消息处理器""" + if handler not in cls._class_handlers: + cls._class_handlers.append(handler) + + def register_message_handler(self, handler: Callable): + """注册实例级别的消息处理器""" + if handler not in self.message_handlers: + self.message_handlers.append(handler) + + async def verify_token(self, token: str) -> bool: + if not self.enable_token: + return True + return token in self.valid_tokens + + def add_valid_token(self, token: str): + self.valid_tokens.add(token) + + def remove_valid_token(self, token: str): + self.valid_tokens.discard(token) + + def run_sync(self): + """同步方式运行服务器""" + if not self.own_app: + raise RuntimeError("当使用外部FastAPI实例时,请使用该实例的运行方法") + uvicorn.run(self.app, host=self.host, port=self.port) + + async def run(self): + """异步方式运行服务器""" + self._running = True + try: + if self.own_app: + # 如果使用自己的 FastAPI 实例,运行 uvicorn 服务器 + config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") + self.server = uvicorn.Server(config) + await self.server.serve() + else: + # 如果使用外部 FastAPI 实例,保持运行状态以处理消息 + while self._running: + await asyncio.sleep(1) + except KeyboardInterrupt: + await self.stop() + raise + except Exception as e: + await self.stop() + raise RuntimeError(f"服务器运行错误: {str(e)}") from e + finally: + await self.stop() + + async def start_server(self): + """启动服务器的异步方法""" + if not self._running: + self._running = True + await self.run() + + async def stop(self): + """停止服务器""" + # 清理platform映射 + self.platform_websockets.clear() + + # 取消所有后台任务 + for task in self.background_tasks: + task.cancel() + # 等待所有任务完成 + await asyncio.gather(*self.background_tasks, return_exceptions=True) + self.background_tasks.clear() + + # 关闭所有WebSocket连接 + for websocket in self.active_websockets: + await websocket.close() + self.active_websockets.clear() + + if hasattr(self, "server") and self.own_app: + self._running = False + # 正确关闭 uvicorn 服务器 + self.server.should_exit = True + await self.server.shutdown() + # 等待服务器完全停止 + if hasattr(self.server, "started") and self.server.started: + await self.server.main_loop() + # 清理处理程序 + self.message_handlers.clear() + def _remove_websocket(self, websocket: WebSocket, platform: str): """从所有集合中移除websocket""" if websocket in self.active_websockets: @@ -161,54 +233,6 @@ class MessageServer(BaseMessageHandler): async def send_message(self, message: MessageBase): await self.broadcast_to_platform(message.message_info.platform, message.to_dict()) - def run_sync(self): - """同步方式运行服务器""" - uvicorn.run(self.app, host=self.host, port=self.port) - - async def run(self): - """异步方式运行服务器""" - config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") - self.server = uvicorn.Server(config) - try: - await self.server.serve() - except KeyboardInterrupt as e: - await self.stop() - raise KeyboardInterrupt from e - - async def start_server(self): - """启动服务器的异步方法""" - if not self._running: - self._running = True - await self.run() - - async def stop(self): - """停止服务器""" - # 清理platform映射 - self.platform_websockets.clear() - - # 取消所有后台任务 - for task in self.background_tasks: - task.cancel() - # 等待所有任务完成 - await asyncio.gather(*self.background_tasks, return_exceptions=True) - self.background_tasks.clear() - - # 关闭所有WebSocket连接 - for websocket in self.active_websockets: - await websocket.close() - self.active_websockets.clear() - - if hasattr(self, "server"): - self._running = False - # 正确关闭 uvicorn 服务器 - self.server.should_exit = True - await self.server.shutdown() - # 等待服务器完全停止 - if hasattr(self.server, "started") and self.server.started: - await self.server.main_loop() - # 清理处理程序 - self.message_handlers.clear() - async def send_message_REST(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: """发送消息到指定端点""" async with aiohttp.ClientSession() as session: @@ -219,105 +243,4 @@ class MessageServer(BaseMessageHandler): raise e -class BaseMessageAPI: - def __init__(self, host: str = "0.0.0.0", port: int = 18000): - self.app = FastAPI() - self.host = host - self.port = port - self.message_handlers: List[Callable] = [] - self.cache = [] - self._setup_routes() - self._running = False - - def _setup_routes(self): - """设置基础路由""" - - @self.app.post("/api/message") - async def handle_message(message: Dict[str, Any]): - try: - # 创建后台任务处理消息 - asyncio.create_task(self._background_message_handler(message)) - return {"status": "success"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) from e - - async def _background_message_handler(self, message: Dict[str, Any]): - """后台处理单个消息""" - try: - await self.process_single_message(message) - except Exception as e: - logger.error(f"Background message processing failed: {str(e)}") - logger.error(traceback.format_exc()) - - def register_message_handler(self, handler: Callable): - """注册消息处理函数""" - self.message_handlers.append(handler) - - async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: - """发送消息到指定端点""" - async with aiohttp.ClientSession() as session: - try: - async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response: - return await response.json() - except Exception: - # logger.error(f"发送消息失败: {str(e)}") - pass - - async def process_single_message(self, message: Dict[str, Any]): - """处理单条消息""" - tasks = [] - for handler in self.message_handlers: - try: - tasks.append(handler(message)) - except Exception as e: - logger.error(str(e)) - logger.error(traceback.format_exc()) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - def run_sync(self): - """同步方式运行服务器""" - uvicorn.run(self.app, host=self.host, port=self.port) - - async def run(self): - """异步方式运行服务器""" - config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") - self.server = uvicorn.Server(config) - try: - await self.server.serve() - except KeyboardInterrupt as e: - await self.stop() - raise KeyboardInterrupt from e - - async def start_server(self): - """启动服务器的异步方法""" - if not self._running: - self._running = True - await self.run() - - async def stop(self): - """停止服务器""" - if hasattr(self, "server"): - self._running = False - # 正确关闭 uvicorn 服务器 - self.server.should_exit = True - await self.server.shutdown() - # 等待服务器完全停止 - if hasattr(self.server, "started") and self.server.started: - await self.server.main_loop() - # 清理处理程序 - self.message_handlers.clear() - - def start(self): - """启动服务器的便捷方法""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(self.start_server()) - except KeyboardInterrupt: - pass - finally: - loop.close() - - -global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"])) +global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())