rufffffff

This commit is contained in:
明天好像没什么
2025-11-01 21:10:01 +08:00
parent 08a9a2c2e8
commit cb97b2d8d3
50 changed files with 742 additions and 759 deletions

View File

@@ -14,14 +14,14 @@ from .adapter import (
)
__all__ = [
# 从 core 重新导出的函数
"get_db_session",
"get_engine",
# 兼容层适配器
"MODEL_MAPPING",
"build_filters",
"db_get",
"db_query",
"db_save",
"db_get",
# 从 core 重新导出的函数
"get_db_session",
"get_engine",
"store_action_info",
]

View File

@@ -4,15 +4,13 @@
保持原有函数签名和行为不变
"""
import time
from typing import Any, Optional
import orjson
from sqlalchemy import and_, asc, desc, select
from typing import Any
from src.common.database.api import (
CRUDBase,
QueryBuilder,
)
from src.common.database.api import (
store_action_info as new_store_action_info,
)
from src.common.database.core.models import (
@@ -34,15 +32,14 @@ from src.common.database.core.models import (
Messages,
MonthlyPlan,
OnlineTime,
PersonInfo,
PermissionNodes,
PersonInfo,
Schedule,
ThinkingLog,
UserPermissions,
UserRelationships,
Videos,
)
from src.common.database.core.session import get_db_session
from src.common.logger import get_logger
logger = get_logger("database.compatibility")
@@ -82,11 +79,11 @@ _crud_instances = {name: CRUDBase(model) for name, model in MODEL_MAPPING.items(
async def build_filters(model_class, filters: dict[str, Any]):
"""构建查询过滤条件兼容MongoDB风格操作符
Args:
model_class: SQLAlchemy模型类
filters: 过滤条件字典
Returns:
条件列表
"""
@@ -127,16 +124,16 @@ async def build_filters(model_class, filters: dict[str, Any]):
def _model_to_dict(instance) -> dict[str, Any]:
"""将模型实例转换为字典
Args:
instance: 模型实例
Returns:
字典表示
"""
if instance is None:
return None
result = {}
for column in instance.__table__.columns:
result[column.name] = getattr(instance, column.name)
@@ -145,15 +142,15 @@ def _model_to_dict(instance) -> dict[str, Any]:
async def db_query(
model_class,
data: Optional[dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[list[str]] = None,
single_result: Optional[bool] = False,
data: dict[str, Any] | None = None,
query_type: str | None = "get",
filters: dict[str, Any] | None = None,
limit: int | None = None,
order_by: list[str] | None = None,
single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None:
"""执行异步数据库查询操作兼容旧API
Args:
model_class: SQLAlchemy模型类
data: 用于创建或更新的数据字典
@@ -162,7 +159,7 @@ async def db_query(
limit: 限制结果数量
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回相应结果
"""
@@ -179,7 +176,7 @@ async def db_query(
if query_type == "get":
# 使用QueryBuilder
query_builder = QueryBuilder(model_class)
# 应用过滤条件
if filters:
# 将MongoDB风格过滤器转换为QueryBuilder格式
@@ -202,15 +199,15 @@ async def db_query(
query_builder = query_builder.filter(**{f"{field_name}__nin": op_value})
else:
query_builder = query_builder.filter(**{field_name: value})
# 应用排序
if order_by:
query_builder = query_builder.order_by(*order_by)
# 应用限制
if limit:
query_builder = query_builder.limit(limit)
# 执行查询
if single_result:
result = await query_builder.first()
@@ -223,7 +220,7 @@ async def db_query(
if not data:
logger.error("创建操作需要提供data参数")
return None
instance = await crud.create(data)
return _model_to_dict(instance)
@@ -231,17 +228,17 @@ async def db_query(
if not filters or not data:
logger.error("更新操作需要提供filters和data参数")
return None
# 先查找记录
query_builder = QueryBuilder(model_class)
for field_name, value in filters.items():
query_builder = query_builder.filter(**{field_name: value})
instance = await query_builder.first()
if not instance:
logger.warning(f"未找到匹配的记录: {filters}")
return None
# 更新记录
updated = await crud.update(instance.id, data)
return _model_to_dict(updated)
@@ -250,29 +247,29 @@ async def db_query(
if not filters:
logger.error("删除操作需要提供filters参数")
return None
# 先查找记录
query_builder = QueryBuilder(model_class)
for field_name, value in filters.items():
query_builder = query_builder.filter(**{field_name: value})
instance = await query_builder.first()
if not instance:
logger.warning(f"未找到匹配的记录: {filters}")
return None
# 删除记录
success = await crud.delete(instance.id)
return {"deleted": success}
elif query_type == "count":
query_builder = QueryBuilder(model_class)
# 应用过滤条件
if filters:
for field_name, value in filters.items():
query_builder = query_builder.filter(**{field_name: value})
count = await query_builder.count()
return {"count": count}
@@ -286,15 +283,15 @@ async def db_save(
data: dict[str, Any],
key_field: str,
key_value: Any,
) -> Optional[dict[str, Any]]:
) -> dict[str, Any] | None:
"""保存或更新记录兼容旧API
Args:
model_class: SQLAlchemy模型类
data: 数据字典
key_field: 主键字段名
key_value: 主键值
Returns:
保存的记录数据或None
"""
@@ -303,15 +300,15 @@ async def db_save(
crud = _crud_instances.get(model_name)
if not crud:
crud = CRUDBase(model_class)
# 使用get_or_create (返回tuple[T, bool])
instance, created = await crud.get_or_create(
defaults=data,
**{key_field: key_value},
)
return _model_to_dict(instance)
except Exception as e:
logger.error(f"保存数据库记录出错: {e}", exc_info=True)
return None
@@ -319,20 +316,20 @@ async def db_save(
async def db_get(
model_class,
filters: Optional[dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
filters: dict[str, Any] | None = None,
limit: int | None = None,
order_by: str | None = None,
single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None:
"""从数据库获取记录兼容旧API
Args:
model_class: SQLAlchemy模型类
filters: 过滤条件
limit: 结果数量限制
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
记录数据或None
"""
@@ -353,11 +350,11 @@ async def store_action_info(
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_data: dict | None = None,
action_name: str = "",
) -> Optional[dict[str, Any]]:
) -> dict[str, Any] | None:
"""存储动作信息到数据库兼容旧API
直接使用新的specialized API
"""
return await new_store_action_info(