增强聊天管理器和数据库API,添加自动注册和异步清理功能,优化模型转换为字典的逻辑

This commit is contained in:
Windpicker-owo
2025-11-20 16:48:18 +08:00
parent 03c80a08fb
commit 6ecf5a36f2
5 changed files with 215 additions and 74 deletions

View File

@@ -18,6 +18,7 @@ class ChatterManager:
self.action_manager = action_manager self.action_manager = action_manager
self.chatter_classes: dict[ChatType, list[type]] = {} self.chatter_classes: dict[ChatType, list[type]] = {}
self.instances: dict[str, BaseChatter] = {} self.instances: dict[str, BaseChatter] = {}
self._auto_registered = False
# 管理器统计 # 管理器统计
self.stats = { self.stats = {
@@ -40,6 +41,12 @@ class ChatterManager:
except Exception as e: except Exception as e:
logger.warning(f"自动注册chatter组件时发生错误: {e}") logger.warning(f"自动注册chatter组件时发生错误: {e}")
def _ensure_chatter_registry(self):
"""确保聊天处理器注册表已初始化"""
if not self.chatter_classes and not self._auto_registered:
self._auto_register_from_component_registry()
self._auto_registered = True
def register_chatter(self, chatter_class: type): def register_chatter(self, chatter_class: type):
"""注册聊天处理器类""" """注册聊天处理器类"""
for chat_type in chatter_class.chat_types: for chat_type in chatter_class.chat_types:
@@ -84,73 +91,97 @@ class ChatterManager:
del self.instances[stream_id] del self.instances[stream_id]
logger.info(f"清理不活跃聊天流实例: {stream_id}") logger.info(f"清理不活跃聊天流实例: {stream_id}")
def _schedule_unread_cleanup(self, stream_id: str):
"""异步清理未读消息计数"""
try:
from src.chat.message_manager.message_manager import message_manager
except Exception as import_error:
logger.error("加载 message_manager 失败", stream_id=stream_id, error=import_error)
return
async def _clear_unread():
try:
await message_manager.clear_stream_unread_messages(stream_id)
logger.debug("清理未读消息完成", stream_id=stream_id)
except Exception as clear_error:
logger.error("清理未读消息失败", stream_id=stream_id, error=clear_error)
try:
asyncio.create_task(_clear_unread(), name=f"clear-unread-{stream_id}")
except RuntimeError as runtime_error:
logger.error("schedule unread cleanup failed", stream_id=stream_id, error=runtime_error)
async def process_stream_context(self, stream_id: str, context: "StreamContext") -> dict: async def process_stream_context(self, stream_id: str, context: "StreamContext") -> dict:
"""处理流上下文""" """处理流上下文"""
chat_type = context.chat_type chat_type = context.chat_type
logger.debug(f"处理流 {stream_id},聊天类型: {chat_type.value}") chat_type_value = chat_type.value
if not self.chatter_classes: logger.debug("处理流上下文", stream_id=stream_id, chat_type=chat_type_value)
self._auto_register_from_component_registry()
self._ensure_chatter_registry()
# 获取适合该聊天类型的chatter
chatter_class = self.get_chatter_class(chat_type) chatter_class = self.get_chatter_class(chat_type)
if not chatter_class: if not chatter_class:
# 如果没有找到精确匹配尝试查找支持ALL类型的chatter
from src.plugin_system.base.component_types import ChatType from src.plugin_system.base.component_types import ChatType
all_chatter_class = self.get_chatter_class(ChatType.ALL) all_chatter_class = self.get_chatter_class(ChatType.ALL)
if all_chatter_class: if all_chatter_class:
chatter_class = all_chatter_class chatter_class = all_chatter_class
logger.info(f"{stream_id} 使用通用chatter (类型: {chat_type.value})") logger.info(
"回退到通用聊天处理器",
stream_id=stream_id,
requested_type=chat_type_value,
fallback=ChatType.ALL.value,
)
else: else:
raise ValueError(f"No chatter registered for chat type {chat_type}") raise ValueError(f"No chatter registered for chat type {chat_type}")
if stream_id not in self.instances: stream_instance = self.instances.get(stream_id)
self.instances[stream_id] = chatter_class(stream_id=stream_id, action_manager=self.action_manager) if stream_instance is None:
logger.info(f"创建新的聊天流实例: {stream_id} 使用 {chatter_class.__name__} (类型: {chat_type.value})") stream_instance = chatter_class(stream_id=stream_id, action_manager=self.action_manager)
self.instances[stream_id] = stream_instance
logger.info(
"创建聊天处理器实例",
stream_id=stream_id,
chatter_class=chatter_class.__name__,
chat_type=chat_type_value,
)
self.stats["streams_processed"] += 1 self.stats["streams_processed"] += 1
try: try:
result = await self.instances[stream_id].execute(context) result = await stream_instance.execute(context)
# 检查执行结果是否真正成功
success = result.get("success", False) success = result.get("success", False)
if success: if success:
self.stats["successful_executions"] += 1 self.stats["successful_executions"] += 1
self._schedule_unread_cleanup(stream_id)
# 只有真正成功时才清空未读消息
try:
from src.chat.message_manager.message_manager import message_manager
await message_manager.clear_stream_unread_messages(stream_id)
logger.debug(f"{stream_id} 处理成功,已清空未读消息")
except Exception as clear_e:
logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}")
else: else:
self.stats["failed_executions"] += 1 self.stats["failed_executions"] += 1
logger.warning(f"{stream_id} 处理失败,不清空未读消息") logger.warning("聊天处理器执行失败", stream_id=stream_id)
# 记录处理结果
actions_count = result.get("actions_count", 0) actions_count = result.get("actions_count", 0)
logger.debug(f"{stream_id} 处理完成: 成功={success}, 动作数={actions_count}") logger.debug(
"聊天处理器执行完成",
stream_id=stream_id,
success=success,
actions_count=actions_count,
)
return result return result
except asyncio.CancelledError: except asyncio.CancelledError:
self.stats["failed_executions"] += 1 self.stats["failed_executions"] += 1
logger.info(f" {stream_id} 处理被取消") logger.info("流处理被取消", stream_id=stream_id)
context.triggering_user_id = None # 清除触发用户ID context.triggering_user_id = None
# 确保清理 processing_message_id 以防止重复回复检测失效
context.processing_message_id = None context.processing_message_id = None
raise raise
except Exception as e: except Exception as e: # noqa: BLE001
self.stats["failed_executions"] += 1 self.stats["failed_executions"] += 1
logger.error(f"处理流 {stream_id} 时发生错误: {e}") logger.error("处理流时出错", stream_id=stream_id, error=e)
context.triggering_user_id = None # 清除触发用户ID context.triggering_user_id = None
# 确保清理 processing_message_id
context.processing_message_id = None context.processing_message_id = None
raise raise
finally: finally:
# 清除触发用户ID所有情况下都需要
context.triggering_user_id = None context.triggering_user_id = None
def get_stats(self) -> dict[str, Any]: def get_stats(self) -> dict[str, Any]:
"""获取管理器统计信息""" """获取管理器统计信息"""
stats = self.stats.copy() stats = self.stats.copy()

View File

@@ -6,6 +6,9 @@
- 智能预加载:关联数据自动预加载 - 智能预加载:关联数据自动预加载
""" """
import operator
from collections.abc import Callable
from functools import lru_cache
from typing import Any, TypeVar from typing import Any, TypeVar
from sqlalchemy import delete, func, select, update from sqlalchemy import delete, func, select, update
@@ -25,6 +28,43 @@ logger = get_logger("database.crud")
T = TypeVar("T", bound=Base) T = TypeVar("T", bound=Base)
@lru_cache(maxsize=256)
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
"""获取模型的列名称列表"""
return tuple(column.name for column in model.__table__.columns)
@lru_cache(maxsize=256)
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
"""获取模型的有效字段集合"""
return frozenset(_get_model_column_names(model))
@lru_cache(maxsize=256)
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
"""为模型准备attrgetter用于批量获取属性值"""
column_names = _get_model_column_names(model)
if not column_names:
return lambda _: ()
if len(column_names) == 1:
attr_name = column_names[0]
def _single(instance: Base) -> tuple[Any, ...]:
return (getattr(instance, attr_name),)
return _single
getter = operator.attrgetter(*column_names)
def _multi(instance: Base) -> tuple[Any, ...]:
values = getter(instance)
return values if isinstance(values, tuple) else (values,)
return _multi
def _model_to_dict(instance: Base) -> dict[str, Any]: def _model_to_dict(instance: Base) -> dict[str, Any]:
"""将 SQLAlchemy 模型实例转换为字典 """将 SQLAlchemy 模型实例转换为字典
@@ -32,16 +72,27 @@ def _model_to_dict(instance: Base) -> dict[str, Any]:
instance: SQLAlchemy 模型实例 instance: SQLAlchemy 模型实例
Returns: Returns:
字典表示,包含所有列的 字典表示的模型实例的字段
""" """
result = {} if instance is None:
for column in instance.__table__.columns: return {}
try:
result[column.name] = getattr(instance, column.name) model = type(instance)
except Exception as e: column_names = _get_model_column_names(model)
logger.warning(f"无法访问字段 {column.name}: {e}") fetch_values = _get_model_value_fetcher(model)
result[column.name] = None
return result try:
values = fetch_values(instance)
return dict(zip(column_names, values))
except Exception as exc:
logger.warning(f"无法转换模型 {model.__name__}: {exc}")
fallback = {}
for column in column_names:
try:
fallback[column] = getattr(instance, column)
except Exception:
fallback[column] = None
return fallback
def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T: def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
@@ -55,8 +106,9 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
模型实例 (detached, 所有字段已加载) 模型实例 (detached, 所有字段已加载)
""" """
instance = model_class() instance = model_class()
valid_fields = _get_model_field_set(model_class)
for key, value in data.items(): for key, value in data.items():
if hasattr(instance, key): if key in valid_fields:
setattr(instance, key, value) setattr(instance, key, value)
return instance return instance

View File

@@ -183,11 +183,14 @@ class QueryBuilder(Generic[T]):
self._use_cache = False self._use_cache = False
return self return self
async def all(self) -> list[T]: async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]:
"""获取所有结果 """获取所有结果
Args:
as_dict: 为True时返回字典格式
Returns: Returns:
模型实例列表 模型实例列表或字典列表
""" """
cache_key = ":".join(self._cache_key_parts) + ":all" cache_key = ":".join(self._cache_key_parts) + ":all"
@@ -197,27 +200,33 @@ class QueryBuilder(Generic[T]):
cached_dicts = await cache.get(cache_key) cached_dicts = await cache.get(cache_key)
if cached_dicts is not None: if cached_dicts is not None:
logger.debug(f"缓存命中: {cache_key}") logger.debug(f"缓存命中: {cache_key}")
# 从字典列表恢复对象列表 dict_rows = [dict(row) for row in cached_dicts]
return [_dict_to_model(self.model, d) for d in cached_dicts] if as_dict:
return dict_rows
return [_dict_to_model(self.model, row) for row in dict_rows]
# 从数据库查询 # 从数据库查询
async with get_db_session() as session: async with get_db_session() as session:
result = await session.execute(self._stmt) result = await session.execute(self._stmt)
instances = list(result.scalars().all()) instances = list(result.scalars().all())
# 在 session 内部转换为字典列表,此时所有字段都可安全访问 # 在 session 内部转换为字典列表,此时所有字段都可安全访问
instances_dicts = [_model_to_dict(inst) for inst in instances] instances_dicts = [_model_to_dict(inst) for inst in instances]
# 写入缓存
if self._use_cache: if self._use_cache:
cache = await get_cache() cache = await get_cache()
await cache.set(cache_key, instances_dicts) cache_payload = [dict(row) for row in instances_dicts]
await cache.set(cache_key, cache_payload)
# 从字典列表重建对象列表返回detached状态所有字段已加载 if as_dict:
return [_dict_to_model(self.model, d) for d in instances_dicts] return instances_dicts
return [_dict_to_model(self.model, row) for row in instances_dicts]
async def first(self) -> T | None: async def first(self, *, as_dict: bool = False) -> T | dict[str, Any] | None:
"""获取第一结果 """获取第一结果
Args:
as_dict: 为True时返回字典格式
Returns: Returns:
模型实例或None 模型实例或None
@@ -230,8 +239,10 @@ class QueryBuilder(Generic[T]):
cached_dict = await cache.get(cache_key) cached_dict = await cache.get(cache_key)
if cached_dict is not None: if cached_dict is not None:
logger.debug(f"缓存命中: {cache_key}") logger.debug(f"缓存命中: {cache_key}")
# 从字典恢复对象 row = dict(cached_dict)
return _dict_to_model(self.model, cached_dict) if as_dict:
return row
return _dict_to_model(self.model, row)
# 从数据库查询 # 从数据库查询
async with get_db_session() as session: async with get_db_session() as session:
@@ -239,15 +250,16 @@ class QueryBuilder(Generic[T]):
instance = result.scalars().first() instance = result.scalars().first()
if instance is not None: if instance is not None:
# 在 session 内部转换为字典,此时所有字段都可安全访问 # 在 session 内部转换为字典,此时所有字段都可安全访问
instance_dict = _model_to_dict(instance) instance_dict = _model_to_dict(instance)
# 写入缓存 # 写入缓存
if self._use_cache: if self._use_cache:
cache = await get_cache() cache = await get_cache()
await cache.set(cache_key, instance_dict) await cache.set(cache_key, dict(instance_dict))
# 从字典重建对象返回detached状态所有字段已加载 if as_dict:
return instance_dict
return _dict_to_model(self.model, instance_dict) return _dict_to_model(self.model, instance_dict)
return None return None

View File

@@ -13,6 +13,7 @@ from src.common.database.api import (
from src.common.database.api import ( from src.common.database.api import (
store_action_info as new_store_action_info, store_action_info as new_store_action_info,
) )
from src.common.database.api.crud import _model_to_dict as _crud_model_to_dict
from src.common.database.core.models import ( from src.common.database.core.models import (
ActionRecords, ActionRecords,
AntiInjectionStats, AntiInjectionStats,
@@ -123,21 +124,19 @@ async def build_filters(model_class, filters: dict[str, Any]):
def _model_to_dict(instance) -> dict[str, Any]: def _model_to_dict(instance) -> dict[str, Any]:
"""将模型实例转换为字典 """数据库模型实例转换为字典兼容旧API
Args: Args:
instance: 模型实例 instance: 数据库模型实例
Returns: Returns:
字典表示 字典表示
""" """
if instance is None: if instance is None:
return None return None
return _crud_model_to_dict(instance)
result = {}
for column in instance.__table__.columns:
result[column.name] = getattr(instance, column.name)
return result
async def db_query( async def db_query(
@@ -211,11 +210,9 @@ async def db_query(
# 执行查询 # 执行查询
if single_result: if single_result:
result = await query_builder.first() return await query_builder.first(as_dict=True)
return _model_to_dict(result)
else: return await query_builder.all(as_dict=True)
results = await query_builder.all()
return [_model_to_dict(r) for r in results]
elif query_type == "create": elif query_type == "create":
if not data: if not data:

View File

@@ -1,13 +1,15 @@
# 使用基于时间戳的文件处理器,简单的轮转份数限制 # 使用基于时间戳的文件处理器,简单的轮转份数限制
import logging import logging
from logging.handlers import QueueHandler, QueueListener
import tarfile import tarfile
import threading import threading
import time import time
from collections.abc import Callable from collections.abc import Callable, Sequence
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from queue import SimpleQueue
import orjson import orjson
import structlog import structlog
import tomlkit import tomlkit
@@ -27,6 +29,11 @@ _console_handler: logging.Handler | None = None
_LOGGER_META_LOCK = threading.Lock() _LOGGER_META_LOCK = threading.Lock()
_LOGGER_META: dict[str, dict[str, str | None]] = {} _LOGGER_META: dict[str, dict[str, str | None]] = {}
# 日志格式化器
_log_queue: SimpleQueue[logging.LogRecord] | None = None
_queue_handler: QueueHandler | None = None
_queue_listener: QueueListener | None = None
def _register_logger_meta(name: str, *, alias: str | None = None, color: str | None = None): def _register_logger_meta(name: str, *, alias: str | None = None, color: str | None = None):
"""注册/更新 logger 元数据。 """注册/更新 logger 元数据。
@@ -90,6 +97,44 @@ def get_console_handler():
return _console_handler return _console_handler
def _start_queue_logging(handlers: Sequence[logging.Handler]) -> QueueHandler | None:
"""为日志处理器启动异步队列;无处理器时返回 None"""
global _log_queue, _queue_handler, _queue_listener
if _queue_listener is not None:
_queue_listener.stop()
_queue_listener = None
if not handlers:
return None
_log_queue = SimpleQueue()
_queue_handler = StructlogQueueHandler(_log_queue)
_queue_listener = QueueListener(_log_queue, *handlers, respect_handler_level=True)
_queue_listener.start()
return _queue_handler
def _stop_queue_logging():
"""停止异步日志队列"""
global _log_queue, _queue_handler, _queue_listener
if _queue_listener is not None:
_queue_listener.stop()
_queue_listener = None
_log_queue = None
_queue_handler = None
class StructlogQueueHandler(QueueHandler):
"""Queue handler that keeps structlog event dicts intact."""
def prepare(self, record):
# Keep the original LogRecord so processor formatters can access the event dict.
return record
class TimestampedFileHandler(logging.Handler): class TimestampedFileHandler(logging.Handler):
"""基于时间戳的文件处理器,带简单大小轮转 + 旧文件压缩/保留策略。 """基于时间戳的文件处理器,带简单大小轮转 + 旧文件压缩/保留策略。
@@ -221,6 +266,8 @@ def close_handlers():
"""安全关闭所有handler""" """安全关闭所有handler"""
global _file_handler, _console_handler global _file_handler, _console_handler
_stop_queue_logging()
if _file_handler: if _file_handler:
_file_handler.close() _file_handler.close()
_file_handler = None _file_handler = None
@@ -1037,15 +1084,17 @@ def _immediate_setup():
# 使用单例handler避免重复创建 # 使用单例handler避免重复创建
file_handler_local = get_file_handler() file_handler_local = get_file_handler()
console_handler_local = get_console_handler() console_handler_local = get_console_handler()
active_handlers = [h for h in (file_handler_local, console_handler_local) if h is not None]
for h in (file_handler_local, console_handler_local):
if h is not None:
root_logger.addHandler(h)
# 设置格式化器 # 设置格式化器
if file_handler_local is not None: if file_handler_local is not None:
file_handler_local.setFormatter(file_formatter) file_handler_local.setFormatter(file_formatter)
console_handler_local.setFormatter(console_formatter) if console_handler_local is not None:
console_handler_local.setFormatter(console_formatter)
queue_handler = _start_queue_logging(active_handlers)
if queue_handler is not None:
root_logger.addHandler(queue_handler)
# 清理重复的handler # 清理重复的handler
remove_duplicate_handlers() remove_duplicate_handlers()