re-style: 格式化代码
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# SQLAlchemy相关导入
|
||||
from src.common.database.sqlalchemy_init import initialize_database_compat
|
||||
from src.common.database.sqlalchemy_models import get_engine, get_db_session
|
||||
from src.common.database.sqlalchemy_models import get_db_session, get_engine
|
||||
from src.common.logger import get_logger
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
@@ -6,31 +6,31 @@
|
||||
|
||||
import time
|
||||
import traceback
|
||||
from typing import Dict, List, Any, Union, Optional
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import desc, asc, func, and_, select
|
||||
from sqlalchemy import and_, asc, desc, func, select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
get_db_session,
|
||||
Messages,
|
||||
ActionRecords,
|
||||
PersonInfo,
|
||||
ChatStreams,
|
||||
LLMUsage,
|
||||
Emoji,
|
||||
Images,
|
||||
ImageDescriptions,
|
||||
OnlineTime,
|
||||
Memory,
|
||||
Expression,
|
||||
ThinkingLog,
|
||||
GraphNodes,
|
||||
GraphEdges,
|
||||
Schedule,
|
||||
MaiZoneScheduleStatus,
|
||||
CacheEntries,
|
||||
ChatStreams,
|
||||
Emoji,
|
||||
Expression,
|
||||
GraphEdges,
|
||||
GraphNodes,
|
||||
ImageDescriptions,
|
||||
Images,
|
||||
LLMUsage,
|
||||
MaiZoneScheduleStatus,
|
||||
Memory,
|
||||
Messages,
|
||||
OnlineTime,
|
||||
PersonInfo,
|
||||
Schedule,
|
||||
ThinkingLog,
|
||||
UserRelationships,
|
||||
get_db_session,
|
||||
)
|
||||
from src.common.logger import get_logger
|
||||
|
||||
@@ -59,7 +59,7 @@ MODEL_MAPPING = {
|
||||
}
|
||||
|
||||
|
||||
async def build_filters(model_class, filters: Dict[str, Any]):
|
||||
async def build_filters(model_class, filters: dict[str, Any]):
|
||||
"""构建查询过滤条件"""
|
||||
conditions = []
|
||||
|
||||
@@ -98,13 +98,13 @@ async def build_filters(model_class, filters: Dict[str, Any]):
|
||||
|
||||
async def db_query(
|
||||
model_class,
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
data: dict[str, Any] | None = None,
|
||||
query_type: str | None = "get",
|
||||
filters: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: list[str] | None = None,
|
||||
single_result: bool | None = False,
|
||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||
"""执行异步数据库查询操作
|
||||
|
||||
Args:
|
||||
@@ -263,8 +263,8 @@ async def db_query(
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
model_class, data: dict[str, Any], key_field: str | None = None, key_value: Any | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""异步保存数据到数据库(创建或更新)
|
||||
|
||||
Args:
|
||||
@@ -325,11 +325,11 @@ async def db_save(
|
||||
|
||||
async def db_get(
|
||||
model_class,
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
filters: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
order_by: str | None = None,
|
||||
single_result: bool | None = False,
|
||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
||||
"""异步从数据库获取记录
|
||||
|
||||
Args:
|
||||
@@ -359,9 +359,9 @@ async def store_action_info(
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_data: dict | None = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
) -> dict[str, Any] | None:
|
||||
"""异步存储动作信息到数据库
|
||||
|
||||
Args:
|
||||
|
||||
@@ -4,10 +4,10 @@
|
||||
提供统一的异步数据库初始化接口
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.common.database.sqlalchemy_models import Base, get_engine, initialize_database
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("sqlalchemy_init")
|
||||
|
||||
@@ -71,7 +71,7 @@ async def create_all_tables() -> bool:
|
||||
return False
|
||||
|
||||
|
||||
async def get_database_info() -> Optional[dict]:
|
||||
async def get_database_info() -> dict | None:
|
||||
"""
|
||||
异步获取数据库信息
|
||||
|
||||
|
||||
@@ -6,11 +6,12 @@
|
||||
import datetime
|
||||
import os
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import Optional, Any, Dict, AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime, text
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
|
||||
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
@@ -423,7 +424,7 @@ class Expression(Base):
|
||||
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||
type: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
|
||||
create_date: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
|
||||
|
||||
@@ -710,7 +711,7 @@ async def initialize_database():
|
||||
config = global_config.database
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs: Dict[str, Any] = {
|
||||
engine_kwargs: dict[str, Any] = {
|
||||
"echo": False, # 生产环境关闭SQL日志
|
||||
"future": True,
|
||||
}
|
||||
@@ -759,12 +760,12 @@ async def initialize_database():
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]:
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]:
|
||||
"""
|
||||
异步数据库会话上下文管理器。
|
||||
在初始化失败时会yield None,调用方需要检查会话是否为None。
|
||||
"""
|
||||
session: Optional[AsyncSession] = None
|
||||
session: AsyncSession | None = None
|
||||
SessionLocal = None
|
||||
try:
|
||||
_, SessionLocal = await initialize_database()
|
||||
|
||||
Reference in New Issue
Block a user