re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
parent ecb02cae31
commit 7923eafef3
263 changed files with 3103 additions and 3123 deletions

View File

@@ -1,10 +1,11 @@
import os
from rich.traceback import install
from src.common.logger import get_logger
# SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_engine, get_db_session
from src.common.database.sqlalchemy_models import get_db_session, get_engine
from src.common.logger import get_logger
install(extra_lines=3)

View File

@@ -6,31 +6,31 @@
import time
import traceback
from typing import Dict, List, Any, Union, Optional
from typing import Any
from sqlalchemy import desc, asc, func, and_, select
from sqlalchemy import and_, asc, desc, func, select
from sqlalchemy.exc import SQLAlchemyError
from src.common.database.sqlalchemy_models import (
get_db_session,
Messages,
ActionRecords,
PersonInfo,
ChatStreams,
LLMUsage,
Emoji,
Images,
ImageDescriptions,
OnlineTime,
Memory,
Expression,
ThinkingLog,
GraphNodes,
GraphEdges,
Schedule,
MaiZoneScheduleStatus,
CacheEntries,
ChatStreams,
Emoji,
Expression,
GraphEdges,
GraphNodes,
ImageDescriptions,
Images,
LLMUsage,
MaiZoneScheduleStatus,
Memory,
Messages,
OnlineTime,
PersonInfo,
Schedule,
ThinkingLog,
UserRelationships,
get_db_session,
)
from src.common.logger import get_logger
@@ -59,7 +59,7 @@ MODEL_MAPPING = {
}
async def build_filters(model_class, filters: Dict[str, Any]):
async def build_filters(model_class, filters: dict[str, Any]):
"""构建查询过滤条件"""
conditions = []
@@ -98,13 +98,13 @@ async def build_filters(model_class, filters: Dict[str, Any]):
async def db_query(
model_class,
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
data: dict[str, Any] | None = None,
query_type: str | None = "get",
filters: dict[str, Any] | None = None,
limit: int | None = None,
order_by: list[str] | None = None,
single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None:
"""执行异步数据库查询操作
Args:
@@ -263,8 +263,8 @@ async def db_query(
async def db_save(
model_class, data: Dict[str, Any], key_field: Optional[str] = None, key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
model_class, data: dict[str, Any], key_field: str | None = None, key_value: Any | None = None
) -> dict[str, Any] | None:
"""异步保存数据到数据库(创建或更新)
Args:
@@ -325,11 +325,11 @@ async def db_save(
async def db_get(
model_class,
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
filters: dict[str, Any] | None = None,
limit: int | None = None,
order_by: str | None = None,
single_result: bool | None = False,
) -> list[dict[str, Any]] | dict[str, Any] | None:
"""异步从数据库获取记录
Args:
@@ -359,9 +359,9 @@ async def store_action_info(
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_data: dict | None = None,
action_name: str = "",
) -> Optional[Dict[str, Any]]:
) -> dict[str, Any] | None:
"""异步存储动作信息到数据库
Args:

View File

@@ -4,10 +4,10 @@
提供统一的异步数据库初始化接口
"""
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import Base, get_engine, initialize_database
from src.common.logger import get_logger
logger = get_logger("sqlalchemy_init")
@@ -71,7 +71,7 @@ async def create_all_tables() -> bool:
return False
async def get_database_info() -> Optional[dict]:
async def get_database_info() -> dict | None:
"""
异步获取数据库信息

View File

@@ -6,11 +6,12 @@
import datetime
import os
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Optional, Any, Dict, AsyncGenerator
from typing import Any
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, DateTime, text
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
@@ -423,7 +424,7 @@ class Expression(Base):
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
type: Mapped[str] = mapped_column(Text, nullable=False)
create_date: Mapped[Optional[float]] = mapped_column(Float, nullable=True)
create_date: Mapped[float | None] = mapped_column(Float, nullable=True)
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
@@ -710,7 +711,7 @@ async def initialize_database():
config = global_config.database
# 配置引擎参数
engine_kwargs: Dict[str, Any] = {
engine_kwargs: dict[str, Any] = {
"echo": False, # 生产环境关闭SQL日志
"future": True,
}
@@ -759,12 +760,12 @@ async def initialize_database():
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[Optional[AsyncSession], None]:
async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]:
"""
异步数据库会话上下文管理器。
在初始化失败时会yield None调用方需要检查会话是否为None。
"""
session: Optional[AsyncSession] = None
session: AsyncSession | None = None
SessionLocal = None
try:
_, SessionLocal = await initialize_database()