From 6ecf5a36f2e480c9f3d20bb50c11f56b6e8ca3ec Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Thu, 20 Nov 2025 16:48:18 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=BC=BA=E8=81=8A=E5=A4=A9=E7=AE=A1?= =?UTF-8?q?=E7=90=86=E5=99=A8=E5=92=8C=E6=95=B0=E6=8D=AE=E5=BA=93API?= =?UTF-8?q?=EF=BC=8C=E6=B7=BB=E5=8A=A0=E8=87=AA=E5=8A=A8=E6=B3=A8=E5=86=8C?= =?UTF-8?q?=E5=92=8C=E5=BC=82=E6=AD=A5=E6=B8=85=E7=90=86=E5=8A=9F=E8=83=BD?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E6=A8=A1=E5=9E=8B=E8=BD=AC=E6=8D=A2?= =?UTF-8?q?=E4=B8=BA=E5=AD=97=E5=85=B8=E7=9A=84=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/chat/chatter_manager.py | 93 +++++++++++++------- src/common/database/api/crud.py | 72 ++++++++++++--- src/common/database/api/query.py | 44 +++++---- src/common/database/compatibility/adapter.py | 19 ++-- src/common/logger.py | 61 +++++++++++-- 5 files changed, 215 insertions(+), 74 deletions(-) diff --git a/src/chat/chatter_manager.py b/src/chat/chatter_manager.py index 36f2dd2e9..a4405358b 100644 --- a/src/chat/chatter_manager.py +++ b/src/chat/chatter_manager.py @@ -18,6 +18,7 @@ class ChatterManager: self.action_manager = action_manager self.chatter_classes: dict[ChatType, list[type]] = {} self.instances: dict[str, BaseChatter] = {} + self._auto_registered = False # 管理器统计 self.stats = { @@ -40,6 +41,12 @@ class ChatterManager: except Exception as e: logger.warning(f"自动注册chatter组件时发生错误: {e}") + def _ensure_chatter_registry(self): + """确保聊天处理器注册表已初始化""" + if not self.chatter_classes and not self._auto_registered: + self._auto_register_from_component_registry() + self._auto_registered = True + def register_chatter(self, chatter_class: type): """注册聊天处理器类""" for chat_type in chatter_class.chat_types: @@ -84,73 +91,97 @@ class ChatterManager: del self.instances[stream_id] logger.info(f"清理不活跃聊天流实例: {stream_id}") + def _schedule_unread_cleanup(self, stream_id: str): + """异步清理未读消息计数""" + try: + from src.chat.message_manager.message_manager import message_manager + except Exception as import_error: + logger.error("加载 message_manager 失败", stream_id=stream_id, error=import_error) + return + + async def _clear_unread(): + try: + await message_manager.clear_stream_unread_messages(stream_id) + logger.debug("清理未读消息完成", stream_id=stream_id) + except Exception as clear_error: + logger.error("清理未读消息失败", stream_id=stream_id, error=clear_error) + + try: + asyncio.create_task(_clear_unread(), name=f"clear-unread-{stream_id}") + except RuntimeError as runtime_error: + logger.error("schedule unread cleanup failed", stream_id=stream_id, error=runtime_error) + async def process_stream_context(self, stream_id: str, context: "StreamContext") -> dict: """处理流上下文""" chat_type = context.chat_type - logger.debug(f"处理流 {stream_id},聊天类型: {chat_type.value}") - if not self.chatter_classes: - self._auto_register_from_component_registry() + chat_type_value = chat_type.value + logger.debug("处理流上下文", stream_id=stream_id, chat_type=chat_type_value) + + self._ensure_chatter_registry() - # 获取适合该聊天类型的chatter chatter_class = self.get_chatter_class(chat_type) if not chatter_class: - # 如果没有找到精确匹配,尝试查找支持ALL类型的chatter from src.plugin_system.base.component_types import ChatType all_chatter_class = self.get_chatter_class(ChatType.ALL) if all_chatter_class: chatter_class = all_chatter_class - logger.info(f"流 {stream_id} 使用通用chatter (类型: {chat_type.value})") + logger.info( + "回退到通用聊天处理器", + stream_id=stream_id, + requested_type=chat_type_value, + fallback=ChatType.ALL.value, + ) else: raise ValueError(f"No chatter registered for chat type {chat_type}") - if stream_id not in self.instances: - self.instances[stream_id] = chatter_class(stream_id=stream_id, action_manager=self.action_manager) - logger.info(f"创建新的聊天流实例: {stream_id} 使用 {chatter_class.__name__} (类型: {chat_type.value})") + stream_instance = self.instances.get(stream_id) + if stream_instance is None: + stream_instance = chatter_class(stream_id=stream_id, action_manager=self.action_manager) + self.instances[stream_id] = stream_instance + logger.info( + "创建聊天处理器实例", + stream_id=stream_id, + chatter_class=chatter_class.__name__, + chat_type=chat_type_value, + ) self.stats["streams_processed"] += 1 try: - result = await self.instances[stream_id].execute(context) - - # 检查执行结果是否真正成功 + result = await stream_instance.execute(context) success = result.get("success", False) if success: self.stats["successful_executions"] += 1 - - # 只有真正成功时才清空未读消息 - try: - from src.chat.message_manager.message_manager import message_manager - await message_manager.clear_stream_unread_messages(stream_id) - logger.debug(f"流 {stream_id} 处理成功,已清空未读消息") - except Exception as clear_e: - logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}") + self._schedule_unread_cleanup(stream_id) else: self.stats["failed_executions"] += 1 - logger.warning(f"流 {stream_id} 处理失败,不清空未读消息") + logger.warning("聊天处理器执行失败", stream_id=stream_id) - # 记录处理结果 actions_count = result.get("actions_count", 0) - logger.debug(f"流 {stream_id} 处理完成: 成功={success}, 动作数={actions_count}") + logger.debug( + "聊天处理器执行完成", + stream_id=stream_id, + success=success, + actions_count=actions_count, + ) return result except asyncio.CancelledError: self.stats["failed_executions"] += 1 - logger.info(f"流 {stream_id} 处理被取消") - context.triggering_user_id = None # 清除触发用户ID - # 确保清理 processing_message_id 以防止重复回复检测失效 + logger.info("流处理被取消", stream_id=stream_id) + context.triggering_user_id = None context.processing_message_id = None raise - except Exception as e: + except Exception as e: # noqa: BLE001 self.stats["failed_executions"] += 1 - logger.error(f"处理流 {stream_id} 时发生错误: {e}") - context.triggering_user_id = None # 清除触发用户ID - # 确保清理 processing_message_id + logger.error("处理流时出错", stream_id=stream_id, error=e) + context.triggering_user_id = None context.processing_message_id = None raise finally: - # 清除触发用户ID(所有情况下都需要) context.triggering_user_id = None + def get_stats(self) -> dict[str, Any]: """获取管理器统计信息""" stats = self.stats.copy() diff --git a/src/common/database/api/crud.py b/src/common/database/api/crud.py index 69af46562..1c9b1aef9 100644 --- a/src/common/database/api/crud.py +++ b/src/common/database/api/crud.py @@ -6,6 +6,9 @@ - 智能预加载:关联数据自动预加载 """ +import operator +from collections.abc import Callable +from functools import lru_cache from typing import Any, TypeVar from sqlalchemy import delete, func, select, update @@ -25,6 +28,43 @@ logger = get_logger("database.crud") T = TypeVar("T", bound=Base) +@lru_cache(maxsize=256) +def _get_model_column_names(model: type[Base]) -> tuple[str, ...]: + """获取模型的列名称列表""" + return tuple(column.name for column in model.__table__.columns) + + +@lru_cache(maxsize=256) +def _get_model_field_set(model: type[Base]) -> frozenset[str]: + """获取模型的有效字段集合""" + return frozenset(_get_model_column_names(model)) + + +@lru_cache(maxsize=256) +def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]: + """为模型准备attrgetter,用于批量获取属性值""" + column_names = _get_model_column_names(model) + + if not column_names: + return lambda _: () + + if len(column_names) == 1: + attr_name = column_names[0] + + def _single(instance: Base) -> tuple[Any, ...]: + return (getattr(instance, attr_name),) + + return _single + + getter = operator.attrgetter(*column_names) + + def _multi(instance: Base) -> tuple[Any, ...]: + values = getter(instance) + return values if isinstance(values, tuple) else (values,) + + return _multi + + def _model_to_dict(instance: Base) -> dict[str, Any]: """将 SQLAlchemy 模型实例转换为字典 @@ -32,16 +72,27 @@ def _model_to_dict(instance: Base) -> dict[str, Any]: instance: SQLAlchemy 模型实例 Returns: - 字典表示,包含所有列的值 + 字典表示的模型实例的字段值 """ - result = {} - for column in instance.__table__.columns: - try: - result[column.name] = getattr(instance, column.name) - except Exception as e: - logger.warning(f"无法访问字段 {column.name}: {e}") - result[column.name] = None - return result + if instance is None: + return {} + + model = type(instance) + column_names = _get_model_column_names(model) + fetch_values = _get_model_value_fetcher(model) + + try: + values = fetch_values(instance) + return dict(zip(column_names, values)) + except Exception as exc: + logger.warning(f"无法转换模型 {model.__name__}: {exc}") + fallback = {} + for column in column_names: + try: + fallback[column] = getattr(instance, column) + except Exception: + fallback[column] = None + return fallback def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T: @@ -55,8 +106,9 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T: 模型实例 (detached, 所有字段已加载) """ instance = model_class() + valid_fields = _get_model_field_set(model_class) for key, value in data.items(): - if hasattr(instance, key): + if key in valid_fields: setattr(instance, key, value) return instance diff --git a/src/common/database/api/query.py b/src/common/database/api/query.py index 8d7bab1b1..6815820ef 100644 --- a/src/common/database/api/query.py +++ b/src/common/database/api/query.py @@ -183,11 +183,14 @@ class QueryBuilder(Generic[T]): self._use_cache = False return self - async def all(self) -> list[T]: + async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]: """获取所有结果 + Args: + as_dict: 为True时返回字典格式 + Returns: - 模型实例列表 + 模型实例列表或字典列表 """ cache_key = ":".join(self._cache_key_parts) + ":all" @@ -197,27 +200,33 @@ class QueryBuilder(Generic[T]): cached_dicts = await cache.get(cache_key) if cached_dicts is not None: logger.debug(f"缓存命中: {cache_key}") - # 从字典列表恢复对象列表 - return [_dict_to_model(self.model, d) for d in cached_dicts] + dict_rows = [dict(row) for row in cached_dicts] + if as_dict: + return dict_rows + return [_dict_to_model(self.model, row) for row in dict_rows] # 从数据库查询 async with get_db_session() as session: result = await session.execute(self._stmt) instances = list(result.scalars().all()) - # ✅ 在 session 内部转换为字典列表,此时所有字段都可安全访问 + # 在 session 内部转换为字典列表,此时所有字段都可安全访问 instances_dicts = [_model_to_dict(inst) for inst in instances] - # 写入缓存 if self._use_cache: cache = await get_cache() - await cache.set(cache_key, instances_dicts) + cache_payload = [dict(row) for row in instances_dicts] + await cache.set(cache_key, cache_payload) - # 从字典列表重建对象列表返回(detached状态,所有字段已加载) - return [_dict_to_model(self.model, d) for d in instances_dicts] + if as_dict: + return instances_dicts + return [_dict_to_model(self.model, row) for row in instances_dicts] - async def first(self) -> T | None: - """获取第一个结果 + async def first(self, *, as_dict: bool = False) -> T | dict[str, Any] | None: + """获取第一条结果 + + Args: + as_dict: 为True时返回字典格式 Returns: 模型实例或None @@ -230,8 +239,10 @@ class QueryBuilder(Generic[T]): cached_dict = await cache.get(cache_key) if cached_dict is not None: logger.debug(f"缓存命中: {cache_key}") - # 从字典恢复对象 - return _dict_to_model(self.model, cached_dict) + row = dict(cached_dict) + if as_dict: + return row + return _dict_to_model(self.model, row) # 从数据库查询 async with get_db_session() as session: @@ -239,15 +250,16 @@ class QueryBuilder(Generic[T]): instance = result.scalars().first() if instance is not None: - # ✅ 在 session 内部转换为字典,此时所有字段都可安全访问 + # 在 session 内部转换为字典,此时所有字段都可安全访问 instance_dict = _model_to_dict(instance) # 写入缓存 if self._use_cache: cache = await get_cache() - await cache.set(cache_key, instance_dict) + await cache.set(cache_key, dict(instance_dict)) - # 从字典重建对象返回(detached状态,所有字段已加载) + if as_dict: + return instance_dict return _dict_to_model(self.model, instance_dict) return None diff --git a/src/common/database/compatibility/adapter.py b/src/common/database/compatibility/adapter.py index a4bd8f51a..c102704d0 100644 --- a/src/common/database/compatibility/adapter.py +++ b/src/common/database/compatibility/adapter.py @@ -13,6 +13,7 @@ from src.common.database.api import ( from src.common.database.api import ( store_action_info as new_store_action_info, ) +from src.common.database.api.crud import _model_to_dict as _crud_model_to_dict from src.common.database.core.models import ( ActionRecords, AntiInjectionStats, @@ -123,21 +124,19 @@ async def build_filters(model_class, filters: dict[str, Any]): def _model_to_dict(instance) -> dict[str, Any]: - """将模型实例转换为字典 + """将数据库模型实例转换为字典(兼容旧API Args: - instance: 模型实例 + instance: 数据库模型实例 Returns: 字典表示 """ if instance is None: return None + return _crud_model_to_dict(instance) + - result = {} - for column in instance.__table__.columns: - result[column.name] = getattr(instance, column.name) - return result async def db_query( @@ -211,11 +210,9 @@ async def db_query( # 执行查询 if single_result: - result = await query_builder.first() - return _model_to_dict(result) - else: - results = await query_builder.all() - return [_model_to_dict(r) for r in results] + return await query_builder.first(as_dict=True) + + return await query_builder.all(as_dict=True) elif query_type == "create": if not data: diff --git a/src/common/logger.py b/src/common/logger.py index 3eff08044..0e4c50fa3 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1,13 +1,15 @@ # 使用基于时间戳的文件处理器,简单的轮转份数限制 import logging +from logging.handlers import QueueHandler, QueueListener import tarfile import threading import time -from collections.abc import Callable +from collections.abc import Callable, Sequence from datetime import datetime, timedelta from pathlib import Path +from queue import SimpleQueue import orjson import structlog import tomlkit @@ -27,6 +29,11 @@ _console_handler: logging.Handler | None = None _LOGGER_META_LOCK = threading.Lock() _LOGGER_META: dict[str, dict[str, str | None]] = {} +# 日志格式化器 +_log_queue: SimpleQueue[logging.LogRecord] | None = None +_queue_handler: QueueHandler | None = None +_queue_listener: QueueListener | None = None + def _register_logger_meta(name: str, *, alias: str | None = None, color: str | None = None): """注册/更新 logger 元数据。 @@ -90,6 +97,44 @@ def get_console_handler(): return _console_handler +def _start_queue_logging(handlers: Sequence[logging.Handler]) -> QueueHandler | None: + """为日志处理器启动异步队列;无处理器时返回 None""" + global _log_queue, _queue_handler, _queue_listener + + if _queue_listener is not None: + _queue_listener.stop() + _queue_listener = None + + if not handlers: + return None + + _log_queue = SimpleQueue() + _queue_handler = StructlogQueueHandler(_log_queue) + _queue_listener = QueueListener(_log_queue, *handlers, respect_handler_level=True) + _queue_listener.start() + return _queue_handler + + +def _stop_queue_logging(): + """停止异步日志队列""" + global _log_queue, _queue_handler, _queue_listener + + if _queue_listener is not None: + _queue_listener.stop() + _queue_listener = None + + _log_queue = None + _queue_handler = None + + +class StructlogQueueHandler(QueueHandler): + """Queue handler that keeps structlog event dicts intact.""" + + def prepare(self, record): + # Keep the original LogRecord so processor formatters can access the event dict. + return record + + class TimestampedFileHandler(logging.Handler): """基于时间戳的文件处理器,带简单大小轮转 + 旧文件压缩/保留策略。 @@ -221,6 +266,8 @@ def close_handlers(): """安全关闭所有handler""" global _file_handler, _console_handler + _stop_queue_logging() + if _file_handler: _file_handler.close() _file_handler = None @@ -1037,15 +1084,17 @@ def _immediate_setup(): # 使用单例handler避免重复创建 file_handler_local = get_file_handler() console_handler_local = get_console_handler() - - for h in (file_handler_local, console_handler_local): - if h is not None: - root_logger.addHandler(h) + active_handlers = [h for h in (file_handler_local, console_handler_local) if h is not None] # 设置格式化器 if file_handler_local is not None: file_handler_local.setFormatter(file_formatter) - console_handler_local.setFormatter(console_formatter) + if console_handler_local is not None: + console_handler_local.setFormatter(console_formatter) + + queue_handler = _start_queue_logging(active_handlers) + if queue_handler is not None: + root_logger.addHandler(queue_handler) # 清理重复的handler remove_duplicate_handlers()