rufffffff
This commit is contained in:
@@ -14,14 +14,14 @@ from .adapter import (
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# 从 core 重新导出的函数
|
||||
"get_db_session",
|
||||
"get_engine",
|
||||
# 兼容层适配器
|
||||
"MODEL_MAPPING",
|
||||
"build_filters",
|
||||
"db_get",
|
||||
"db_query",
|
||||
"db_save",
|
||||
"db_get",
|
||||
# 从 core 重新导出的函数
|
||||
"get_db_session",
|
||||
"get_engine",
|
||||
"store_action_info",
|
||||
]
|
||||
|
||||
@@ -4,15 +4,13 @@
|
||||
保持原有函数签名和行为不变
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import and_, asc, desc, select
|
||||
from typing import Any
|
||||
|
||||
from src.common.database.api import (
|
||||
CRUDBase,
|
||||
QueryBuilder,
|
||||
)
|
||||
from src.common.database.api import (
|
||||
store_action_info as new_store_action_info,
|
||||
)
|
||||
from src.common.database.core.models import (
|
||||
@@ -34,15 +32,14 @@ from src.common.database.core.models import (
|
||||
Messages,
|
||||
MonthlyPlan,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
PermissionNodes,
|
||||
PersonInfo,
|
||||
Schedule,
|
||||
ThinkingLog,
|
||||
UserPermissions,
|
||||
UserRelationships,
|
||||
Videos,
|
||||
)
|
||||
from src.common.database.core.session import get_db_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.compatibility")
|
||||
@@ -82,11 +79,11 @@ _crud_instances = {name: CRUDBase(model) for name, model in MODEL_MAPPING.items(
|
||||
|
||||
async def build_filters(model_class, filters: dict[str, Any]):
|
||||
"""构建查询过滤条件(兼容MongoDB风格操作符)
|
||||
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
filters: 过滤条件字典
|
||||
|
||||
|
||||
Returns:
|
||||
条件列表
|
||||
"""
|
||||
@@ -127,16 +124,16 @@ async def build_filters(model_class, filters: dict[str, Any]):
|
||||
|
||||
def _model_to_dict(instance) -> dict[str, Any]:
|
||||
"""将模型实例转换为字典
|
||||
|
||||
|
||||
Args:
|
||||
instance: 模型实例
|
||||
|
||||
|
||||
Returns:
|
||||
字典表示
|
||||
"""
|
||||
if instance is None:
|
||||
return None
|
||||
|
||||
|
||||
result = {}
|
||||
for column in instance.__table__.columns:
|
||||
result[column.name] = getattr(instance, column.name)
|
||||
@@ -145,15 +142,15 @@ def _model_to_dict(instance) -> dict[str, Any]:
|
||||
|
||||
async def db_query(
|
||||
model_class,
|
||||
data: Optional[dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[list[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
data: dict[str, Any] | None = None,
|
||||
query_type: str | None = "get",
|
||||
filters: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[str] | None = None,
|
||||
single_result: bool | None = False,
|
||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||
"""执行异步数据库查询操作(兼容旧API)
|
||||
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 用于创建或更新的数据字典
|
||||
@@ -162,7 +159,7 @@ async def db_query(
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
|
||||
Returns:
|
||||
根据查询类型返回相应结果
|
||||
"""
|
||||
@@ -179,7 +176,7 @@ async def db_query(
|
||||
if query_type == "get":
|
||||
# 使用QueryBuilder
|
||||
query_builder = QueryBuilder(model_class)
|
||||
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
# 将MongoDB风格过滤器转换为QueryBuilder格式
|
||||
@@ -202,15 +199,15 @@ async def db_query(
|
||||
query_builder = query_builder.filter(**{f"{field_name}__nin": op_value})
|
||||
else:
|
||||
query_builder = query_builder.filter(**{field_name: value})
|
||||
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
query_builder = query_builder.order_by(*order_by)
|
||||
|
||||
|
||||
# 应用限制
|
||||
if limit:
|
||||
query_builder = query_builder.limit(limit)
|
||||
|
||||
|
||||
# 执行查询
|
||||
if single_result:
|
||||
result = await query_builder.first()
|
||||
@@ -223,7 +220,7 @@ async def db_query(
|
||||
if not data:
|
||||
logger.error("创建操作需要提供data参数")
|
||||
return None
|
||||
|
||||
|
||||
instance = await crud.create(data)
|
||||
return _model_to_dict(instance)
|
||||
|
||||
@@ -231,17 +228,17 @@ async def db_query(
|
||||
if not filters or not data:
|
||||
logger.error("更新操作需要提供filters和data参数")
|
||||
return None
|
||||
|
||||
|
||||
# 先查找记录
|
||||
query_builder = QueryBuilder(model_class)
|
||||
for field_name, value in filters.items():
|
||||
query_builder = query_builder.filter(**{field_name: value})
|
||||
|
||||
|
||||
instance = await query_builder.first()
|
||||
if not instance:
|
||||
logger.warning(f"未找到匹配的记录: {filters}")
|
||||
return None
|
||||
|
||||
|
||||
# 更新记录
|
||||
updated = await crud.update(instance.id, data)
|
||||
return _model_to_dict(updated)
|
||||
@@ -250,29 +247,29 @@ async def db_query(
|
||||
if not filters:
|
||||
logger.error("删除操作需要提供filters参数")
|
||||
return None
|
||||
|
||||
|
||||
# 先查找记录
|
||||
query_builder = QueryBuilder(model_class)
|
||||
for field_name, value in filters.items():
|
||||
query_builder = query_builder.filter(**{field_name: value})
|
||||
|
||||
|
||||
instance = await query_builder.first()
|
||||
if not instance:
|
||||
logger.warning(f"未找到匹配的记录: {filters}")
|
||||
return None
|
||||
|
||||
|
||||
# 删除记录
|
||||
success = await crud.delete(instance.id)
|
||||
return {"deleted": success}
|
||||
|
||||
elif query_type == "count":
|
||||
query_builder = QueryBuilder(model_class)
|
||||
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
for field_name, value in filters.items():
|
||||
query_builder = query_builder.filter(**{field_name: value})
|
||||
|
||||
|
||||
count = await query_builder.count()
|
||||
return {"count": count}
|
||||
|
||||
@@ -286,15 +283,15 @@ async def db_save(
|
||||
data: dict[str, Any],
|
||||
key_field: str,
|
||||
key_value: Any,
|
||||
) -> Optional[dict[str, Any]]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""保存或更新记录(兼容旧API)
|
||||
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 数据字典
|
||||
key_field: 主键字段名
|
||||
key_value: 主键值
|
||||
|
||||
|
||||
Returns:
|
||||
保存的记录数据或None
|
||||
"""
|
||||
@@ -303,15 +300,15 @@ async def db_save(
|
||||
crud = _crud_instances.get(model_name)
|
||||
if not crud:
|
||||
crud = CRUDBase(model_class)
|
||||
|
||||
|
||||
# 使用get_or_create (返回tuple[T, bool])
|
||||
instance, created = await crud.get_or_create(
|
||||
defaults=data,
|
||||
**{key_field: key_value},
|
||||
)
|
||||
|
||||
|
||||
return _model_to_dict(instance)
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存数据库记录出错: {e}", exc_info=True)
|
||||
return None
|
||||
@@ -319,20 +316,20 @@ async def db_save(
|
||||
|
||||
async def db_get(
|
||||
model_class,
|
||||
filters: Optional[dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
filters: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: str | None = None,
|
||||
single_result: bool | None = False,
|
||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||
"""从数据库获取记录(兼容旧API)
|
||||
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
filters: 过滤条件
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
|
||||
Returns:
|
||||
记录数据或None
|
||||
"""
|
||||
@@ -353,11 +350,11 @@ async def store_action_info(
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_data: dict | None = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[dict[str, Any]]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""存储动作信息到数据库(兼容旧API)
|
||||
|
||||
|
||||
直接使用新的specialized API
|
||||
"""
|
||||
return await new_store_action_info(
|
||||
|
||||
Reference in New Issue
Block a user