Merge branch 'dev' into feature/kfc

This commit is contained in:
拾风
2025-12-01 16:06:47 +08:00
committed by GitHub
87 changed files with 6181 additions and 2355 deletions

View File

@@ -9,9 +9,10 @@
import operator
from collections.abc import Callable
from functools import lru_cache
from typing import Any, TypeVar
from typing import Any, Generic, TypeVar
from sqlalchemy import delete, func, select, update
from sqlalchemy.engine import CursorResult, Result
from src.common.database.core.models import Base
from src.common.database.core.session import get_db_session
@@ -25,23 +26,23 @@ from src.common.logger import get_logger
logger = get_logger("database.crud")
T = TypeVar("T", bound=Base)
T = TypeVar("T", bound=Any)
@lru_cache(maxsize=256)
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
def _get_model_column_names(model: type[Any]) -> tuple[str, ...]:
"""获取模型的列名称列表"""
return tuple(column.name for column in model.__table__.columns)
@lru_cache(maxsize=256)
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
def _get_model_field_set(model: type[Any]) -> frozenset[str]:
"""获取模型的有效字段集合"""
return frozenset(_get_model_column_names(model))
@lru_cache(maxsize=256)
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
def _get_model_value_fetcher(model: type[Any]) -> Callable[[Any], tuple[Any, ...]]:
"""为模型准备attrgetter用于批量获取属性值"""
column_names = _get_model_column_names(model)
@@ -51,21 +52,21 @@ def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, .
if len(column_names) == 1:
attr_name = column_names[0]
def _single(instance: Base) -> tuple[Any, ...]:
def _single(instance: Any) -> tuple[Any, ...]:
return (getattr(instance, attr_name),)
return _single
getter = operator.attrgetter(*column_names)
def _multi(instance: Base) -> tuple[Any, ...]:
def _multi(instance: Any) -> tuple[Any, ...]:
values = getter(instance)
return values if isinstance(values, tuple) else (values,)
return _multi
def _model_to_dict(instance: Base) -> dict[str, Any]:
def _model_to_dict(instance: Any) -> dict[str, Any]:
"""将 SQLAlchemy 模型实例转换为字典
Args:
@@ -113,7 +114,7 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
return instance
class CRUDBase:
class CRUDBase(Generic[T]):
"""基础CRUD操作类
提供通用的增删改查操作,自动集成缓存和批处理
@@ -246,7 +247,7 @@ class CRUDBase:
cached_dicts = await cache.get(cache_key)
if cached_dicts is not None:
# 从字典列表恢复对象列表
return [_dict_to_model(self.model, d) for d in cached_dicts]
return [_dict_to_model(self.model, d) for d in cached_dicts] # type: ignore
# 从数据库查询
async with get_db_session() as session:
@@ -275,7 +276,7 @@ class CRUDBase:
await cache.set(cache_key, instances_dicts)
# 从字典列表重建对象列表返回detached状态所有字段已加载
return [_dict_to_model(self.model, d) for d in instances_dicts]
return [_dict_to_model(self.model, d) for d in instances_dicts] # type: ignore
async def create(
self,
@@ -417,7 +418,7 @@ class CRUDBase:
async with get_db_session() as session:
stmt = delete(self.model).where(self.model.id == id)
result = await session.execute(stmt)
success = result.rowcount > 0
success = result.rowcount > 0 # type: ignore
# 注意commit在get_db_session的context manager退出时自动执行
# 清除缓存
@@ -452,7 +453,7 @@ class CRUDBase:
stmt = stmt.where(getattr(self.model, key) == value)
result = await session.execute(stmt)
return result.scalar()
return int(result.scalar() or 0)
async def exists(
self,
@@ -546,7 +547,7 @@ class CRUDBase:
.values(**obj_in)
)
result = await session.execute(stmt)
count += result.rowcount
count += result.rowcount # type: ignore
# 清除缓存
cache_key = f"{self.model_name}:id:{id}"

View File

@@ -20,7 +20,7 @@ from src.common.logger import get_logger
logger = get_logger("database.query")
T = TypeVar("T", bound="Base")
T = TypeVar("T", bound=Any)
class QueryBuilder(Generic[T]):
@@ -327,7 +327,7 @@ class QueryBuilder(Generic[T]):
items = await self.all()
return items, total
return items, total # type: ignore
class AggregateQuery:

View File

@@ -122,7 +122,7 @@ async def get_recent_actions(
动作记录列表
"""
query = QueryBuilder(ActionRecords)
return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all()
return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all() # type: ignore
# ===== Messages 业务API =====
@@ -148,7 +148,7 @@ async def get_chat_history(
.limit(limit)
.offset(offset)
.all()
)
) # type: ignore
async def get_message_count(stream_id: str) -> int:
@@ -292,7 +292,7 @@ async def get_active_streams(
if platform:
query = query.filter(platform=platform)
return await query.order_by("-last_message_time").limit(limit).all()
return await query.order_by("-last_message_time").limit(limit).all() # type: ignore
# ===== LLMUsage 业务API =====
@@ -390,7 +390,7 @@ async def get_usage_statistics(
# 聚合统计
total_input = await query.sum("input_tokens")
total_output = await query.sum("output_tokens")
total_count = await query.filter().count() if hasattr(query, "count") else 0
total_count = await getattr(query.filter(), "count")() if hasattr(query, "count") else 0
return {
"total_input_tokens": int(total_input),

View File

@@ -123,7 +123,7 @@ async def build_filters(model_class, filters: dict[str, Any]):
return conditions
def _model_to_dict(instance) -> dict[str, Any]:
def _model_to_dict(instance) -> dict[str, Any] | None:
"""将数据库模型实例转换为字典兼容旧API
Args:
@@ -238,7 +238,7 @@ async def db_query(
return None
# 更新记录
updated = await crud.update(instance.id, data)
updated = await crud.update(instance.id, data) # type: ignore
return _model_to_dict(updated)
elif query_type == "delete":
@@ -257,7 +257,7 @@ async def db_query(
return None
# 删除记录
success = await crud.delete(instance.id)
success = await crud.delete(instance.id) # type: ignore
return {"deleted": success}
elif query_type == "count":

View File

@@ -46,6 +46,7 @@ async def get_engine() -> AsyncEngine:
if _engine_lock is None:
_engine_lock = asyncio.Lock()
assert _engine_lock is not None
# 使用锁保护初始化过程
async with _engine_lock:
# 双重检查锁定模式
@@ -55,6 +56,7 @@ async def get_engine() -> AsyncEngine:
try:
from src.config.config import global_config
assert global_config is not None
config = global_config.database
db_type = config.database_type

View File

@@ -44,6 +44,7 @@ def get_string_field(max_length=255, **kwargs):
"""
from src.config.config import global_config
assert global_config is not None
db_type = global_config.database.database_type
# MySQL 索引需要指定长度的 VARCHAR

View File

@@ -75,6 +75,7 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
# 可以设置 schema 搜索路径等
from src.config.config import global_config
assert global_config is not None
schema = global_config.database.postgresql_schema
if schema and schema != "public":
await session.execute(text(f"SET search_path TO {schema}"))
@@ -114,6 +115,7 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
# 获取数据库类型并应用特定设置
from src.config.config import global_config
assert global_config is not None
await _apply_session_settings(session, global_config.database.database_type)
yield session
@@ -142,6 +144,7 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
# 应用数据库特定设置
from src.config.config import global_config
assert global_config is not None
await _apply_session_settings(session, global_config.database.database_type)
yield session

View File

@@ -13,7 +13,7 @@ from collections import defaultdict, deque
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import IntEnum
from typing import Any, TypeVar
from typing import Any
from sqlalchemy import delete, insert, select, update
@@ -23,8 +23,6 @@ from src.common.memory_utils import estimate_size_smart
logger = get_logger("batch_scheduler")
T = TypeVar("T")
class Priority(IntEnum):
"""操作优先级"""
@@ -429,7 +427,7 @@ class AdaptiveBatchScheduler:
# 执行更新但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
results.append((op, result.rowcount)) # type: ignore
# 注意commit 由 get_db_session_direct 上下文管理器自动处理
@@ -471,7 +469,7 @@ class AdaptiveBatchScheduler:
# 执行删除但不commit
result = await session.execute(stmt)
results.append((op, result.rowcount))
results.append((op, result.rowcount)) # type: ignore
# 注意commit 由 get_db_session_direct 上下文管理器自动处理

View File

@@ -398,47 +398,48 @@ class MultiLevelCache:
l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2"))
# 使用超时避免死锁
try:
l1_stats, l2_stats = await asyncio.gather(
asyncio.wait_for(l1_stats_task, timeout=1.0),
asyncio.wait_for(l2_stats_task, timeout=1.0),
return_exceptions=True
)
except asyncio.TimeoutError:
logger.warning("缓存统计获取超时,使用基本统计")
l1_stats = await self.l1_cache.get_stats()
l2_stats = await self.l2_cache.get_stats()
results = await asyncio.gather(
asyncio.wait_for(l1_stats_task, timeout=1.0),
asyncio.wait_for(l2_stats_task, timeout=1.0),
return_exceptions=True
)
l1_stats = results[0]
l2_stats = results[1]
# 处理异常情况
if isinstance(l1_stats, Exception):
if isinstance(l1_stats, BaseException):
logger.error(f"L1统计获取失败: {l1_stats}")
l1_stats = CacheStats()
if isinstance(l2_stats, Exception):
if isinstance(l2_stats, BaseException):
logger.error(f"L2统计获取失败: {l2_stats}")
l2_stats = CacheStats()
assert isinstance(l1_stats, CacheStats)
assert isinstance(l2_stats, CacheStats)
# 🔧 修复:并行获取键集合,避免锁嵌套
l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache))
l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache))
try:
l1_keys, l2_keys = await asyncio.gather(
asyncio.wait_for(l1_keys_task, timeout=1.0),
asyncio.wait_for(l2_keys_task, timeout=1.0),
return_exceptions=True
)
except asyncio.TimeoutError:
logger.warning("缓存键获取超时,使用默认值")
l1_keys, l2_keys = set(), set()
results = await asyncio.gather(
asyncio.wait_for(l1_keys_task, timeout=1.0),
asyncio.wait_for(l2_keys_task, timeout=1.0),
return_exceptions=True
)
l1_keys = results[0]
l2_keys = results[1]
# 处理异常情况
if isinstance(l1_keys, Exception):
if isinstance(l1_keys, BaseException):
logger.warning(f"L1键获取失败: {l1_keys}")
l1_keys = set()
if isinstance(l2_keys, Exception):
if isinstance(l2_keys, BaseException):
logger.warning(f"L2键获取失败: {l2_keys}")
l2_keys = set()
assert isinstance(l1_keys, set)
assert isinstance(l2_keys, set)
# 计算共享键和独占键
shared_keys = l1_keys & l2_keys
l1_only_keys = l1_keys - l2_keys
@@ -448,24 +449,25 @@ class MultiLevelCache:
l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys))
l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys))
try:
l1_size, l2_size = await asyncio.gather(
asyncio.wait_for(l1_size_task, timeout=1.0),
asyncio.wait_for(l2_size_task, timeout=1.0),
return_exceptions=True
)
except asyncio.TimeoutError:
logger.warning("内存计算超时,使用统计值")
l1_size, l2_size = l1_stats.total_size, l2_stats.total_size
results = await asyncio.gather(
asyncio.wait_for(l1_size_task, timeout=1.0),
asyncio.wait_for(l2_size_task, timeout=1.0),
return_exceptions=True
)
l1_size = results[0]
l2_size = results[1]
# 处理异常情况
if isinstance(l1_size, Exception):
if isinstance(l1_size, BaseException):
logger.warning(f"L1内存计算失败: {l1_size}")
l1_size = l1_stats.total_size
if isinstance(l2_size, Exception):
if isinstance(l2_size, BaseException):
logger.warning(f"L2内存计算失败: {l2_size}")
l2_size = l2_stats.total_size
assert isinstance(l1_size, int)
assert isinstance(l2_size, int)
# 计算实际总内存(避免重复计数)
actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size)
@@ -769,6 +771,7 @@ async def get_cache() -> MultiLevelCache:
try:
from src.config.config import global_config
assert global_config is not None
db_config = global_config.database
# 检查是否启用缓存

View File

@@ -10,8 +10,8 @@ import asyncio
import functools
import hashlib
import time
from collections.abc import Awaitable, Callable
from typing import Any, TypeVar
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, ParamSpec, TypeVar
from sqlalchemy.exc import DBAPIError, OperationalError
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
@@ -56,8 +56,9 @@ def generate_cache_key(
return ":".join(cache_key_parts)
T = TypeVar("T")
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
P = ParamSpec("P")
R = TypeVar("R")
def retry(
@@ -77,14 +78,13 @@ def retry(
exceptions: 需要重试的异常类型
Example:
@retry(max_attempts=3, delay=1.0)
async def query_data():
return await session.execute(stmt)
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception = None
current_delay = delay
@@ -107,7 +107,9 @@ def retry(
)
# 所有尝试都失败
raise last_exception
if last_exception:
raise last_exception
raise RuntimeError(f"Retry failed after {max_attempts} attempts")
return wrapper
@@ -128,9 +130,9 @@ def timeout(seconds: float):
return await session.execute(complex_stmt)
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
except asyncio.TimeoutError:
@@ -164,9 +166,9 @@ def cached(
return await query_user(user_id)
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 延迟导入避免循环依赖
from src.common.database.optimization import get_cache
@@ -225,9 +227,9 @@ def measure_time(log_slow: float | None = None):
return await session.execute(stmt)
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.perf_counter()
try:
@@ -267,21 +269,23 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
函数需要接受session参数
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> T:
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 查找session参数
session = None
if args:
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.asyncio import AsyncSession
session: AsyncSession | None = None
if args:
for arg in args:
if isinstance(arg, AsyncSession):
session = arg
break
if not session and "session" in kwargs:
session = kwargs["session"]
possible_session = kwargs["session"]
if isinstance(possible_session, AsyncSession):
session = possible_session
if not session:
logger.warning(f"{func.__name__} 未找到session参数跳过事务管理")
@@ -330,7 +334,7 @@ def db_operation(
return await complex_operation()
"""
def decorator(func: Callable[..., T]) -> Callable[..., T]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
# 从内到外应用装饰器
wrapped = func