增强聊天管理器和数据库API,添加自动注册和异步清理功能,优化模型转换为字典的逻辑
This commit is contained in:
@@ -6,6 +6,9 @@
|
||||
- 智能预加载:关联数据自动预加载
|
||||
"""
|
||||
|
||||
import operator
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import Any, TypeVar
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
@@ -25,6 +28,43 @@ logger = get_logger("database.crud")
|
||||
T = TypeVar("T", bound=Base)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
|
||||
"""获取模型的列名称列表"""
|
||||
return tuple(column.name for column in model.__table__.columns)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
|
||||
"""获取模型的有效字段集合"""
|
||||
return frozenset(_get_model_column_names(model))
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
|
||||
"""为模型准备attrgetter,用于批量获取属性值"""
|
||||
column_names = _get_model_column_names(model)
|
||||
|
||||
if not column_names:
|
||||
return lambda _: ()
|
||||
|
||||
if len(column_names) == 1:
|
||||
attr_name = column_names[0]
|
||||
|
||||
def _single(instance: Base) -> tuple[Any, ...]:
|
||||
return (getattr(instance, attr_name),)
|
||||
|
||||
return _single
|
||||
|
||||
getter = operator.attrgetter(*column_names)
|
||||
|
||||
def _multi(instance: Base) -> tuple[Any, ...]:
|
||||
values = getter(instance)
|
||||
return values if isinstance(values, tuple) else (values,)
|
||||
|
||||
return _multi
|
||||
|
||||
|
||||
def _model_to_dict(instance: Base) -> dict[str, Any]:
|
||||
"""将 SQLAlchemy 模型实例转换为字典
|
||||
|
||||
@@ -32,16 +72,27 @@ def _model_to_dict(instance: Base) -> dict[str, Any]:
|
||||
instance: SQLAlchemy 模型实例
|
||||
|
||||
Returns:
|
||||
字典表示,包含所有列的值
|
||||
字典表示的模型实例的字段值
|
||||
"""
|
||||
result = {}
|
||||
for column in instance.__table__.columns:
|
||||
try:
|
||||
result[column.name] = getattr(instance, column.name)
|
||||
except Exception as e:
|
||||
logger.warning(f"无法访问字段 {column.name}: {e}")
|
||||
result[column.name] = None
|
||||
return result
|
||||
if instance is None:
|
||||
return {}
|
||||
|
||||
model = type(instance)
|
||||
column_names = _get_model_column_names(model)
|
||||
fetch_values = _get_model_value_fetcher(model)
|
||||
|
||||
try:
|
||||
values = fetch_values(instance)
|
||||
return dict(zip(column_names, values))
|
||||
except Exception as exc:
|
||||
logger.warning(f"无法转换模型 {model.__name__}: {exc}")
|
||||
fallback = {}
|
||||
for column in column_names:
|
||||
try:
|
||||
fallback[column] = getattr(instance, column)
|
||||
except Exception:
|
||||
fallback[column] = None
|
||||
return fallback
|
||||
|
||||
|
||||
def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
|
||||
@@ -55,8 +106,9 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
|
||||
模型实例 (detached, 所有字段已加载)
|
||||
"""
|
||||
instance = model_class()
|
||||
valid_fields = _get_model_field_set(model_class)
|
||||
for key, value in data.items():
|
||||
if hasattr(instance, key):
|
||||
if key in valid_fields:
|
||||
setattr(instance, key, value)
|
||||
return instance
|
||||
|
||||
|
||||
@@ -183,11 +183,14 @@ class QueryBuilder(Generic[T]):
|
||||
self._use_cache = False
|
||||
return self
|
||||
|
||||
async def all(self) -> list[T]:
|
||||
async def all(self, *, as_dict: bool = False) -> list[T] | list[dict[str, Any]]:
|
||||
"""获取所有结果
|
||||
|
||||
Args:
|
||||
as_dict: 为True时返回字典格式
|
||||
|
||||
Returns:
|
||||
模型实例列表
|
||||
模型实例列表或字典列表
|
||||
"""
|
||||
cache_key = ":".join(self._cache_key_parts) + ":all"
|
||||
|
||||
@@ -197,27 +200,33 @@ class QueryBuilder(Generic[T]):
|
||||
cached_dicts = await cache.get(cache_key)
|
||||
if cached_dicts is not None:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
# 从字典列表恢复对象列表
|
||||
return [_dict_to_model(self.model, d) for d in cached_dicts]
|
||||
dict_rows = [dict(row) for row in cached_dicts]
|
||||
if as_dict:
|
||||
return dict_rows
|
||||
return [_dict_to_model(self.model, row) for row in dict_rows]
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(self._stmt)
|
||||
instances = list(result.scalars().all())
|
||||
|
||||
# ✅ 在 session 内部转换为字典列表,此时所有字段都可安全访问
|
||||
# 在 session 内部转换为字典列表,此时所有字段都可安全访问
|
||||
instances_dicts = [_model_to_dict(inst) for inst in instances]
|
||||
|
||||
# 写入缓存
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instances_dicts)
|
||||
cache_payload = [dict(row) for row in instances_dicts]
|
||||
await cache.set(cache_key, cache_payload)
|
||||
|
||||
# 从字典列表重建对象列表返回(detached状态,所有字段已加载)
|
||||
return [_dict_to_model(self.model, d) for d in instances_dicts]
|
||||
if as_dict:
|
||||
return instances_dicts
|
||||
return [_dict_to_model(self.model, row) for row in instances_dicts]
|
||||
|
||||
async def first(self) -> T | None:
|
||||
"""获取第一个结果
|
||||
async def first(self, *, as_dict: bool = False) -> T | dict[str, Any] | None:
|
||||
"""获取第一条结果
|
||||
|
||||
Args:
|
||||
as_dict: 为True时返回字典格式
|
||||
|
||||
Returns:
|
||||
模型实例或None
|
||||
@@ -230,8 +239,10 @@ class QueryBuilder(Generic[T]):
|
||||
cached_dict = await cache.get(cache_key)
|
||||
if cached_dict is not None:
|
||||
logger.debug(f"缓存命中: {cache_key}")
|
||||
# 从字典恢复对象
|
||||
return _dict_to_model(self.model, cached_dict)
|
||||
row = dict(cached_dict)
|
||||
if as_dict:
|
||||
return row
|
||||
return _dict_to_model(self.model, row)
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
@@ -239,15 +250,16 @@ class QueryBuilder(Generic[T]):
|
||||
instance = result.scalars().first()
|
||||
|
||||
if instance is not None:
|
||||
# ✅ 在 session 内部转换为字典,此时所有字段都可安全访问
|
||||
# 在 session 内部转换为字典,此时所有字段都可安全访问
|
||||
instance_dict = _model_to_dict(instance)
|
||||
|
||||
# 写入缓存
|
||||
if self._use_cache:
|
||||
cache = await get_cache()
|
||||
await cache.set(cache_key, instance_dict)
|
||||
await cache.set(cache_key, dict(instance_dict))
|
||||
|
||||
# 从字典重建对象返回(detached状态,所有字段已加载)
|
||||
if as_dict:
|
||||
return instance_dict
|
||||
return _dict_to_model(self.model, instance_dict)
|
||||
|
||||
return None
|
||||
|
||||
@@ -13,6 +13,7 @@ from src.common.database.api import (
|
||||
from src.common.database.api import (
|
||||
store_action_info as new_store_action_info,
|
||||
)
|
||||
from src.common.database.api.crud import _model_to_dict as _crud_model_to_dict
|
||||
from src.common.database.core.models import (
|
||||
ActionRecords,
|
||||
AntiInjectionStats,
|
||||
@@ -123,21 +124,19 @@ async def build_filters(model_class, filters: dict[str, Any]):
|
||||
|
||||
|
||||
def _model_to_dict(instance) -> dict[str, Any]:
|
||||
"""将模型实例转换为字典
|
||||
"""将数据库模型实例转换为字典(兼容旧API
|
||||
|
||||
Args:
|
||||
instance: 模型实例
|
||||
instance: 数据库模型实例
|
||||
|
||||
Returns:
|
||||
字典表示
|
||||
"""
|
||||
if instance is None:
|
||||
return None
|
||||
return _crud_model_to_dict(instance)
|
||||
|
||||
|
||||
result = {}
|
||||
for column in instance.__table__.columns:
|
||||
result[column.name] = getattr(instance, column.name)
|
||||
return result
|
||||
|
||||
|
||||
async def db_query(
|
||||
@@ -211,11 +210,9 @@ async def db_query(
|
||||
|
||||
# 执行查询
|
||||
if single_result:
|
||||
result = await query_builder.first()
|
||||
return _model_to_dict(result)
|
||||
else:
|
||||
results = await query_builder.all()
|
||||
return [_model_to_dict(r) for r in results]
|
||||
return await query_builder.first(as_dict=True)
|
||||
|
||||
return await query_builder.all(as_dict=True)
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
|
||||
Reference in New Issue
Block a user