From 53bb77686b61eab7a7bb42023a034feba79db7f7 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 4 Dec 2025 21:04:55 +0800 Subject: [PATCH 1/6] =?UTF-8?q?Revert=20"feat(chromadb):=20=E6=B7=BB?= =?UTF-8?q?=E5=8A=A0=E5=85=A8=E5=B1=80=E9=94=81=E4=BB=A5=E4=BF=9D=E6=8A=A4?= =?UTF-8?q?=20ChromaDB=20=E6=93=8D=E4=BD=9C=EF=BC=8C=E7=A1=AE=E4=BF=9D?= =?UTF-8?q?=E7=BA=BF=E7=A8=8B=E5=AE=89=E5=85=A8"?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This reverts commit e7cb04bfdde573096d32060c72a4b643bf2e66da. From 43e25378c8871f2e47596fe1697553c3abc85245 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 4 Dec 2025 21:33:50 +0800 Subject: [PATCH 2/6] =?UTF-8?q?feat(napcat):=20=E6=B7=BB=E5=8A=A0=E4=BA=8B?= =?UTF-8?q?=E4=BB=B6=E5=A4=84=E7=90=86=E8=BF=87=E6=BB=A4=E6=9C=BA=E5=88=B6?= =?UTF-8?q?=EF=BC=8C=E6=94=AF=E6=8C=81=E9=BB=91=E7=99=BD=E5=90=8D=E5=8D=95?= =?UTF-8?q?=E5=8A=9F=E8=83=BD?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/built_in/napcat_adapter/plugin.py | 94 +++++++++++++++++++ .../src/handlers/to_core/message_handler.py | 84 +---------------- 2 files changed, 99 insertions(+), 79 deletions(-) diff --git a/src/plugins/built_in/napcat_adapter/plugin.py b/src/plugins/built_in/napcat_adapter/plugin.py index 191437abf..fb1a7ff35 100644 --- a/src/plugins/built_in/napcat_adapter/plugin.py +++ b/src/plugins/built_in/napcat_adapter/plugin.py @@ -88,6 +88,93 @@ class NapcatAdapter(BaseAdapter): # 注册 utils 内部使用的适配器实例,便于工具方法自动获取 WS handler_utils.register_adapter(self) + def _should_process_event(self, raw: Dict[str, Any]) -> bool: + """ + 检查事件是否应该被处理(黑白名单过滤) + + 此方法在 from_platform_message 顶层调用,对所有类型的事件(消息、通知、元事件)进行过滤。 + + Args: + raw: OneBot 原始事件数据 + + Returns: + bool: True表示应该处理,False表示应该过滤 + """ + if not self.plugin: + return True + + plugin_config = self.plugin.config + if not plugin_config: + return True # 如果没有配置,默认处理所有事件 + + features_config = plugin_config.get("features", {}) + post_type = raw.get("post_type") + + # 获取用户信息(根据事件类型从不同字段获取) + user_id: str = "" + if post_type == "message": + sender_info = raw.get("sender", {}) + user_id = str(sender_info.get("user_id", "")) + elif post_type == "notice": + user_id = str(raw.get("user_id", "")) + else: + # 元事件或其他类型不需要过滤 + return True + + # 检查全局封禁用户列表 + ban_user_ids = [str(item) for item in features_config.get("ban_user_id", [])] + if user_id and user_id in ban_user_ids: + logger.debug(f"用户 {user_id} 在全局封禁列表中,事件被过滤") + return False + + # 检查是否屏蔽其他QQ机器人(仅对消息事件生效) + if post_type == "message" and features_config.get("ban_qq_bot", False): + sender_info = raw.get("sender", {}) + role = sender_info.get("role", "") + if role == "admin" or "bot" in str(sender_info).lower(): + logger.debug(f"检测到机器人消息 {user_id},事件被过滤") + return False + + # 获取消息类型(消息事件使用 message_type,通知事件根据 group_id 判断) + message_type = raw.get("message_type") + group_id = raw.get("group_id") + + # 如果是通知事件,根据是否有 group_id 判断是群通知还是私聊通知 + if post_type == "notice": + message_type = "group" if group_id else "private" + + # 群聊/群通知过滤 + if message_type == "group" and group_id: + group_id_str = str(group_id) + group_list_type = features_config.get("group_list_type", "blacklist") + group_list = [str(item) for item in features_config.get("group_list", [])] + + if group_list_type == "blacklist": + if group_id_str in group_list: + logger.debug(f"群聊 {group_id_str} 在黑名单中,事件被过滤") + return False + else: # whitelist + if group_id_str not in group_list: + logger.debug(f"群聊 {group_id_str} 不在白名单中,事件被过滤") + return False + + # 私聊/私聊通知过滤 + elif message_type == "private": + private_list_type = features_config.get("private_list_type", "blacklist") + private_list = [str(item) for item in features_config.get("private_list", [])] + + if private_list_type == "blacklist": + if user_id in private_list: + logger.debug(f"私聊用户 {user_id} 在黑名单中,事件被过滤") + return False + else: # whitelist + if user_id not in private_list: + logger.debug(f"私聊用户 {user_id} 不在白名单中,事件被过滤") + return False + + # 通过所有过滤条件 + return True + async def on_adapter_loaded(self) -> None: """适配器加载时的初始化""" logger.info("Napcat 适配器正在启动...") @@ -161,6 +248,8 @@ class NapcatAdapter(BaseAdapter): - notice 事件 → 通知(戳一戳、表情回复等) - meta_event 事件 → 元事件(心跳、生命周期) - API 响应 → 存入响应池 + + 注意:黑白名单等过滤机制在此方法最开始执行,确保所有类型的事件都能被过滤。 """ post_type = raw.get("post_type") @@ -171,6 +260,11 @@ class NapcatAdapter(BaseAdapter): future = self._response_pool[echo] if not future.done(): future.set_result(raw) + return None + + # 顶层过滤:黑白名单等过滤机制 + if not self._should_process_event(raw): + return None try: # 消息事件 diff --git a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py index 3babf85e6..6cef2fe40 100644 --- a/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py +++ b/src/plugins/built_in/napcat_adapter/src/handlers/to_core/message_handler.py @@ -39,79 +39,6 @@ class MessageHandler: """设置插件配置""" self.plugin_config = config - def _should_process_message(self, raw: Dict[str, Any]) -> bool: - """ - 检查消息是否应该被处理(黑白名单过滤) - - Args: - raw: OneBot 原始消息数据 - - Returns: - bool: True表示应该处理,False表示应该过滤 - """ - if not self.plugin_config: - return True # 如果没有配置,默认处理所有消息 - - features_config = self.plugin_config.get("features", {}) - - # 获取消息基本信息 - message_type = raw.get("message_type") - sender_info = raw.get("sender", {}) - user_id = str(sender_info.get("user_id", "")) - - # 检查全局封禁用户列表 - ban_user_ids = [str(item) for item in features_config.get("ban_user_id", [])] - if user_id in ban_user_ids: - logger.debug(f"用户 {user_id} 在全局封禁列表中,消息被过滤") - return False - - # 检查是否屏蔽其他QQ机器人 - if features_config.get("ban_qq_bot", False): - # 判断是否为机器人消息:通常通过sender中的role字段或其他标识 - role = sender_info.get("role", "") - if role == "admin" or "bot" in str(sender_info).lower(): - logger.debug(f"检测到机器人消息 {user_id},消息被过滤") - return False - - # 群聊消息处理 - if message_type == "group": - group_id = str(raw.get("group_id", "")) - - # 获取群聊配置 - group_list_type = features_config.get("group_list_type", "blacklist") - group_list = [str(item) for item in features_config.get("group_list", [])] - - if group_list_type == "blacklist": - # 黑名单模式:如果在黑名单中就过滤 - if group_id in group_list: - logger.debug(f"群聊 {group_id} 在黑名单中,消息被过滤") - return False - else: # whitelist - # 白名单模式:如果不在白名单中就过滤 - if group_id not in group_list: - logger.debug(f"群聊 {group_id} 不在白名单中,消息被过滤") - return False - - # 私聊消息处理 - elif message_type == "private": - # 获取私聊配置 - private_list_type = features_config.get("private_list_type", "blacklist") - private_list = [str(item) for item in features_config.get("private_list", [])] - - if private_list_type == "blacklist": - # 黑名单模式:如果在黑名单中就过滤 - if user_id in private_list: - logger.debug(f"私聊用户 {user_id} 在黑名单中,消息被过滤") - return False - else: # whitelist - # 白名单模式:如果不在白名单中就过滤 - if user_id not in private_list: - logger.debug(f"私聊用户 {user_id} 不在白名单中,消息被过滤") - return False - - # 通过所有过滤条件 - return True - async def handle_raw_message(self, raw: Dict[str, Any]): """ 处理原始消息并转换为 MessageEnvelope @@ -120,18 +47,17 @@ class MessageHandler: raw: OneBot 原始消息数据 Returns: - MessageEnvelope (dict) or None (if message is filtered) + MessageEnvelope (dict) or None + + Note: + 黑白名单过滤已移动到 NapcatAdapter.from_platform_message 顶层执行, + 确保所有类型的事件(消息、通知等)都能被统一过滤。 """ message_type = raw.get("message_type") message_id = str(raw.get("message_id", "")) message_time = time.time() - # 黑白名单过滤 - if not self._should_process_message(raw): - logger.debug(f"消息被黑白名单过滤丢弃: message_id={message_id}") - return None - msg_builder = MessageBuilder() # 构造用户信息 From 1dfa44b32b2dba28759251292b2f39b1168faac4 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 4 Dec 2025 22:04:27 +0800 Subject: [PATCH 3/6] =?UTF-8?q?fix(config):=20=E6=9B=B4=E6=96=B0=E7=89=88?= =?UTF-8?q?=E6=9C=AC=E5=8F=B7=E8=87=B3=200.13.1-alpha.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/config/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/config/config.py b/src/config/config.py index b0a2c3346..5df56712e 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -65,7 +65,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template") # 考虑到,实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码 # 对该字段的更新,请严格参照语义化版本规范:https://semver.org/lang/zh-CN/ -MMC_VERSION = "0.13.0" +MMC_VERSION = "0.13.1-alpha.1" # 全局配置变量 _CONFIG_INITIALIZED = False From 2e7b434537975415191c190fb4d778444f088631 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 4 Dec 2025 22:40:12 +0800 Subject: [PATCH 4/6] =?UTF-8?q?refactor:=20=E4=BD=BF=E7=94=A8=E5=BC=82?= =?UTF-8?q?=E6=AD=A5=E7=94=9F=E6=88=90=E5=99=A8=E8=BF=81=E7=A7=BB=E5=88=B0?= =?UTF-8?q?=E4=BA=8B=E4=BB=B6=E9=A9=B1=E5=8A=A8=E6=A8=A1=E5=9E=8B=E4=BB=A5?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E8=81=8A=E5=A4=A9=E6=B5=81=E7=AE=A1=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 用异步生成器替换了无限循环任务,以处理聊天流事件。 引入了`ConversationTick`数据模型来表示会话事件。 - 更新了`StreamLoopManager`,以使用新的基于生成器的方法来管理聊天流。 - 在聊天流处理过程中增强了错误处理和日志记录功能。 - 改进了聊天流的生命周期管理,包括启动和停止方法。 - 删除了与之前的循环工作线程实现相关的遗留代码。 --- src/chat/message_manager/__init__.py | 13 +- .../message_manager/distribution_manager.py | 919 ++++++++---------- 2 files changed, 415 insertions(+), 517 deletions(-) diff --git a/src/chat/message_manager/__init__.py b/src/chat/message_manager/__init__.py index 7b67424f9..07b4e5795 100644 --- a/src/chat/message_manager/__init__.py +++ b/src/chat/message_manager/__init__.py @@ -1,14 +1,25 @@ """ 消息管理器模块 提供统一的消息管理、上下文管理和流循环调度功能 + +基于 Generator + Tick 的事件驱动模式 """ -from .distribution_manager import StreamLoopManager, stream_loop_manager +from .distribution_manager import ( + ConversationTick, + StreamLoopManager, + conversation_loop, + run_chat_stream, + stream_loop_manager, +) from .message_manager import MessageManager, message_manager __all__ = [ + "ConversationTick", "MessageManager", "StreamLoopManager", + "conversation_loop", "message_manager", + "run_chat_stream", "stream_loop_manager", ] diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index 90c9929d1..be774081d 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -1,11 +1,18 @@ """ -流循环管理器 -为每个聊天流创建独立的无限循环任务,主动轮询处理消息 +流循环管理器 - 基于 Generator + Tick 的事件驱动模式 + +采用异步生成器替代无限循环任务,实现更简洁可控的消息处理流程。 + +核心概念: +- ConversationTick: 表示一次待处理的会话事件 +- conversation_loop: 异步生成器,按需产出 Tick 事件 +- run_chat_stream: 驱动器,消费 Tick 并调用 Chatter """ import asyncio import time -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, AsyncIterator, Callable, Awaitable from src.chat.chatter_manager import ChatterManager from src.chat.energy_system import energy_manager @@ -14,13 +21,221 @@ from src.config.config import global_config from src.chat.message_receive.chat_stream import get_chat_manager if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream from src.common.data_models.message_manager_data_model import StreamContext logger = get_logger("stream_loop_manager") +# ============================================================================ +# Tick 数据模型 +# ============================================================================ + + +@dataclass +class ConversationTick: + """ + 会话事件标记 - 表示一次待处理的会话事件 + + 这是一个轻量级的事件信号,不存储消息数据。 + 未读消息由 StreamContext 管理,能量值由 energy_manager 管理。 + """ + stream_id: str + tick_time: float = field(default_factory=time.time) + force_dispatch: bool = False # 是否为强制分发(未读消息超阈值) + tick_count: int = 0 # 当前流的 tick 计数 + + +# ============================================================================ +# 异步生成器 - 核心循环逻辑 +# ============================================================================ + + +async def conversation_loop( + stream_id: str, + get_context_func: Callable[[str], Awaitable["StreamContext | None"]], + calculate_interval_func: Callable[[str, bool], Awaitable[float]], + flush_cache_func: Callable[[str], Awaitable[None]], + check_force_dispatch_func: Callable[["StreamContext", int], bool], + is_running_func: Callable[[], bool], +) -> AsyncIterator[ConversationTick]: + """ + 会话循环生成器 - 按需产出 Tick 事件 + + 替代原有的无限循环任务,改为事件驱动的生成器模式。 + 只有调用 __anext__() 时才会执行,完全由消费者控制节奏。 + + Args: + stream_id: 流ID + get_context_func: 获取 StreamContext 的异步函数 + calculate_interval_func: 计算等待间隔的异步函数 + flush_cache_func: 刷新缓存消息的异步函数 + check_force_dispatch_func: 检查是否需要强制分发的函数 + is_running_func: 检查是否继续运行的函数 + + Yields: + ConversationTick: 会话事件 + """ + tick_count = 0 + last_interval = None + + while is_running_func(): + try: + # 1. 获取流上下文 + context = await get_context_func(stream_id) + if not context: + logger.warning(f" [生成器] stream={stream_id[:8]}, 无法获取流上下文") + await asyncio.sleep(10.0) + continue + + # 2. 刷新缓存消息到未读列表 + await flush_cache_func(stream_id) + + # 3. 检查是否有消息需要处理 + unread_messages = context.get_unread_messages() + unread_count = len(unread_messages) if unread_messages else 0 + + # 4. 检查是否需要强制分发 + force_dispatch = check_force_dispatch_func(context, unread_count) + + # 5. 如果有消息,产出 Tick + if unread_count > 0 or force_dispatch: + tick_count += 1 + yield ConversationTick( + stream_id=stream_id, + force_dispatch=force_dispatch, + tick_count=tick_count, + ) + + # 6. 计算并等待下次检查间隔 + has_messages = unread_count > 0 + interval = await calculate_interval_func(stream_id, has_messages) + + # 只在间隔发生变化时输出日志 + if last_interval is None or abs(interval - last_interval) > 0.01: + logger.debug(f"[生成器] stream={stream_id[:8]}, 等待间隔: {interval:.2f}s") + last_interval = interval + + await asyncio.sleep(interval) + + except asyncio.CancelledError: + logger.info(f" [生成器] stream={stream_id[:8]}, 被取消") + break + except Exception as e: + logger.error(f" [生成器] stream={stream_id[:8]}, 出错: {e}") + await asyncio.sleep(5.0) + + +# ============================================================================ +# 聊天流驱动器 +# ============================================================================ + + +async def run_chat_stream( + stream_id: str, + manager: "StreamLoopManager", +) -> None: + """ + 聊天流驱动器 - 消费 Tick 事件并调用 Chatter + + 替代原有的 _stream_loop_worker,结构更清晰。 + + Args: + stream_id: 流ID + manager: StreamLoopManager 实例 + """ + task_id = id(asyncio.current_task()) + logger.debug(f" [驱动器] stream={stream_id[:8]}, 任务ID={task_id}, 启动") + + try: + # 创建生成器 + tick_generator = conversation_loop( + stream_id=stream_id, + get_context_func=manager._get_stream_context, + calculate_interval_func=manager._calculate_interval, + flush_cache_func=manager._flush_cached_messages_to_unread, + check_force_dispatch_func=manager._needs_force_dispatch_for_context, + is_running_func=lambda: manager.is_running, + ) + + # 消费 Tick 事件 + async for tick in tick_generator: + try: + # 获取上下文 + context = await manager._get_stream_context(stream_id) + if not context: + continue + + # 并发保护:检查是否正在处理 + if context.is_chatter_processing: + if manager._recover_stale_chatter_state(stream_id, context): + logger.warning(f" [驱动器] stream={stream_id[:8]}, 处理标志残留已修复") + else: + logger.debug(f" [驱动器] stream={stream_id[:8]}, Chatter正在处理,跳过此Tick") + continue + + # 日志 + if tick.force_dispatch: + logger.info(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 强制分发") + else: + logger.debug(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 开始处理") + + # 更新能量值 + try: + await manager._update_stream_energy(stream_id, context) + except Exception as e: + logger.debug(f"更新能量失败: {e}") + + # 处理消息 + assert global_config is not None + try: + success = await asyncio.wait_for( + manager._process_stream_messages(stream_id, context), + global_config.chat.thinking_timeout + ) + except asyncio.TimeoutError: + logger.warning(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理超时") + success = False + + # 更新统计 + manager.stats["total_process_cycles"] += 1 + if success: + logger.debug(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理成功") + await asyncio.sleep(0.1) # 等待清理操作完成 + else: + manager.stats["total_failures"] += 1 + logger.debug(f" [驱动器] stream={stream_id[:8]}, Tick#{tick.tick_count}, 处理失败") + + except asyncio.CancelledError: + raise + except Exception as e: + logger.error(f" [驱动器] stream={stream_id[:8]}, 处理Tick时出错: {e}") + manager.stats["total_failures"] += 1 + + except asyncio.CancelledError: + logger.info(f" [驱动器] stream={stream_id[:8]}, 任务ID={task_id}, 被取消") + finally: + # 清理任务记录 + try: + context = await manager._get_stream_context(stream_id) + if context and context.stream_loop_task: + context.stream_loop_task = None + logger.debug(f" [驱动器] stream={stream_id[:8]}, 清理任务记录") + except Exception as e: + logger.debug(f"清理任务记录失败: {e}") + + +# ============================================================================ +# 流循环管理器 +# ============================================================================ + + class StreamLoopManager: - """流循环管理器 - 每个流一个独立的无限循环任务""" + """ + 流循环管理器 - 基于 Generator + Tick 的事件驱动模式 + + 管理所有聊天流的生命周期,为每个流创建独立的驱动器任务。 + """ def __init__(self, max_concurrent_streams: int | None = None): if global_config is None: @@ -50,21 +265,22 @@ class StreamLoopManager: # 状态控制 self.is_running = False - # 每个流的上一次间隔值(用于日志去重) - self._last_intervals: dict[str, float] = {} - - # 流循环启动锁:防止并发启动同一个流的多个循环任务 + # 流启动锁:防止并发启动同一个流的多个任务 self._stream_start_locks: dict[str, asyncio.Lock] = {} logger.info(f"流循环管理器初始化完成 (最大并发流数: {self.max_concurrent_streams})") + # ======================================================================== + # 生命周期管理 + # ======================================================================== + async def start(self) -> None: """启动流循环管理器""" if self.is_running: logger.warning("流循环管理器已经在运行") return - self.is_running = True + logger.info("流循环管理器已启动") async def stop(self) -> None: """停止流循环管理器""" @@ -75,13 +291,9 @@ class StreamLoopManager: # 取消所有流循环 try: - # 获取所有活跃的流 - from src.plugin_system.apis.chat_api import get_chat_manager - chat_manager = get_chat_manager() all_streams = chat_manager.get_all_streams() - # 创建任务列表以便并发取消 cancel_tasks = [] for chat_stream in all_streams.values(): context = chat_stream.context @@ -89,7 +301,6 @@ class StreamLoopManager: context.stream_loop_task.cancel() cancel_tasks.append((chat_stream.stream_id, context.stream_loop_task)) - # 并发等待所有任务取消 if cancel_tasks: logger.info(f"正在取消 {len(cancel_tasks)} 个流循环任务...") await asyncio.gather( @@ -103,13 +314,18 @@ class StreamLoopManager: logger.info("流循环管理器已停止") - async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool: - """启动指定流的循环任务 - 优化版本使用自适应管理器 + # ======================================================================== + # 流循环控制 + # ======================================================================== + async def start_stream_loop(self, stream_id: str, force: bool = False) -> bool: + """ + 启动指定流的驱动器任务 + Args: stream_id: 流ID - force: 是否强制启动 - + force: 是否强制启动(会先取消现有任务) + Returns: bool: 是否成功启动 """ @@ -119,207 +335,83 @@ class StreamLoopManager: logger.warning(f"无法获取流上下文: {stream_id}") return False - # 快速路径:如果流已存在且不是强制启动,无需处理 + # 快速路径:如果流已存在且不是强制启动 if not force and context.stream_loop_task and not context.stream_loop_task.done(): - logger.debug(f"🔄 [流循环] stream={stream_id[:8]}, 循环已在运行,跳过启动") + logger.debug(f" [管理器] stream={stream_id[:8]}, 任务已在运行") return True - # 获取或创建该流的启动锁 + # 获取或创建启动锁 if stream_id not in self._stream_start_locks: self._stream_start_locks[stream_id] = asyncio.Lock() - lock = self._stream_start_locks[stream_id] - # 使用锁防止并发启动同一个流的多个循环任务 async with lock: - # 如果是强制启动且任务仍在运行,先取消旧任务 + # 强制启动时先取消旧任务 if force and context.stream_loop_task and not context.stream_loop_task.done(): - logger.warning(f"⚠️ [流循环] stream={stream_id[:8]}, 强制启动模式:先取消现有任务") + logger.warning(f" [管理器] stream={stream_id[:8]}, 强制启动:取消现有任务") old_task = context.stream_loop_task old_task.cancel() try: await asyncio.wait_for(old_task, timeout=2.0) - logger.debug(f"✅ [流循环] stream={stream_id[:8]}, 旧任务已结束") except (asyncio.TimeoutError, asyncio.CancelledError): - logger.debug(f"⏱️ [流循环] stream={stream_id[:8]}, 旧任务已取消或超时") + pass except Exception as e: - logger.warning(f"❌ [流循环] stream={stream_id[:8]}, 等待旧任务结束时出错: {e}") + logger.warning(f"等待旧任务结束时出错: {e}") - # 创建流循环任务 + # 创建新的驱动器任务 try: - # 检查是否有旧任务残留 - if context.stream_loop_task and not context.stream_loop_task.done(): - logger.error(f"🚨 [流循环] stream={stream_id[:8]}, 错误:旧任务仍在运行!这不应该发生!") - # 紧急取消 - context.stream_loop_task.cancel() - await asyncio.sleep(0.1) - - loop_task = asyncio.create_task(self._stream_loop_worker(stream_id), name=f"stream_loop_{stream_id}") - - # 将任务记录到 StreamContext 中 + loop_task = asyncio.create_task( + run_chat_stream(stream_id, self), + name=f"chat_stream_{stream_id}" + ) context.stream_loop_task = loop_task - # 更新统计信息 self.stats["active_streams"] += 1 self.stats["total_loops"] += 1 - logger.info(f"🚀 [流循环] stream={stream_id[:8]}, 启动新的流循环任务,任务ID: {id(loop_task)}") + logger.debug(f" [管理器] stream={stream_id[:8]}, 启动驱动器任务") return True except Exception as e: - logger.error(f"❌ [流循环] stream={stream_id[:8]}, 启动失败: {e}") + logger.error(f" [管理器] stream={stream_id[:8]}, 启动失败: {e}") return False async def stop_stream_loop(self, stream_id: str) -> bool: - """停止指定流的循环任务 - + """ + 停止指定流的驱动器任务 + Args: stream_id: 流ID - + Returns: bool: 是否成功停止 """ - # 获取流上下文 context = await self._get_stream_context(stream_id) if not context: - logger.debug(f"流 {stream_id} 上下文不存在,无需停止") return False - # 检查是否有 stream_loop_task if not context.stream_loop_task or context.stream_loop_task.done(): - logger.debug(f"流 {stream_id} 循环不存在或已结束,无需停止") return False task = context.stream_loop_task - if not task.done(): - task.cancel() - try: - # 设置取消超时,避免无限等待 - await asyncio.wait_for(task, timeout=5.0) - except asyncio.CancelledError: - logger.debug(f"流循环任务已取消: {stream_id}") - except asyncio.TimeoutError: - logger.warning(f"流循环任务取消超时: {stream_id}") - except Exception as e: - logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}") + task.cancel() + try: + await asyncio.wait_for(task, timeout=5.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass + except Exception as e: + logger.error(f"停止任务时出错: {e}") - # 清空 StreamContext 中的任务记录 context.stream_loop_task = None - logger.debug(f"停止流循环: {stream_id}") return True - async def _stream_loop_worker(self, stream_id: str) -> None: - """单个流的工作循环 - 优化版本 - - Args: - stream_id: 流ID - """ - task_id = id(asyncio.current_task()) - logger.info(f"🔄 [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 启动") - - try: - while self.is_running: - try: - # 1. 获取流上下文 - context = await self._get_stream_context(stream_id) - if not context: - logger.warning(f"⚠️ [流工作器] stream={stream_id[:8]}, 无法获取流上下文") - await asyncio.sleep(10.0) - continue - - # 2. 检查是否有消息需要处理 - await self._flush_cached_messages_to_unread(stream_id) - unread_count = self._get_unread_count(context) - force_dispatch = self._needs_force_dispatch_for_context(context, unread_count) - - has_messages = force_dispatch or await self._has_messages_to_process(context) - - if has_messages: - # 🔒 并发保护:如果 Chatter 正在处理中,跳过本轮 - # 这可能发生在:1) 打断后重启循环 2) 处理时间超过轮询间隔 - if context.is_chatter_processing: - if self._recover_stale_chatter_state(stream_id, context): - logger.warning(f"🔄 [流工作器] stream={stream_id[:8]}, 处理标志疑似残留,已尝试自动修复") - else: - logger.debug(f"🔒 [流工作器] stream={stream_id[:8]}, Chatter正在处理中,跳过本轮") - # 不打印"开始处理"日志,直接进入下一轮等待 - # 使用较短的等待时间,等待当前处理完成 - await asyncio.sleep(1.0) - continue - - if force_dispatch: - logger.info(f"⚡ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 未读消息 {unread_count} 条,触发强制分发") - else: - logger.info(f"📨 [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 开始处理消息") - - # 3. 在处理前更新能量值(用于下次间隔计算) - try: - asyncio.create_task(self._update_stream_energy(stream_id, context)) - except Exception as e: - logger.debug(f"更新流能量失败 {stream_id}: {e}") - - # 4. 激活chatter处理 - try: - success = await asyncio.wait_for(self._process_stream_messages(stream_id, context), global_config.chat.thinking_timeout) - except asyncio.TimeoutError: - logger.warning(f"⏱️ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理超时") - success = False - # 更新统计 - self.stats["total_process_cycles"] += 1 - if success: - logger.info(f"✅ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理成功") - - # 🔒 处理成功后,等待一小段时间确保清理操作完成 - # 这样可以避免在 chatter_manager 清除未读消息之前就进入下一轮循环 - await asyncio.sleep(0.1) - else: - self.stats["total_failures"] += 1 - logger.debug(f"❌ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 处理失败") - - # 5. 计算下次检查间隔 - interval = await self._calculate_interval(stream_id, has_messages) - - # 6. sleep等待下次检查 - # 只在间隔发生变化时输出日志,避免刷屏 - last_interval = self._last_intervals.get(stream_id) - if last_interval is None or abs(interval - last_interval) > 0.01: - logger.info(f"流 {stream_id} 等待周期变化: {interval:.2f}s") - self._last_intervals[stream_id] = interval - await asyncio.sleep(interval) - - except asyncio.CancelledError: - logger.info(f"🛑 [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 被取消") - break - except Exception as e: - logger.error(f"❌ [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 出错: {e}") - self.stats["total_failures"] += 1 - await asyncio.sleep(5.0) # 错误时等待5秒再重试 - - finally: - # 清理 StreamContext 中的任务记录 - try: - context = await self._get_stream_context(stream_id) - if context and context.stream_loop_task: - context.stream_loop_task = None - logger.info(f"🧹 [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 清理任务记录") - except Exception as e: - logger.debug(f"清理 StreamContext 任务记录失败: {e}") - - # 清理间隔记录 - self._last_intervals.pop(stream_id, None) - - logger.info(f"🏁 [流工作器] stream={stream_id[:8]}, 任务ID={task_id}, 循环结束") + # ======================================================================== + # 内部方法 - 上下文管理 + # ======================================================================== async def _get_stream_context(self, stream_id: str) -> "StreamContext | None": - """获取流上下文 - - Args: - stream_id: 流ID - - Returns: - Optional[StreamContext]: 流上下文,如果不存在返回None - """ + """获取流上下文""" try: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) @@ -330,32 +422,35 @@ class StreamLoopManager: logger.error(f"获取流上下文失败 {stream_id}: {e}") return None - async def _has_messages_to_process(self, context: "StreamContext") -> bool: - """检查是否有消息需要处理 - - Args: - context: 流上下文 - - Returns: - bool: 是否有未读消息 - """ + async def _flush_cached_messages_to_unread(self, stream_id: str) -> list: + """将缓存消息刷新到未读消息列表""" try: - # 检查是否有未读消息 - if hasattr(context, "unread_messages") and context.unread_messages: - return True + context = await self._get_stream_context(stream_id) + if not context: + return [] - return False + if hasattr(context, "flush_cached_messages"): + cached_messages = context.flush_cached_messages() + if cached_messages: + logger.debug(f"刷新缓存消息: stream={stream_id[:8]}, 数量={len(cached_messages)}") + return cached_messages + return [] except Exception as e: - logger.error(f"检查消息状态失败: {e}") - return False + logger.warning(f"刷新缓存失败: {e}") + return [] + + # ======================================================================== + # 内部方法 - 消息处理 + # ======================================================================== async def _process_stream_messages(self, stream_id: str, context: "StreamContext") -> bool: - """处理流消息 - 支持子任务管理 - + """ + 处理流消息 + Args: stream_id: 流ID context: 流上下文 - + Returns: bool: 是否处理成功 """ @@ -363,338 +458,259 @@ class StreamLoopManager: logger.warning(f"Chatter管理器未设置: {stream_id}") return False - # 🔒 二次并发保护(防御性检查) - # 正常情况下不应该触发,如果触发说明有竞态条件 + # 二次并发保护 if context.is_chatter_processing: - logger.warning(f"🔒 [并发保护] stream={stream_id[:8]}, Chatter正在处理中(二次检查触发,可能存在竞态)") + logger.warning(f" [并发保护] stream={stream_id[:8]}, 二次检查触发") return False - # 设置处理状态为正在处理 self._set_stream_processing_status(stream_id, True) chatter_task = None try: start_time = time.time() - # 检查未读消息,如果为空则直接返回(优化:避免无效的 chatter 调用) + + # 检查未读消息 unread_messages = context.get_unread_messages() if not unread_messages: - logger.debug(f"流 {stream_id} 未读消息为空,跳过 chatter 处理") - return True # 返回 True 表示处理完成(虽然没有实际处理) - - # 🔇 静默群组检查:在静默群组中,只有提到 Bot 名字/别名才响应 - if await self._should_skip_for_mute_group(stream_id, unread_messages): - # 清空未读消息,不触发 chatter - from .message_manager import message_manager - await message_manager.clear_stream_unread_messages(stream_id) - logger.debug(f"🔇 流 {stream_id} 在静默列表中且未提及Bot,跳过处理") + logger.debug(f"未读消息为空,跳过处理: {stream_id}") return True - logger.debug(f"流 {stream_id} 有 {len(unread_messages)} 条未读消息,开始处理") + # 静默群组检查 + if await self._should_skip_for_mute_group(stream_id, unread_messages): + from .message_manager import message_manager + await message_manager.clear_stream_unread_messages(stream_id) + logger.debug(f" 静默群组跳过: {stream_id}") + return True - # 设置触发用户ID,以实现回复保护 + logger.debug(f"处理 {len(unread_messages)} 条未读消息: {stream_id}") + + # 设置触发用户ID last_message = context.get_last_message() if last_message: context.triggering_user_id = last_message.user_info.user_id - # 设置 Chatter 正在处理的标志 + # 设置处理标志 context.is_chatter_processing = True - logger.debug(f"设置 Chatter 处理标志: {stream_id}") - # 创建 chatter 处理任务,以便可以在打断时取消 + # 创建 chatter 任务 chatter_task = asyncio.create_task( self.chatter_manager.process_stream_context(stream_id, context), - name=f"chatter_process_{stream_id}" + name=f"chatter_{stream_id}" ) - - # 记录任务句柄,便于后续检测/自愈 context.processing_task = chatter_task - def _cleanup_processing_flag(task: asyncio.Task) -> None: + # 任务完成回调 + def _cleanup(task: asyncio.Task) -> None: try: context.processing_task = None if context.is_chatter_processing: context.is_chatter_processing = False self._set_stream_processing_status(stream_id, False) - logger.debug(f"🔄 [并发保护] stream={stream_id[:8]}, chatter任务结束自动清理处理标志") - except Exception as callback_error: - logger.debug(f"清理chatter处理标志失败: {callback_error}") + except Exception: + pass - chatter_task.add_done_callback(_cleanup_processing_flag) + chatter_task.add_done_callback(_cleanup) - # 等待 chatter 任务完成 + # 等待完成 results = await chatter_task success = results.get("success", False) if success: - process_time = time.time() - start_time - logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)") + logger.debug(f"处理成功: {stream_id} (耗时: {time.time() - start_time:.2f}s)") else: - logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}") + logger.warning(f"处理失败: {stream_id} - {results.get('error_message', '未知错误')}") return success + except asyncio.CancelledError: if chatter_task and not chatter_task.done(): chatter_task.cancel() raise except Exception as e: - logger.error(f"流处理异常: {stream_id} - {e}") + logger.error(f"处理异常: {stream_id} - {e}") return False finally: - # 清除 Chatter 处理标志 context.is_chatter_processing = False context.processing_task = None - logger.debug(f"清除 Chatter 处理标志: {stream_id}") - - # 无论成功或失败,都要设置处理状态为未处理 self._set_stream_processing_status(stream_id, False) async def _should_skip_for_mute_group(self, stream_id: str, unread_messages: list) -> bool: - """检查是否应该因静默群组而跳过处理 - - 在静默群组中,只有当消息提及 Bot(@、回复、包含名字/别名)时才响应。 - - Args: - stream_id: 流ID - unread_messages: 未读消息列表 - - Returns: - bool: True 表示应该跳过,False 表示正常处理 - """ + """检查是否应该因静默群组而跳过处理""" if global_config is None: return False - - # 获取静默群组列表 + mute_group_list = getattr(global_config.message_receive, "mute_group_list", []) if not mute_group_list: return False - + try: - # 获取 chat_stream 来检查群组信息 chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) - + if not chat_stream or not chat_stream.group_info: - # 不是群聊,不适用静默规则 return False - + group_id = str(chat_stream.group_info.group_id) if group_id not in mute_group_list: - # 不在静默列表中 return False - - # 在静默列表中,检查是否有消息提及 Bot + + # 检查是否有消息提及 Bot bot_name = getattr(global_config.bot, "nickname", "") bot_aliases = getattr(global_config.bot, "alias_names", []) - bot_qq = str(getattr(global_config.bot, "qq_account", "")) - - # 构建需要检测的关键词列表 mention_keywords = [bot_name] + list(bot_aliases) if bot_name else list(bot_aliases) - mention_keywords = [k for k in mention_keywords if k] # 过滤空字符串 - + mention_keywords = [k for k in mention_keywords if k] + for msg in unread_messages: - # 检查是否被 @ 或回复 if getattr(msg, "is_at", False) or getattr(msg, "is_mentioned", False): - logger.debug(f"🔇 静默群组 {group_id}: 消息被@或回复,允许响应") return False - - # 检查消息内容是否包含 Bot 名字或别名 + content = getattr(msg, "processed_plain_text", "") or getattr(msg, "display_message", "") or "" for keyword in mention_keywords: if keyword and keyword in content: - logger.debug(f"🔇 静默群组 {group_id}: 消息包含关键词 '{keyword}',允许响应") return False - - # 没有任何消息提及 Bot - logger.debug(f"🔇 静默群组 {group_id}: {len(unread_messages)} 条消息均未提及Bot,跳过") + return True - + except Exception as e: - logger.warning(f"检查静默群组时出错: {stream_id}, error={e}") + logger.warning(f"检查静默群组出错: {e}") return False def _set_stream_processing_status(self, stream_id: str, is_processing: bool) -> None: """设置流的处理状态""" try: from .message_manager import message_manager - if message_manager.is_running: message_manager.set_stream_processing_status(stream_id, is_processing) - logger.debug(f"设置流处理状态: stream={stream_id}, processing={is_processing}") + except Exception: + pass - except ImportError: - logger.debug("MessageManager不可用,跳过状态设置") - except Exception as e: - logger.warning(f"设置流处理状态失败: stream={stream_id}, error={e}") - - async def _flush_cached_messages_to_unread(self, stream_id: str) -> list: - """将缓存消息刷新到未读消息列表""" + def _recover_stale_chatter_state(self, stream_id: str, context: "StreamContext") -> bool: + """检测并修复 Chatter 处理标志的假死状态""" try: - # 获取流上下文 - context = await self._get_stream_context(stream_id) - if not context: - logger.warning(f"无法获取流上下文: {stream_id}") - return [] + processing_task = getattr(context, "processing_task", None) - # 使用StreamContext的缓存刷新功能 - if hasattr(context, "flush_cached_messages"): - cached_messages = context.flush_cached_messages() - if cached_messages: - logger.debug(f"从StreamContext刷新缓存消息: stream={stream_id}, 数量={len(cached_messages)}") - return cached_messages - else: - logger.debug(f"StreamContext不支持缓存刷新: stream={stream_id}") - return [] + if processing_task is None: + context.is_chatter_processing = False + self._set_stream_processing_status(stream_id, False) + logger.warning(f" [自愈] stream={stream_id[:8]}, 无任务但标志为真") + return True - except Exception as e: - logger.warning(f"刷新StreamContext缓存失败: stream={stream_id}, error={e}") - return [] - async def _update_stream_energy(self, stream_id: str, context: Any) -> None: - """更新流的能量值 + if processing_task.done(): + context.is_chatter_processing = False + context.processing_task = None + self._set_stream_processing_status(stream_id, False) + logger.warning(f" [自愈] stream={stream_id[:8]}, 任务已结束但标志未清") + return True - Args: - stream_id: 流ID - context: 流上下文 (StreamContext) - """ - if global_config is None: - raise RuntimeError("Global config is not initialized") + return False + except Exception: + return False + # ======================================================================== + # 内部方法 - 能量与间隔计算 + # ======================================================================== + + async def _update_stream_energy(self, stream_id: str, context: "StreamContext") -> None: + """更新流的能量值""" try: - from src.chat.message_receive.chat_stream import get_chat_manager - - # 获取聊天流 chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) if not chat_stream: - logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新") return - # 从 context 获取消息(包括未读和历史消息) - # 合并未读消息和历史消息 + assert global_config is not None + # 合并消息 all_messages = [] - - # 添加历史消息 history_messages = context.get_history_messages(limit=global_config.chat.max_context_size) all_messages.extend(history_messages) - - # 添加未读消息 - unread_messages = context.get_unread_messages() - all_messages.extend(unread_messages) - - # 按时间排序并限制数量 + all_messages.extend(context.get_unread_messages()) all_messages.sort(key=lambda m: m.time) messages = all_messages[-global_config.chat.max_context_size:] - # 获取用户ID - user_id = None - if context.triggering_user_id: - user_id = context.triggering_user_id + user_id = context.triggering_user_id - # 使用能量管理器计算并缓存能量值 energy = await energy_manager.calculate_focus_energy( stream_id=stream_id, messages=messages, user_id=user_id ) - # 同步更新到 ChatStream chat_stream._focus_energy = energy - - logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}") + logger.debug(f"更新能量: {stream_id[:8]} -> {energy:.3f}") except Exception as e: - logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False) + logger.warning(f"更新能量失败: {e}") async def _calculate_interval(self, stream_id: str, has_messages: bool) -> float: - """计算下次检查间隔 - - Args: - stream_id: 流ID - has_messages: 本次是否有消息处理 - - Returns: - float: 间隔时间(秒) - """ + """计算下次检查间隔""" if global_config is None: - raise RuntimeError("Global config is not initialized") + return 5.0 - # 私聊使用最小间隔,快速响应 + # 私聊快速响应 try: chat_manager = get_chat_manager() chat_stream = await chat_manager.get_stream(stream_id) if chat_stream and not chat_stream.group_info: - # 私聊:有消息时快速响应,空转时稍微等待 - min_interval = 0.5 if has_messages else 5.0 - logger.debug(f"流 {stream_id} 私聊模式,使用最小间隔: {min_interval:.2f}s") - return min_interval - except Exception as e: - logger.debug(f"检查流 {stream_id} 是否为私聊失败: {e}") + return 0.5 if has_messages else 5.0 + except Exception: + pass - # 基础间隔 base_interval = getattr(global_config.chat, "distribution_interval", 5.0) - # 如果没有消息,使用更长的间隔 if not has_messages: - return base_interval * 2.0 # 无消息时间隔加倍 + return base_interval * 2.0 - # 尝试使用能量管理器计算间隔 + # 基于能量计算间隔 try: - # 获取当前focus_energy focus_energy = energy_manager.energy_cache.get(stream_id, (0.5, 0))[0] - - # 使用能量管理器计算间隔 interval = energy_manager.get_distribution_interval(focus_energy) - - logger.debug(f"流 {stream_id} 动态间隔: {interval:.2f}s (能量: {focus_energy:.3f})") return interval - - except Exception as e: - logger.debug(f"流 {stream_id} 使用默认间隔: {base_interval:.2f}s ({e})") + except Exception: return base_interval - def _recover_stale_chatter_state(self, stream_id: str, context: "StreamContext") -> bool: - """ - 检测并修复 Chatter 处理标志的假死状态。 + def _needs_force_dispatch_for_context(self, context: "StreamContext", unread_count: int | None = None) -> bool: + """检查是否需要强制分发""" + if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0: + return False - 返回 True 表示已发现并修复了异常状态;False 表示未发现异常。 - """ + if unread_count is None: + try: + unread_count = len(context.unread_messages) if context.unread_messages else 0 + except Exception: + return False + + return unread_count > self.force_dispatch_unread_threshold + + # ======================================================================== + # 辅助方法 + # ======================================================================== + + async def _wait_for_task_cancel(self, stream_id: str, task: asyncio.Task) -> None: + """等待任务取消完成""" try: - processing_task = getattr(context, "processing_task", None) - - # 标志为 True 但没有任务句柄,直接修复 - if processing_task is None: - context.is_chatter_processing = False - self._set_stream_processing_status(stream_id, False) - logger.warning(f"🛠️ [自愈] stream={stream_id[:8]}, 发现无任务但标志为真,已重置") - return True - - # 标志为 True 但任务已经结束/被取消 - if processing_task.done(): - context.is_chatter_processing = False - context.processing_task = None - self._set_stream_processing_status(stream_id, False) - logger.warning(f"🛠️ [自愈] stream={stream_id[:8]}, 任务已结束但标志未清,已重置") - return True - - return False + await asyncio.wait_for(task, timeout=5.0) + except (asyncio.CancelledError, asyncio.TimeoutError): + pass except Exception as e: - logger.debug(f"检测 Chatter 状态异常失败: stream={stream_id}, error={e}") - return False + logger.error(f"等待任务取消出错: {e}") + + def set_chatter_manager(self, chatter_manager: ChatterManager) -> None: + """设置 Chatter 管理器""" + self.chatter_manager = chatter_manager + logger.debug(f"设置 Chatter 管理器: {chatter_manager.__class__.__name__}") + + # ======================================================================== + # 统计信息 + # ======================================================================== def get_queue_status(self) -> dict[str, Any]: - """获取队列状态 - - Returns: - Dict[str, Any]: 队列状态信息 - """ + """获取队列状态""" current_time = time.time() uptime = current_time - self.stats["start_time"] if self.is_running else 0 - # 从统计信息中获取活跃流数量 - active_streams = self.stats.get("active_streams", 0) - return { - "active_streams": active_streams, + "active_streams": self.stats.get("active_streams", 0), "total_loops": self.stats["total_loops"], "max_concurrent": self.max_concurrent_streams, "is_running": self.is_running, @@ -704,153 +720,24 @@ class StreamLoopManager: "stats": self.stats.copy(), } - def set_chatter_manager(self, chatter_manager: ChatterManager) -> None: - """设置chatter管理器 - - Args: - chatter_manager: chatter管理器实例 - """ - self.chatter_manager = chatter_manager - logger.debug(f"设置chatter管理器: {chatter_manager.__class__.__name__}") - - async def _should_force_dispatch_for_stream(self, stream_id: str) -> bool: - if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0: - return False - - try: - chat_manager = get_chat_manager() - chat_stream = await chat_manager.get_stream(stream_id) - if not chat_stream: - return False - - unread = getattr(chat_stream.context, "unread_messages", []) - return len(unread) > self.force_dispatch_unread_threshold - except Exception as e: - logger.debug(f"检查流 {stream_id} 是否需要强制分发失败: {e}") - return False - - def _get_unread_count(self, context: "StreamContext") -> int: - try: - unread_messages = context.unread_messages - if unread_messages is None: - return 0 - return len(unread_messages) - except Exception: - return 0 - - def _needs_force_dispatch_for_context(self, context: "StreamContext", unread_count: int | None = None) -> bool: - if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0: - return False - - count = unread_count if unread_count is not None else self._get_unread_count(context) - return count > self.force_dispatch_unread_threshold - def get_performance_summary(self) -> dict[str, Any]: - """获取性能摘要 - - Returns: - Dict[str, Any]: 性能摘要 - """ + """获取性能摘要""" current_time = time.time() uptime = current_time - self.stats["start_time"] - - # 计算吞吐量 - throughput = self.stats["total_process_cycles"] / max(1, uptime / 3600) # 每小时处理次数 - - # 从统计信息中获取活跃流数量 - active_streams = self.stats.get("active_streams", 0) + throughput = self.stats["total_process_cycles"] / max(1, uptime / 3600) return { "uptime_hours": uptime / 3600, - "active_streams": active_streams, + "active_streams": self.stats.get("active_streams", 0), "total_process_cycles": self.stats["total_process_cycles"], "total_failures": self.stats["total_failures"], "throughput_per_hour": throughput, "max_concurrent_streams": self.max_concurrent_streams, } - async def _refresh_focus_energy(self, stream_id: str) -> None: - """分发完成后基于历史消息刷新能量值""" - try: - chat_manager = get_chat_manager() - chat_stream = await chat_manager.get_stream(stream_id) - if not chat_stream: - logger.debug(f"刷新能量时未找到聊天流: {stream_id}") - return - await chat_stream.context.refresh_focus_energy_from_history() - logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量") - except Exception as e: - logger.warning(f"刷新聊天流 {stream_id} 能量失败: {e}") +# ============================================================================ +# 全局实例 +# ============================================================================ - async def _wait_for_task_cancel(self, stream_id: str, task: asyncio.Task) -> None: - """等待任务取消完成,带有超时控制 - - Args: - stream_id: 流ID - task: 要等待取消的任务 - """ - try: - await asyncio.wait_for(task, timeout=5.0) - logger.debug(f"流循环任务已正常结束: {stream_id}") - except asyncio.CancelledError: - logger.debug(f"流循环任务已取消: {stream_id}") - except asyncio.TimeoutError: - logger.warning(f"流循环任务取消超时: {stream_id}") - except Exception as e: - logger.error(f"等待流循环任务结束时出错: {stream_id} - {e}") - - async def _force_dispatch_stream(self, stream_id: str) -> None: - """强制分发流处理 - - 当流的未读消息超过阈值时,强制触发分发处理 - 这个方法主要用于突破并发限制时的紧急处理 - - 注意:此方法目前未被使用,相关功能已集成到 start_stream_loop 方法中 - - Args: - stream_id: 流ID - """ - logger.debug(f"强制分发流处理: {stream_id}") - - try: - # 获取流上下文 - context = await self._get_stream_context(stream_id) - if not context: - logger.warning(f"强制分发时未找到流上下文: {stream_id}") - return - - # 检查是否有现有的 stream_loop_task - if context.stream_loop_task and not context.stream_loop_task.done(): - logger.debug(f"发现现有流循环 {stream_id},将先取消再重新创建") - existing_task = context.stream_loop_task - existing_task.cancel() - # 创建异步任务来等待取消完成,并添加异常处理 - cancel_task = asyncio.create_task( - self._wait_for_task_cancel(stream_id, existing_task), name=f"cancel_existing_loop_{stream_id}" - ) - # 为取消任务添加异常处理,避免孤儿任务 - cancel_task.add_done_callback( - lambda task: logger.debug(f"取消任务完成: {stream_id}") - if not task.exception() - else logger.error(f"取消任务异常: {stream_id} - {task.exception()}") - ) - - # 检查未读消息数量 - unread_count = self._get_unread_count(context) - logger.info(f"流 {stream_id} 当前未读消息数: {unread_count}") - - # 使用 start_stream_loop 重新创建流循环任务 - success = await self.start_stream_loop(stream_id, force=True) - - if success: - logger.info(f"已创建强制分发流循环: {stream_id}") - else: - logger.warning(f"创建强制分发流循环失败: {stream_id}") - - except Exception as e: - logger.error(f"强制分发流处理失败 {stream_id}: {e}") - - -# 全局流循环管理器实例 stream_loop_manager = StreamLoopManager() From 63cb81aab699e4542f41f1ef40e3bf9ecc6c24aa Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 4 Dec 2025 22:49:56 +0800 Subject: [PATCH 5/6] =?UTF-8?q?fix:=20=E6=9B=B4=E6=96=B0=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E6=96=87=E4=BB=B6=EF=BC=8C=E7=A6=81=E7=94=A8=E8=BA=AB=E4=BB=BD?= =?UTF-8?q?=E5=8E=8B=E7=BC=A9=E4=BB=A5=E6=8F=90=E9=AB=98=E5=9B=9E=E5=A4=8D?= =?UTF-8?q?=E6=80=A7=E8=83=BD=20refactor:=20=E7=A7=BB=E9=99=A4=E6=97=A5?= =?UTF-8?q?=E5=BF=97=E4=BF=A1=E6=81=AF=E4=BB=A5=E7=AE=80=E5=8C=96=E5=9B=9E?= =?UTF-8?q?=E5=A4=8D=E5=92=8C=E5=93=8D=E5=BA=94=E5=8A=A8=E4=BD=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/built_in/affinity_flow_chatter/actions/reply.py | 4 +--- .../tools/chat_stream_impression_tool.py | 3 +-- template/bot_config_template.toml | 4 ++-- 3 files changed, 4 insertions(+), 7 deletions(-) diff --git a/src/plugins/built_in/affinity_flow_chatter/actions/reply.py b/src/plugins/built_in/affinity_flow_chatter/actions/reply.py index 74311f501..3ab0fe909 100644 --- a/src/plugins/built_in/affinity_flow_chatter/actions/reply.py +++ b/src/plugins/built_in/affinity_flow_chatter/actions/reply.py @@ -96,7 +96,6 @@ class ReplyAction(BaseAction): # 发送回复 reply_text = await self._send_response(response_set) - logger.info(f"{self.log_prefix} reply 动作执行成功") return True, reply_text except asyncio.CancelledError: @@ -218,8 +217,7 @@ class RespondAction(BaseAction): # 发送回复(respond 默认不引用) reply_text = await self._send_response(response_set) - - logger.info(f"{self.log_prefix} respond 动作执行成功") + return True, reply_text except asyncio.CancelledError: diff --git a/src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py b/src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py index 936f160b8..65544a73b 100644 --- a/src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py +++ b/src/plugins/built_in/affinity_flow_chatter/tools/chat_stream_impression_tool.py @@ -126,7 +126,6 @@ class ChatStreamImpressionTool(BaseTool): updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}") result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates) - logger.info(f"聊天流印象更新成功: {stream_id}") return {"type": "chat_stream_impression_update", "id": stream_id, "content": result_text} @@ -214,7 +213,7 @@ class ChatStreamImpressionTool(BaseTool): await cache.delete(generate_cache_key("stream_impression", stream_id)) await cache.delete(generate_cache_key("chat_stream", stream_id)) - logger.info(f"聊天流印象已更新到数据库: {stream_id}") + logger.debug(f"聊天流印象已更新到数据库: {stream_id}") else: error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象" logger.error(error_msg) diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index ca9346f65..eec2fbd60 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -131,7 +131,7 @@ safety_guidelines = [ ] compress_personality = false # 是否压缩人格,压缩后会精简人格信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果人设不长,可以关闭 -compress_identity = true # 是否压缩身份,压缩后会精简身份信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果不长,可以关闭 +compress_identity = false # 是否压缩身份,压缩后会精简身份信息,节省token消耗并提高回复性能,但是会丢失一些信息,如果不长,可以关闭 [expression] # 表达学习配置 @@ -283,7 +283,7 @@ path_expansion_path_score_weight = 0.50 # 路径分数在最终评分中的权 path_expansion_importance_weight = 0.30 # 重要性在最终评分中的权重 path_expansion_recency_weight = 0.20 # 时效性在最终评分中的权重 -# 🆕 路径扩展 - 记忆去重配置 +# 路径扩展 - 记忆去重配置 enable_memory_deduplication = true # 启用检索结果去重(合并相似记忆) memory_deduplication_threshold = 0.85 # 记忆相似度阈值(0.85表示85%相似即合并) From 06a45b363922b5790e9e61f4b152d5bb1d671940 Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 4 Dec 2025 23:30:43 +0800 Subject: [PATCH 6/6] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=E5=AF=B9=20M?= =?UTF-8?q?ySQL=20=E7=9A=84=E6=94=AF=E6=8C=81=EF=BC=8C=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E6=96=87=E6=A1=A3=E5=92=8C=E9=85=8D=E7=BD=AE=E4=BB=A5=E4=BB=85?= =?UTF-8?q?=E6=94=AF=E6=8C=81=20SQLite=20=E5=92=8C=20PostgreSQL?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .github/copilot-instructions.md | 4 +- README.md | 6 +- bot.py | 2 +- pyproject.toml | 5 +- requirements.txt | 2 - scripts/migrate_database.py | 109 ++++++------------ .../message_manager/batch_database_writer.py | 14 --- src/chat/message_receive/chat_stream.py | 6 - src/common/database/core/__init__.py | 1 - src/common/database/core/dialect_adapter.py | 32 +---- src/common/database/core/engine.py | 54 +-------- src/common/database/core/migration.py | 3 - src/common/database/core/models.py | 9 +- src/common/database/core/session.py | 5 +- .../database/optimization/connection_pool.py | 4 +- src/config/official_configs.py | 19 +-- src/plugin_system/utils/dependency_alias.py | 1 - template/bot_config_template.toml | 25 +--- 18 files changed, 54 insertions(+), 247 deletions(-) diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index 819189b09..50d156157 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -157,7 +157,7 @@ python __main__.py # 备用入口 **调试技巧**: - 检查 `logs/app_*.jsonl` 结构化日志 - 使用 `get_errors()` 工具查看编译错误 -- 数据库问题:查看 `data/MaiBot.db`(SQLite)或 MySQL 连接 +- 数据库问题:查看 `data/MaiBot.db`(SQLite)或 PostgreSQL 连接 ## 📋 关键约定与模式 @@ -165,7 +165,7 @@ python __main__.py # 备用入口 **全局配置**: `src/config/config.py` 的 `global_config` 单例 - 通过 TOML 文件驱动(`config/bot_config.toml`) - 支持环境变量覆盖(`.env`) -- 数据库类型切换:`database.database_type = "sqlite" | "mysql"` +- 数据库类型切换:`database.database_type = "sqlite" | "postgresql"` ### 事件系统 **Event Manager** (`src/plugin_system/core/event_manager.py`): diff --git a/README.md b/README.md index a4f532895..f4ae5e26c 100644 --- a/README.md +++ b/README.md @@ -75,7 +75,7 @@ ### 🚀 拓展功能 - 🧠 **AFC 智能对话** - 基于亲和力流,实现兴趣感知和动态关系构建 -- 🔄 **数据库切换** - 支持 SQLite 与 MySQL 自由切换,采用 SQLAlchemy 2.0 重新构建 +- 🔄 **数据库切换** - 支持 SQLite 与 PostgreSQL 自由切换,采用 SQLAlchemy 2.0 重新构建 - 🛡️ **反注入集成** - 内置一整套回复前注入过滤系统,为人格保驾护航 - 🎥 **视频分析** - 支持多种视频识别模式,拓展原版视觉 - 📅 **日程系统** - 让MoFox规划每一天 @@ -109,7 +109,7 @@ | 服务 | 描述 | | ------------ | ------------------------------------------ | | 🤖 QQ 协议端 | [NapCatQQ](https://github.com/NapNeko/NapCatQQ) 或其他兼容协议端 | -| 🗃️ 数据库 | SQLite(默认)或 MySQL(可选) | +| 🗃️ 数据库 | SQLite(默认)或 PostgreSQL(可选) | | 🔧 管理工具 | Chat2DB(可选,用于数据库可视化管理) | --- @@ -133,7 +133,7 @@ 1. 📝 **核心配置**:编辑 `config/bot_config.toml`,设置 LLM API Key、Bot 名称等基础参数。 2. 🤖 **协议端配置**:确保使用 [NapCatQQ](https://github.com/NapNeko/NapCatQQ) 或兼容协议端,建立稳定通信。 -3. 🗃️ **数据库配置**:选择 SQLite(默认)或配置 MySQL 数据库连接。 +3. 🗃️ **数据库配置**:选择 SQLite(默认)或配置 PostgreSQL 数据库连接。 4. 🔌 **插件配置**:在 `config/plugins/` 目录中启用或配置所需插件。 diff --git a/bot.py b/bot.py index b3c653096..f752d81ab 100644 --- a/bot.py +++ b/bot.py @@ -21,7 +21,7 @@ logger = get_logger("main") install(extra_lines=3) # 常量定义 -SUPPORTED_DATABASES = ["sqlite", "mysql", "postgresql"] +SUPPORTED_DATABASES = ["sqlite", "postgresql"] SHUTDOWN_TIMEOUT = 10.0 EULA_CHECK_INTERVAL = 2 MAX_EULA_CHECK_ATTEMPTS = 30 diff --git a/pyproject.toml b/pyproject.toml index 6f8dc92a7..be298b743 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,6 @@ dependencies = [ "pydantic>=2.12.3", "pygments>=2.19.2", "pymongo>=4.13.2", - "pymysql>=1.1.1", "pypinyin>=0.54.0", "PyYAML>=6.0", "python-dateutil>=2.9.0.post0", @@ -74,15 +73,13 @@ dependencies = [ "uvicorn>=0.35.0", "watchdog>=6.0.0", "websockets>=15.0.1", - "aiomysql>=0.2.0", "aiosqlite>=0.21.0", "inkfox>=0.1.1", "rjieba>=0.1.13", "fastmcp>=2.13.0", "mofox-wire", "jinja2>=3.1.0", - "psycopg2-binary", - "PyMySQL" + "psycopg2-binary" ] [[tool.uv.index]] diff --git a/requirements.txt b/requirements.txt index e91f9f380..cb640d6a6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,8 @@ aiosqlite aiofiles -aiomysql asyncpg psycopg[binary] psycopg2-binary -PyMySQL APScheduler aiohttp aiohttp-cors diff --git a/scripts/migrate_database.py b/scripts/migrate_database.py index b5e5c30c6..9f4bf71d3 100644 --- a/scripts/migrate_database.py +++ b/scripts/migrate_database.py @@ -2,14 +2,12 @@ """数据库迁移脚本 支持在不同数据库之间迁移数据: -- SQLite <-> MySQL - SQLite <-> PostgreSQL -- MySQL <-> PostgreSQL 使用方法: python scripts/migrate_database.py --help python scripts/migrate_database.py --source sqlite --target postgresql - python scripts/migrate_database.py --source mysql --target postgresql --batch-size 5000 + python scripts/migrate_database.py --source postgresql --target sqlite --batch-size 5000 # 交互式向导模式(推荐) python scripts/migrate_database.py @@ -25,7 +23,7 @@ 实现细节: - 使用 SQLAlchemy 进行数据库连接和元数据管理 - 采用流式迁移,避免一次性加载过多数据 -- 支持 SQLite、MySQL、PostgreSQL 之间的互相迁移 +- 支持 SQLite、PostgreSQL 之间的互相迁移 - 批量插入失败时自动降级为逐行插入,最大程度保留数据 """ @@ -124,7 +122,7 @@ def get_database_config_from_toml(db_type: str) -> dict | None: """从 bot_config.toml 中读取数据库配置 Args: - db_type: 数据库类型,支持 "sqlite"、"mysql"、"postgresql" + db_type: 数据库类型,支持 "sqlite"、"postgresql" Returns: dict: 数据库配置字典,如果对应配置不存在则返回 None @@ -148,28 +146,6 @@ def get_database_config_from_toml(db_type: str) -> dict | None: sqlite_path = os.path.join(PROJECT_ROOT, sqlite_path) return {"path": sqlite_path} - elif db_type == "mysql": - return { - "host": db_config.get("mysql_host") - or config_data.get("mysql_host") - or "localhost", - "port": db_config.get("mysql_port") - or config_data.get("mysql_port") - or 3306, - "database": db_config.get("mysql_database") - or config_data.get("mysql_database") - or "maibot", - "user": db_config.get("mysql_user") - or config_data.get("mysql_user") - or "root", - "password": db_config.get("mysql_password") - or config_data.get("mysql_password") - or "", - "charset": db_config.get("mysql_charset") - or config_data.get("mysql_charset") - or "utf8mb4", - } - elif db_type == "postgresql": return { "host": db_config.get("postgresql_host") @@ -257,7 +233,7 @@ def create_engine_by_type(db_type: str, config: dict) -> Engine: """根据数据库类型创建对应的 SQLAlchemy Engine Args: - db_type: 数据库类型,支持 sqlite/mysql/postgresql + db_type: 数据库类型,支持 sqlite/postgresql config: 配置字典 Returns: @@ -266,15 +242,6 @@ def create_engine_by_type(db_type: str, config: dict) -> Engine: db_type = db_type.lower() if db_type == "sqlite": return create_sqlite_engine(config["path"]) - elif db_type == "mysql": - return create_mysql_engine( - host=config["host"], - port=config["port"], - database=config["database"], - user=config["user"], - password=config["password"], - charset=config.get("charset", "utf8mb4"), - ) elif db_type == "postgresql": return create_postgresql_engine( host=config["host"], @@ -512,7 +479,7 @@ def migrate_table_data( source_table: 源表对象 target_table: 目标表对象 batch_size: 每批次处理大小 - target_dialect: 目标数据库方言 (sqlite/mysql/postgresql) + target_dialect: 目标数据库方言 (sqlite/postgresql) row_limit: 最大迁移行数限制,None 表示不限制 Returns: @@ -738,7 +705,7 @@ class DatabaseMigrator: def _validate_database_types(self): """验证数据库类型""" - supported_types = {"sqlite", "mysql", "postgresql"} + supported_types = {"sqlite", "postgresql"} if self.source_type not in supported_types: raise ValueError(f"不支持的源数据库类型: {self.source_type}") if self.target_type not in supported_types: @@ -995,7 +962,7 @@ class DatabaseMigrator: def parse_args(): """解析命令行参数""" parser = argparse.ArgumentParser( - description="数据库迁移工具 - 在 SQLite、MySQL、PostgreSQL 之间迁移数据", + description="数据库迁移工具 - 在 SQLite、PostgreSQL 之间迁移数据", formatter_class=argparse.RawDescriptionHelpFormatter, epilog="""示例: # 从 SQLite 迁移到 PostgreSQL @@ -1008,15 +975,16 @@ def parse_args(): --target-user postgres \ --target-password your_password - # 从 SQLite 迁移到 MySQL + # 从 PostgreSQL 迁移到 SQLite python scripts/migrate_database.py \ - --source sqlite \ - --target mysql \ - --target-host localhost \ - --target-port 3306 \ - --target-database maibot \ - --target-user root \ - --target-password your_password + --source postgresql \ + --source-host localhost \ + --source-port 5432 \ + --source-database maibot \ + --source-user postgres \ + --source-password your_password \ + --target sqlite \ + --target-path data/MaiBot_backup.db # 使用交互式向导模式(推荐) python scripts/migrate_database.py @@ -1028,13 +996,13 @@ def parse_args(): parser.add_argument( "--source", type=str, - choices=["sqlite", "mysql", "postgresql"], + choices=["sqlite", "postgresql"], help="源数据库类型(不指定时,在交互模式中选择)", ) parser.add_argument( "--target", type=str, - choices=["sqlite", "mysql", "postgresql"], + choices=["sqlite", "postgresql"], help="目标数据库类型(不指定时,在交互模式中选择)", ) parser.add_argument( @@ -1053,8 +1021,8 @@ def parse_args(): # 源数据库参数(可选,默认从 bot_config.toml 读取) source_group = parser.add_argument_group("源数据库配置(可选,默认从 bot_config.toml 读取)") source_group.add_argument("--source-path", type=str, help="SQLite 数据库路径") - source_group.add_argument("--source-host", type=str, help="MySQL/PostgreSQL 主机") - source_group.add_argument("--source-port", type=int, help="MySQL/PostgreSQL 端口") + source_group.add_argument("--source-host", type=str, help="PostgreSQL 主机") + source_group.add_argument("--source-port", type=int, help="PostgreSQL 端口") source_group.add_argument("--source-database", type=str, help="数据库名") source_group.add_argument("--source-user", type=str, help="用户名") source_group.add_argument("--source-password", type=str, help="密码") @@ -1062,13 +1030,12 @@ def parse_args(): # 目标数据库参数 target_group = parser.add_argument_group("目标数据库配置") target_group.add_argument("--target-path", type=str, help="SQLite 数据库路径") - target_group.add_argument("--target-host", type=str, help="MySQL/PostgreSQL 主机") - target_group.add_argument("--target-port", type=int, help="MySQL/PostgreSQL 端口") + target_group.add_argument("--target-host", type=str, help="PostgreSQL 主机") + target_group.add_argument("--target-port", type=int, help="PostgreSQL 端口") target_group.add_argument("--target-database", type=str, help="数据库名") target_group.add_argument("--target-user", type=str, help="用户名") target_group.add_argument("--target-password", type=str, help="密码") target_group.add_argument("--target-schema", type=str, default="public", help="PostgreSQL schema") - target_group.add_argument("--target-charset", type=str, default="utf8mb4", help="MySQL 字符集") # 跳过表参数 parser.add_argument( @@ -1113,24 +1080,20 @@ def build_config_from_args(args, prefix: str, db_type: str) -> dict | None: return {"path": path} return None - elif db_type in ("mysql", "postgresql"): + elif db_type == "postgresql": host = getattr(args, f"{prefix}_host", None) if not host: return None config = { "host": host, - "port": getattr(args, f"{prefix}_port") or (3306 if db_type == "mysql" else 5432), + "port": getattr(args, f"{prefix}_port") or 5432, "database": getattr(args, f"{prefix}_database") or "maibot", - "user": getattr(args, f"{prefix}_user") or ("root" if db_type == "mysql" else "postgres"), + "user": getattr(args, f"{prefix}_user") or "postgres", "password": getattr(args, f"{prefix}_password") or "", + "schema": getattr(args, f"{prefix}_schema", "public"), } - if db_type == "mysql": - config["charset"] = getattr(args, f"{prefix}_charset", "utf8mb4") - elif db_type == "postgresql": - config["schema"] = getattr(args, f"{prefix}_schema", "public") - return config return None @@ -1201,14 +1164,14 @@ def interactive_setup() -> dict: print("只需回答几个问题,我会帮你构造迁移配置。") print("=" * 60) - db_types = ["sqlite", "mysql", "postgresql"] + db_types = ["sqlite", "postgresql"] # 选择源数据库 source_type = _ask_choice("请选择【源数据库类型】:", db_types, default_index=0) # 选择目标数据库(不能与源相同) while True: - default_idx = 2 if len(db_types) >= 3 else 0 + default_idx = 1 if len(db_types) >= 2 else 0 target_type = _ask_choice("请选择【目标数据库类型】:", db_types, default_index=default_idx) if target_type != source_type: break @@ -1231,8 +1194,8 @@ def interactive_setup() -> dict: source_path = _ask_str("源 SQLite 文件路径", default="data/MaiBot.db") source_config = {"path": source_path} else: - port_default = 3306 if source_type == "mysql" else 5432 - user_default = "root" if source_type == "mysql" else "postgres" + port_default = 5432 + user_default = "postgres" host = _ask_str("源数据库 host", default="localhost") port = _ask_int("源数据库 port", default=port_default) database = _ask_str("源数据库名", default="maibot") @@ -1245,9 +1208,7 @@ def interactive_setup() -> dict: "user": user, "password": password, } - if source_type == "mysql": - source_config["charset"] = _ask_str("源数据库字符集", default="utf8mb4") - elif source_type == "postgresql": + if source_type == "postgresql": source_config["schema"] = _ask_str("源数据库 schema", default="public") # 目标数据库配置(必须显式确认) @@ -1260,8 +1221,8 @@ def interactive_setup() -> dict: ) target_config = {"path": target_path} else: - port_default = 3306 if target_type == "mysql" else 5432 - user_default = "root" if target_type == "mysql" else "postgres" + port_default = 5432 + user_default = "postgres" host = _ask_str("目标数据库 host", default="localhost") port = _ask_int("目标数据库 port", default=port_default) database = _ask_str("目标数据库名", default="maibot") @@ -1275,9 +1236,7 @@ def interactive_setup() -> dict: "user": user, "password": password, } - if target_type == "mysql": - target_config["charset"] = _ask_str("目标数据库字符集", default="utf8mb4") - elif target_type == "postgresql": + if target_type == "postgresql": target_config["schema"] = _ask_str("目标数据库 schema", default="public") print() diff --git a/src/chat/message_manager/batch_database_writer.py b/src/chat/message_manager/batch_database_writer.py index f5b9e1c18..74128d8d9 100644 --- a/src/chat/message_manager/batch_database_writer.py +++ b/src/chat/message_manager/batch_database_writer.py @@ -234,13 +234,6 @@ class BatchDatabaseWriter: stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) - elif global_config.database.database_type == "mysql": - from sqlalchemy.dialects.mysql import insert as mysql_insert - - stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data) - stmt = stmt.on_duplicate_key_update( - **{key: value for key, value in update_data.items() if key != "stream_id"} - ) elif global_config.database.database_type == "postgresql": from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -268,13 +261,6 @@ class BatchDatabaseWriter: stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data) - elif global_config.database.database_type == "mysql": - from sqlalchemy.dialects.mysql import insert as mysql_insert - - stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data) - stmt = stmt.on_duplicate_key_update( - **{key: value for key, value in update_data.items() if key != "stream_id"} - ) elif global_config.database.database_type == "postgresql": from sqlalchemy.dialects.postgresql import insert as pg_insert diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 0d647687b..b3d61a331 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -3,7 +3,6 @@ import hashlib import time from rich.traceback import install -from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert @@ -665,11 +664,6 @@ class ChatManager: if global_config.database.database_type == "sqlite": stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save) - elif global_config.database.database_type == "mysql": - stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) - stmt = stmt.on_duplicate_key_update( - **{key: value for key, value in fields_to_save.items() if key != "stream_id"} - ) elif global_config.database.database_type == "postgresql": stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) # PostgreSQL 需要使用 constraint 参数或正确的 index_elements diff --git a/src/common/database/core/__init__.py b/src/common/database/core/__init__.py index 2e457f37a..09f7f5c56 100644 --- a/src/common/database/core/__init__.py +++ b/src/common/database/core/__init__.py @@ -9,7 +9,6 @@ 支持的数据库: - SQLite (默认) -- MySQL - PostgreSQL """ diff --git a/src/common/database/core/dialect_adapter.py b/src/common/database/core/dialect_adapter.py index e99eb47ae..09156f230 100644 --- a/src/common/database/core/dialect_adapter.py +++ b/src/common/database/core/dialect_adapter.py @@ -2,7 +2,6 @@ 提供跨数据库兼容性支持,处理不同数据库之间的差异: - SQLite: 轻量级本地数据库 -- MySQL: 高性能关系型数据库 - PostgreSQL: 功能丰富的开源数据库 主要职责: @@ -23,7 +22,6 @@ class DatabaseDialect(Enum): """数据库方言枚举""" SQLITE = "sqlite" - MYSQL = "mysql" POSTGRESQL = "postgresql" @@ -68,20 +66,6 @@ DIALECT_CONFIGS: dict[DatabaseDialect, DialectConfig] = { } }, ), - DatabaseDialect.MYSQL: DialectConfig( - dialect=DatabaseDialect.MYSQL, - ping_query="SELECT 1", - supports_returning=False, # MySQL 8.0.21+ 有限支持 - supports_native_json=True, # MySQL 5.7+ - supports_arrays=False, - requires_length_for_index=True, # MySQL 索引需要指定长度 - default_string_length=255, - isolation_level="READ COMMITTED", - engine_kwargs={ - "pool_pre_ping": True, - "pool_recycle": 3600, - }, - ), DatabaseDialect.POSTGRESQL: DialectConfig( dialect=DatabaseDialect.POSTGRESQL, ping_query="SELECT 1", @@ -113,13 +97,13 @@ class DialectAdapter: """初始化适配器 Args: - db_type: 数据库类型字符串 ("sqlite", "mysql", "postgresql") + db_type: 数据库类型字符串 ("sqlite", "postgresql") """ try: cls._current_dialect = DatabaseDialect(db_type.lower()) cls._config = DIALECT_CONFIGS[cls._current_dialect] except ValueError: - raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, mysql, postgresql") + raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, postgresql") @classmethod def get_dialect(cls) -> DatabaseDialect: @@ -153,15 +137,10 @@ class DialectAdapter: """ config = cls.get_config() - # MySQL 索引列需要指定长度 - if config.requires_length_for_index and indexed: - return String(max_length) - # SQLite 和 PostgreSQL 可以使用 Text if config.dialect in (DatabaseDialect.SQLITE, DatabaseDialect.POSTGRESQL): return Text() if not indexed else String(max_length) - # MySQL 使用 VARCHAR return String(max_length) @classmethod @@ -189,11 +168,6 @@ class DialectAdapter: """是否为 SQLite""" return cls.get_dialect() == DatabaseDialect.SQLITE - @classmethod - def is_mysql(cls) -> bool: - """是否为 MySQL""" - return cls.get_dialect() == DatabaseDialect.MYSQL - @classmethod def is_postgresql(cls) -> bool: """是否为 PostgreSQL""" @@ -211,7 +185,7 @@ def get_indexed_string_field(max_length: int = 255) -> TypeEngine: 这是一个便捷函数,用于在模型定义中获取适合当前数据库的字符串类型 Args: - max_length: 最大长度(对于 MySQL 是必需的) + max_length: 最大长度 Returns: SQLAlchemy 类型 diff --git a/src/common/database/core/engine.py b/src/common/database/core/engine.py index 064178595..d28f307cb 100644 --- a/src/common/database/core/engine.py +++ b/src/common/database/core/engine.py @@ -4,7 +4,6 @@ 支持的数据库类型: - SQLite: 轻量级本地数据库,使用 aiosqlite 驱动 -- MySQL: 高性能关系型数据库,使用 aiomysql 驱动 - PostgreSQL: 功能丰富的开源数据库,使用 asyncpg 驱动 """ @@ -66,9 +65,7 @@ async def get_engine() -> AsyncEngine: logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...") # 根据数据库类型构建URL和引擎参数 - if db_type == "mysql": - url, engine_kwargs = _build_mysql_config(config) - elif db_type == "postgresql": + if db_type == "postgresql": url, engine_kwargs = _build_postgresql_config(config) else: url, engine_kwargs = _build_sqlite_config(config) @@ -123,55 +120,6 @@ def _build_sqlite_config(config) -> tuple[str, dict]: return url, engine_kwargs -def _build_mysql_config(config) -> tuple[str, dict]: - """构建 MySQL 配置 - - Args: - config: 数据库配置对象 - - Returns: - (url, engine_kwargs) 元组 - """ - encoded_user = quote_plus(config.mysql_user) - encoded_password = quote_plus(config.mysql_password) - - if config.mysql_unix_socket: - # Unix socket连接 - encoded_socket = quote_plus(config.mysql_unix_socket) - url = ( - f"mysql+aiomysql://{encoded_user}:{encoded_password}" - f"@/{config.mysql_database}" - f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" - ) - else: - # TCP连接 - url = ( - f"mysql+aiomysql://{encoded_user}:{encoded_password}" - f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" - f"?charset={config.mysql_charset}" - ) - - engine_kwargs = { - "echo": False, - "future": True, - "pool_size": config.connection_pool_size, - "max_overflow": config.connection_pool_size * 2, - "pool_timeout": config.connection_timeout, - "pool_recycle": 3600, - "pool_pre_ping": True, - "connect_args": { - "autocommit": config.mysql_autocommit, - "charset": config.mysql_charset, - "connect_timeout": config.connection_timeout, - }, - } - - logger.info( - f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" - ) - return url, engine_kwargs - - def _build_postgresql_config(config) -> tuple[str, dict]: """构建 PostgreSQL 配置 diff --git a/src/common/database/core/migration.py b/src/common/database/core/migration.py index c1408c791..fff355fae 100644 --- a/src/common/database/core/migration.py +++ b/src/common/database/core/migration.py @@ -119,9 +119,6 @@ async def check_and_migrate_database(existing_engine=None): ): # SQLite 将布尔值存储为 0 或 1 default_value = "1" if default_arg else "0" - elif dialect.name == "mysql" and isinstance(default_arg, bool): - # MySQL 也使用 1/0 表示布尔值 - default_value = "1" if default_arg else "0" elif isinstance(default_arg, bool): # PostgreSQL 使用 TRUE/FALSE default_value = "TRUE" if default_arg else "FALSE" diff --git a/src/common/database/core/models.py b/src/common/database/core/models.py index 175173639..6125b4b02 100644 --- a/src/common/database/core/models.py +++ b/src/common/database/core/models.py @@ -5,7 +5,6 @@ 支持的数据库类型: - SQLite: 使用 Text 类型 -- MySQL: 使用 VARCHAR(max_length) 用于索引字段 - PostgreSQL: 使用 Text 类型(PostgreSQL 的 Text 类型性能与 VARCHAR 相当) 所有模型使用统一的类型注解风格: @@ -31,12 +30,11 @@ def get_string_field(max_length=255, **kwargs): 根据数据库类型返回合适的字符串字段类型 对于需要索引的字段: - - MySQL: 必须使用 VARCHAR(max_length),因为索引需要指定长度 - PostgreSQL: 可以使用 Text,但为了兼容性使用 VARCHAR - SQLite: 可以使用 Text,无长度限制 Args: - max_length: 最大长度(对于 MySQL 是必需的) + max_length: 最大长度 **kwargs: 传递给 String/Text 的额外参数 Returns: @@ -47,11 +45,8 @@ def get_string_field(max_length=255, **kwargs): assert global_config is not None db_type = global_config.database.database_type - # MySQL 索引需要指定长度的 VARCHAR - if db_type == "mysql": - return String(max_length, **kwargs) # PostgreSQL 可以使用 Text,但为了跨数据库迁移兼容性,使用 VARCHAR - elif db_type == "postgresql": + if db_type == "postgresql": return String(max_length, **kwargs) # SQLite 使用 Text(无长度限制) else: diff --git a/src/common/database/core/session.py b/src/common/database/core/session.py index 751bb51bf..e3df1a03e 100644 --- a/src/common/database/core/session.py +++ b/src/common/database/core/session.py @@ -4,7 +4,6 @@ 支持的数据库类型: - SQLite: 设置 PRAGMA 参数优化并发 -- MySQL: 无特殊会话设置 - PostgreSQL: 可选设置 schema 搜索路径 """ @@ -79,7 +78,6 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None: schema = global_config.database.postgresql_schema if schema and schema != "public": await session.execute(text(f"SET search_path TO {schema}")) - # MySQL 通常不需要会话级别的特殊设置 except Exception: # 复用连接时设置可能已存在,忽略错误 pass @@ -93,7 +91,6 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: 支持的数据库: - SQLite: 自动设置 busy_timeout 和外键约束 - - MySQL: 直接使用,无特殊设置 - PostgreSQL: 支持自定义 schema 使用示例: @@ -132,7 +129,7 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]: - 正常退出时自动提交事务 - 发生异常时自动回滚事务 - 如果用户代码已手动调用 commit/rollback,再次调用是安全的 - - 适用于所有数据库类型(SQLite, MySQL, PostgreSQL) + - 适用于所有数据库类型(SQLite, PostgreSQL) Yields: AsyncSession: SQLAlchemy异步会话对象 diff --git a/src/common/database/optimization/connection_pool.py b/src/common/database/optimization/connection_pool.py index ed7d3e5ef..a00335af3 100644 --- a/src/common/database/optimization/connection_pool.py +++ b/src/common/database/optimization/connection_pool.py @@ -128,7 +128,7 @@ class ConnectionPoolManager: - 正常退出时自动提交事务 - 发生异常时自动回滚事务 - 如果用户代码已手动调用 commit/rollback,再次调用是安全的(空操作) - - 支持所有数据库类型:SQLite、MySQL、PostgreSQL + - 支持所有数据库类型:SQLite、PostgreSQL """ connection_info = None @@ -158,7 +158,7 @@ class ConnectionPoolManager: yield connection_info.session # 🔧 正常退出时提交事务 - # 这对所有数据库(SQLite、MySQL、PostgreSQL)都很重要 + # 这对所有数据库(SQLite、PostgreSQL)都很重要 # 因为 SQLAlchemy 默认使用事务模式,不会自动提交 # 注意:如果用户代码已调用 commit(),这里的 commit() 是安全的空操作 if connection_info and connection_info.session: diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 072356db7..162c9c39a 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -16,26 +16,9 @@ from src.config.config_base import ValidatedConfigBase class DatabaseConfig(ValidatedConfigBase): """数据库配置类""" - database_type: Literal["sqlite", "mysql", "postgresql"] = Field(default="sqlite", description="数据库类型") + database_type: Literal["sqlite", "postgresql"] = Field(default="sqlite", description="数据库类型") sqlite_path: str = Field(default="data/MaiBot.db", description="SQLite数据库文件路径") - # MySQL 配置 - mysql_host: str = Field(default="localhost", description="MySQL服务器地址") - mysql_port: int = Field(default=3306, ge=1, le=65535, description="MySQL服务器端口") - mysql_database: str = Field(default="maibot", description="MySQL数据库名") - mysql_user: str = Field(default="root", description="MySQL用户名") - mysql_password: str = Field(default="", description="MySQL密码") - mysql_charset: str = Field(default="utf8mb4", description="MySQL字符集") - mysql_unix_socket: str = Field(default="", description="MySQL Unix套接字路径") - mysql_ssl_mode: Literal["DISABLED", "PREFERRED", "REQUIRED", "VERIFY_CA", "VERIFY_IDENTITY"] = Field( - default="DISABLED", description="SSL模式" - ) - mysql_ssl_ca: str = Field(default="", description="SSL CA证书路径") - mysql_ssl_cert: str = Field(default="", description="SSL客户端证书路径") - mysql_ssl_key: str = Field(default="", description="SSL密钥路径") - mysql_autocommit: bool = Field(default=True, description="自动提交事务") - mysql_sql_mode: str = Field(default="TRADITIONAL", description="SQL模式") - # PostgreSQL 配置 postgresql_host: str = Field(default="localhost", description="PostgreSQL服务器地址") postgresql_port: int = Field(default=5432, ge=1, le=65535, description="PostgreSQL服务器端口") diff --git a/src/plugin_system/utils/dependency_alias.py b/src/plugin_system/utils/dependency_alias.py index a7e478d76..7319f274e 100644 --- a/src/plugin_system/utils/dependency_alias.py +++ b/src/plugin_system/utils/dependency_alias.py @@ -61,7 +61,6 @@ INSTALL_NAME_TO_IMPORT_NAME = { "passlib": "passlib", # 密码哈希库 "bcrypt": "bcrypt", # Bcrypt密码哈希 # ============== 数据库 (Database) ============== - "mysql-connector-python": "mysql.connector", # MySQL官方驱动 "psycopg2-binary": "psycopg2", # PostgreSQL驱动 (二进制) "pymongo": "pymongo", # MongoDB驱动 "redis": "redis", # Redis客户端 diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index eec2fbd60..3593bfe9a 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.9.5" +version = "7.9.6" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -12,30 +12,11 @@ version = "7.9.5" #----以上是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- [database]# 数据库配置 -database_type = "sqlite" # 数据库类型,支持 "sqlite"、"mysql" 或 "postgresql" +database_type = "sqlite" # 数据库类型,支持 "sqlite" 或 "postgresql" # SQLite 配置(当 database_type = "sqlite" 时使用) sqlite_path = "data/MaiBot.db" # SQLite数据库文件路径 -# MySQL 配置(当 database_type = "mysql" 时使用) -mysql_host = "localhost" # MySQL服务器地址 -mysql_port = 3306 # MySQL服务器端口 -mysql_database = "maibot" # MySQL数据库名 -mysql_user = "root" # MySQL用户名 -mysql_password = "" # MySQL密码 -mysql_charset = "utf8mb4" # MySQL字符集 -mysql_unix_socket = "" # MySQL Unix套接字路径(可选,用于本地连接,优先于host/port) - -# MySQL SSL 配置 -mysql_ssl_mode = "DISABLED" # SSL模式: DISABLED, PREFERRED, REQUIRED, VERIFY_CA, VERIFY_IDENTITY -mysql_ssl_ca = "" # SSL CA证书路径 -mysql_ssl_cert = "" # SSL客户端证书路径 -mysql_ssl_key = "" # SSL客户端密钥路径 - -# MySQL 高级配置 -mysql_autocommit = true # 自动提交事务 -mysql_sql_mode = "TRADITIONAL" # SQL模式 - # PostgreSQL 配置(当 database_type = "postgresql" 时使用) postgresql_host = "localhost" # PostgreSQL服务器地址 postgresql_port = 5432 # PostgreSQL服务器端口 @@ -50,7 +31,7 @@ postgresql_ssl_ca = "" # SSL CA证书路径 postgresql_ssl_cert = "" # SSL客户端证书路径 postgresql_ssl_key = "" # SSL客户端密钥路径 -# 连接池配置(MySQL 和 PostgreSQL 有效) +# 连接池配置(PostgreSQL 有效) connection_pool_size = 10 # 连接池大小 connection_timeout = 10 # 连接超时时间(秒)