增强聊天管理器和数据库API,添加自动注册和异步清理功能,优化模型转换为字典的逻辑

This commit is contained in:
Windpicker-owo
2025-11-20 16:48:18 +08:00
parent 03c80a08fb
commit 6ecf5a36f2
5 changed files with 215 additions and 74 deletions

View File

@@ -18,6 +18,7 @@ class ChatterManager:
self.action_manager = action_manager
self.chatter_classes: dict[ChatType, list[type]] = {}
self.instances: dict[str, BaseChatter] = {}
self._auto_registered = False
# 管理器统计
self.stats = {
@@ -40,6 +41,12 @@ class ChatterManager:
except Exception as e:
logger.warning(f"自动注册chatter组件时发生错误: {e}")
def _ensure_chatter_registry(self):
"""确保聊天处理器注册表已初始化"""
if not self.chatter_classes and not self._auto_registered:
self._auto_register_from_component_registry()
self._auto_registered = True
def register_chatter(self, chatter_class: type):
"""注册聊天处理器类"""
for chat_type in chatter_class.chat_types:
@@ -84,73 +91,97 @@ class ChatterManager:
del self.instances[stream_id]
logger.info(f"清理不活跃聊天流实例: {stream_id}")
def _schedule_unread_cleanup(self, stream_id: str):
"""异步清理未读消息计数"""
try:
from src.chat.message_manager.message_manager import message_manager
except Exception as import_error:
logger.error("加载 message_manager 失败", stream_id=stream_id, error=import_error)
return
async def _clear_unread():
try:
await message_manager.clear_stream_unread_messages(stream_id)
logger.debug("清理未读消息完成", stream_id=stream_id)
except Exception as clear_error:
logger.error("清理未读消息失败", stream_id=stream_id, error=clear_error)
try:
asyncio.create_task(_clear_unread(), name=f"clear-unread-{stream_id}")
except RuntimeError as runtime_error:
logger.error("schedule unread cleanup failed", stream_id=stream_id, error=runtime_error)
async def process_stream_context(self, stream_id: str, context: "StreamContext") -> dict:
"""处理流上下文"""
chat_type = context.chat_type
logger.debug(f"处理流 {stream_id},聊天类型: {chat_type.value}")
if not self.chatter_classes:
self._auto_register_from_component_registry()
chat_type_value = chat_type.value
logger.debug("处理流上下文", stream_id=stream_id, chat_type=chat_type_value)
self._ensure_chatter_registry()
# 获取适合该聊天类型的chatter
chatter_class = self.get_chatter_class(chat_type)
if not chatter_class:
# 如果没有找到精确匹配尝试查找支持ALL类型的chatter
from src.plugin_system.base.component_types import ChatType
all_chatter_class = self.get_chatter_class(ChatType.ALL)
if all_chatter_class:
chatter_class = all_chatter_class
logger.info(f"{stream_id} 使用通用chatter (类型: {chat_type.value})")
logger.info(
"回退到通用聊天处理器",
stream_id=stream_id,
requested_type=chat_type_value,
fallback=ChatType.ALL.value,
)
else:
raise ValueError(f"No chatter registered for chat type {chat_type}")
if stream_id not in self.instances:
self.instances[stream_id] = chatter_class(stream_id=stream_id, action_manager=self.action_manager)
logger.info(f"创建新的聊天流实例: {stream_id} 使用 {chatter_class.__name__} (类型: {chat_type.value})")
stream_instance = self.instances.get(stream_id)
if stream_instance is None:
stream_instance = chatter_class(stream_id=stream_id, action_manager=self.action_manager)
self.instances[stream_id] = stream_instance
logger.info(
"创建聊天处理器实例",
stream_id=stream_id,
chatter_class=chatter_class.__name__,
chat_type=chat_type_value,
)
self.stats["streams_processed"] += 1
try:
result = await self.instances[stream_id].execute(context)
# 检查执行结果是否真正成功
result = await stream_instance.execute(context)
success = result.get("success", False)
if success:
self.stats["successful_executions"] += 1
# 只有真正成功时才清空未读消息
try:
from src.chat.message_manager.message_manager import message_manager
await message_manager.clear_stream_unread_messages(stream_id)
logger.debug(f"{stream_id} 处理成功,已清空未读消息")
except Exception as clear_e:
logger.error(f"清除流 {stream_id} 未读消息时发生错误: {clear_e}")
self._schedule_unread_cleanup(stream_id)
else:
self.stats["failed_executions"] += 1
logger.warning(f"{stream_id} 处理失败,不清空未读消息")
logger.warning("聊天处理器执行失败", stream_id=stream_id)
# 记录处理结果
actions_count = result.get("actions_count", 0)
logger.debug(f"{stream_id} 处理完成: 成功={success}, 动作数={actions_count}")
logger.debug(
"聊天处理器执行完成",
stream_id=stream_id,
success=success,
actions_count=actions_count,
)
return result
except asyncio.CancelledError:
self.stats["failed_executions"] += 1
logger.info(f" {stream_id} 处理被取消")
context.triggering_user_id = None # 清除触发用户ID
# 确保清理 processing_message_id 以防止重复回复检测失效
logger.info("流处理被取消", stream_id=stream_id)
context.triggering_user_id = None
context.processing_message_id = None
raise
except Exception as e:
except Exception as e: # noqa: BLE001
self.stats["failed_executions"] += 1
logger.error(f"处理流 {stream_id} 时发生错误: {e}")
context.triggering_user_id = None # 清除触发用户ID
# 确保清理 processing_message_id
logger.error("处理流时出错", stream_id=stream_id, error=e)
context.triggering_user_id = None
context.processing_message_id = None
raise
finally:
# 清除触发用户ID所有情况下都需要
context.triggering_user_id = None
def get_stats(self) -> dict[str, Any]:
"""获取管理器统计信息"""
stats = self.stats.copy()

View File

@@ -6,6 +6,9 @@
- 智能预加载:关联数据自动预加载
"""
import operator
from collections.abc import Callable
from functools import lru_cache
from typing import Any, TypeVar
from sqlalchemy import delete, func, select, update
@@ -25,6 +28,43 @@ logger = get_logger("database.crud")
T = TypeVar("T", bound=Base)
@lru_cache(maxsize=256)
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
"""获取模型的列名称列表"""
return tuple(column.name for column in model.__table__.columns)
@lru_cache(maxsize=256)
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
"""获取模型的有效字段集合"""
return frozenset(_get_model_column_names(model))
@lru_cache(maxsize=256)
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
"""为模型准备attrgetter用于批量获取属性值"""
column_names = _get_model_column_names(model)
if not column_names:
return lambda _: ()
if len(column_names) == 1:
attr_name = column_names[0]
def _single(instance: Base) -> tuple[Any, ...]:
return (getattr(instance, attr_name),)
return _single
getter = operator.attrgetter(*column_names)
def _multi(instance: Base) -> tuple[Any, ...]:
values = getter(instance)
return values if isinstance(values, tuple) else (values,)
return _multi
def _model_to_dict(instance: Base) -> dict[str, Any]:
"""将 SQLAlchemy 模型实例转换为字典
@@ -32,16 +72,27 @@ def _model_to_dict(instance: Base) -> dict[str, Any]:
instance: SQLAlchemy 模型实例
Returns:
字典表示,包含所有列的
字典表示的模型实例的字段
"""
result = {}
for column in instance.__table__.columns:
if instance is None:
return {}
model = type(instance)
column_names = _get_model_column_names(model)
fetch_values = _get_model_value_fetcher(model)
try:
result[column.name] = getattr(instance, column.name)
except Exception as e:
logger.warning(f"无法访问字段 {column.name}: {e}")
result[column.name] = None
return result
values = fetch_values(instance)
return dict(zip(column_names, values))
except Exception as exc:
logger.warning(f"无法转换模型 {model.__name__}: {exc}")
fallback = {}
for column in column_names:
try:
fallback[column] = getattr(instance, column)
except Exception:
fallback[column] = None
return fallback
def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
@@ -55,8 +106,9 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
模型实例 (detached, 所有字段已加载)
"""
instance = model_class()
valid_fields = _get_model_field_set(model_class)
for key, value in data.items():
if hasattr(instance, key):
if key in valid_fields:
setattr(instance, key, value)
return instance

View File

@@ -183,11 +183,14 @@ class QueryBuilder(Generic[T]):
self._use_cache = False
return self
async def all(self) -> list[T]:
async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]:
"""获取所有结果
Args:
as_dict: 为True时返回字典格式
Returns:
模型实例列表
模型实例列表或字典列表
"""
cache_key = ":".join(self._cache_key_parts) + ":all"
@@ -197,27 +200,33 @@ class QueryBuilder(Generic[T]):
cached_dicts = await cache.get(cache_key)
if cached_dicts is not None:
logger.debug(f"缓存命中: {cache_key}")
# 从字典列表恢复对象列表
return [_dict_to_model(self.model, d) for d in cached_dicts]
dict_rows = [dict(row) for row in cached_dicts]
if as_dict:
return dict_rows
return [_dict_to_model(self.model, row) for row in dict_rows]
# 从数据库查询
async with get_db_session() as session:
result = await session.execute(self._stmt)
instances = list(result.scalars().all())
# 在 session 内部转换为字典列表,此时所有字段都可安全访问
# 在 session 内部转换为字典列表,此时所有字段都可安全访问
instances_dicts = [_model_to_dict(inst) for inst in instances]
# 写入缓存
if self._use_cache:
cache = await get_cache()
await cache.set(cache_key, instances_dicts)
cache_payload = [dict(row) for row in instances_dicts]
await cache.set(cache_key, cache_payload)
# 从字典列表重建对象列表返回detached状态所有字段已加载
return [_dict_to_model(self.model, d) for d in instances_dicts]
if as_dict:
return instances_dicts
return [_dict_to_model(self.model, row) for row in instances_dicts]
async def first(self) -> T | None:
"""获取第一结果
async def first(self, *, as_dict: bool = False) -> T | dict[str, Any] | None:
"""获取第一结果
Args:
as_dict: 为True时返回字典格式
Returns:
模型实例或None
@@ -230,8 +239,10 @@ class QueryBuilder(Generic[T]):
cached_dict = await cache.get(cache_key)
if cached_dict is not None:
logger.debug(f"缓存命中: {cache_key}")
# 从字典恢复对象
return _dict_to_model(self.model, cached_dict)
row = dict(cached_dict)
if as_dict:
return row
return _dict_to_model(self.model, row)
# 从数据库查询
async with get_db_session() as session:
@@ -239,15 +250,16 @@ class QueryBuilder(Generic[T]):
instance = result.scalars().first()
if instance is not None:
# 在 session 内部转换为字典,此时所有字段都可安全访问
# 在 session 内部转换为字典,此时所有字段都可安全访问
instance_dict = _model_to_dict(instance)
# 写入缓存
if self._use_cache:
cache = await get_cache()
await cache.set(cache_key, instance_dict)
await cache.set(cache_key, dict(instance_dict))
# 从字典重建对象返回detached状态所有字段已加载
if as_dict:
return instance_dict
return _dict_to_model(self.model, instance_dict)
return None

View File

@@ -13,6 +13,7 @@ from src.common.database.api import (
from src.common.database.api import (
store_action_info as new_store_action_info,
)
from src.common.database.api.crud import _model_to_dict as _crud_model_to_dict
from src.common.database.core.models import (
ActionRecords,
AntiInjectionStats,
@@ -123,21 +124,19 @@ async def build_filters(model_class, filters: dict[str, Any]):
def _model_to_dict(instance) -> dict[str, Any]:
"""将模型实例转换为字典
"""数据库模型实例转换为字典兼容旧API
Args:
instance: 模型实例
instance: 数据库模型实例
Returns:
字典表示
"""
if instance is None:
return None
return _crud_model_to_dict(instance)
result = {}
for column in instance.__table__.columns:
result[column.name] = getattr(instance, column.name)
return result
async def db_query(
@@ -211,11 +210,9 @@ async def db_query(
# 执行查询
if single_result:
result = await query_builder.first()
return _model_to_dict(result)
else:
results = await query_builder.all()
return [_model_to_dict(r) for r in results]
return await query_builder.first(as_dict=True)
return await query_builder.all(as_dict=True)
elif query_type == "create":
if not data:

View File

@@ -1,13 +1,15 @@
# 使用基于时间戳的文件处理器,简单的轮转份数限制
import logging
from logging.handlers import QueueHandler, QueueListener
import tarfile
import threading
import time
from collections.abc import Callable
from collections.abc import Callable, Sequence
from datetime import datetime, timedelta
from pathlib import Path
from queue import SimpleQueue
import orjson
import structlog
import tomlkit
@@ -27,6 +29,11 @@ _console_handler: logging.Handler | None = None
_LOGGER_META_LOCK = threading.Lock()
_LOGGER_META: dict[str, dict[str, str | None]] = {}
# 日志格式化器
_log_queue: SimpleQueue[logging.LogRecord] | None = None
_queue_handler: QueueHandler | None = None
_queue_listener: QueueListener | None = None
def _register_logger_meta(name: str, *, alias: str | None = None, color: str | None = None):
"""注册/更新 logger 元数据。
@@ -90,6 +97,44 @@ def get_console_handler():
return _console_handler
def _start_queue_logging(handlers: Sequence[logging.Handler]) -> QueueHandler | None:
"""为日志处理器启动异步队列;无处理器时返回 None"""
global _log_queue, _queue_handler, _queue_listener
if _queue_listener is not None:
_queue_listener.stop()
_queue_listener = None
if not handlers:
return None
_log_queue = SimpleQueue()
_queue_handler = StructlogQueueHandler(_log_queue)
_queue_listener = QueueListener(_log_queue, *handlers, respect_handler_level=True)
_queue_listener.start()
return _queue_handler
def _stop_queue_logging():
"""停止异步日志队列"""
global _log_queue, _queue_handler, _queue_listener
if _queue_listener is not None:
_queue_listener.stop()
_queue_listener = None
_log_queue = None
_queue_handler = None
class StructlogQueueHandler(QueueHandler):
"""Queue handler that keeps structlog event dicts intact."""
def prepare(self, record):
# Keep the original LogRecord so processor formatters can access the event dict.
return record
class TimestampedFileHandler(logging.Handler):
"""基于时间戳的文件处理器,带简单大小轮转 + 旧文件压缩/保留策略。
@@ -221,6 +266,8 @@ def close_handlers():
"""安全关闭所有handler"""
global _file_handler, _console_handler
_stop_queue_logging()
if _file_handler:
_file_handler.close()
_file_handler = None
@@ -1037,16 +1084,18 @@ def _immediate_setup():
# 使用单例handler避免重复创建
file_handler_local = get_file_handler()
console_handler_local = get_console_handler()
for h in (file_handler_local, console_handler_local):
if h is not None:
root_logger.addHandler(h)
active_handlers = [h for h in (file_handler_local, console_handler_local) if h is not None]
# 设置格式化器
if file_handler_local is not None:
file_handler_local.setFormatter(file_formatter)
if console_handler_local is not None:
console_handler_local.setFormatter(console_formatter)
queue_handler = _start_queue_logging(active_handlers)
if queue_handler is not None:
root_logger.addHandler(queue_handler)
# 清理重复的handler
remove_duplicate_handlers()