Merge branch 'dev' into feature/kfc
This commit is contained in:
@@ -38,6 +38,9 @@ class CacheManager:
|
||||
初始化缓存管理器。
|
||||
"""
|
||||
if not hasattr(self, "_initialized"):
|
||||
assert global_config is not None
|
||||
assert model_config is not None
|
||||
|
||||
self.default_ttl = default_ttl or 3600
|
||||
self.semantic_cache_collection_name = "semantic_cache"
|
||||
|
||||
@@ -87,6 +90,7 @@ class CacheManager:
|
||||
embedding_array = embedding_array.flatten()
|
||||
|
||||
# 检查维度是否符合预期
|
||||
assert global_config is not None
|
||||
expected_dim = (
|
||||
getattr(CacheManager, "embedding_dimension", None)
|
||||
or global_config.lpmm_knowledge.embedding_dimension
|
||||
|
||||
@@ -14,23 +14,29 @@ def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: boo
|
||||
|
||||
candidates: list[int | None] = []
|
||||
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task is not None:
|
||||
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||
except Exception:
|
||||
if model_config is not None:
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task is not None:
|
||||
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
else:
|
||||
candidates.append(None)
|
||||
|
||||
try:
|
||||
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||
except Exception:
|
||||
if global_config is not None:
|
||||
try:
|
||||
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
else:
|
||||
candidates.append(None)
|
||||
|
||||
candidates.append(fallback)
|
||||
|
||||
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
||||
|
||||
if resolved and sync_global:
|
||||
if resolved and sync_global and global_config is not None:
|
||||
try:
|
||||
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
|
||||
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
|
||||
|
||||
@@ -62,7 +62,9 @@ class StreamContext(BaseDataModel):
|
||||
stream_id: str
|
||||
chat_type: ChatType = ChatType.PRIVATE # 聊天类型,默认为私聊
|
||||
chat_mode: ChatMode = ChatMode.FOCUS # 聊天模式,默认为专注模式
|
||||
max_context_size: int = field(default_factory=lambda: getattr(global_config.chat, "max_context_size", 100))
|
||||
max_context_size: int = field(
|
||||
default_factory=lambda: getattr(global_config.chat, "max_context_size", 100) if global_config else 100
|
||||
)
|
||||
unread_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||
history_messages: list["DatabaseMessages"] = field(default_factory=list)
|
||||
last_check_time: float = field(default_factory=time.time)
|
||||
@@ -98,7 +100,9 @@ class StreamContext(BaseDataModel):
|
||||
def __post_init__(self):
|
||||
"""初始化历史消息异步加载"""
|
||||
if not self.max_context_size or self.max_context_size <= 0:
|
||||
self.max_context_size = getattr(global_config.chat, "max_context_size", 100)
|
||||
self.max_context_size = (
|
||||
getattr(global_config.chat, "max_context_size", 100) if global_config else 100
|
||||
)
|
||||
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
@@ -118,6 +122,7 @@ class StreamContext(BaseDataModel):
|
||||
async def add_message(self, message: "DatabaseMessages", skip_energy_update: bool = False) -> bool:
|
||||
"""添加消息到上下文,支持跳过能量更新的选项"""
|
||||
try:
|
||||
assert global_config is not None
|
||||
cache_enabled = global_config.chat.enable_message_cache
|
||||
if cache_enabled and not self.is_cache_enabled:
|
||||
self.enable_cache(True)
|
||||
@@ -150,7 +155,7 @@ class StreamContext(BaseDataModel):
|
||||
# ͬ<><CDAC><EFBFBD><EFBFBD><EFBFBD>ݵ<EFBFBD>ͳһ<CDB3><D2BB><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
|
||||
try:
|
||||
if global_config.memory and global_config.memory.enable:
|
||||
unified_manager = _get_unified_memory_manager()
|
||||
unified_manager: Any = _get_unified_memory_manager()
|
||||
if unified_manager:
|
||||
message_dict = {
|
||||
"message_id": str(message.message_id),
|
||||
|
||||
@@ -9,9 +9,10 @@
|
||||
import operator
|
||||
from collections.abc import Callable
|
||||
from functools import lru_cache
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from sqlalchemy import delete, func, select, update
|
||||
from sqlalchemy.engine import CursorResult, Result
|
||||
|
||||
from src.common.database.core.models import Base
|
||||
from src.common.database.core.session import get_db_session
|
||||
@@ -25,23 +26,23 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.crud")
|
||||
|
||||
T = TypeVar("T", bound=Base)
|
||||
T = TypeVar("T", bound=Any)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _get_model_column_names(model: type[Base]) -> tuple[str, ...]:
|
||||
def _get_model_column_names(model: type[Any]) -> tuple[str, ...]:
|
||||
"""获取模型的列名称列表"""
|
||||
return tuple(column.name for column in model.__table__.columns)
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _get_model_field_set(model: type[Base]) -> frozenset[str]:
|
||||
def _get_model_field_set(model: type[Any]) -> frozenset[str]:
|
||||
"""获取模型的有效字段集合"""
|
||||
return frozenset(_get_model_column_names(model))
|
||||
|
||||
|
||||
@lru_cache(maxsize=256)
|
||||
def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, ...]]:
|
||||
def _get_model_value_fetcher(model: type[Any]) -> Callable[[Any], tuple[Any, ...]]:
|
||||
"""为模型准备attrgetter,用于批量获取属性值"""
|
||||
column_names = _get_model_column_names(model)
|
||||
|
||||
@@ -51,21 +52,21 @@ def _get_model_value_fetcher(model: type[Base]) -> Callable[[Base], tuple[Any, .
|
||||
if len(column_names) == 1:
|
||||
attr_name = column_names[0]
|
||||
|
||||
def _single(instance: Base) -> tuple[Any, ...]:
|
||||
def _single(instance: Any) -> tuple[Any, ...]:
|
||||
return (getattr(instance, attr_name),)
|
||||
|
||||
return _single
|
||||
|
||||
getter = operator.attrgetter(*column_names)
|
||||
|
||||
def _multi(instance: Base) -> tuple[Any, ...]:
|
||||
def _multi(instance: Any) -> tuple[Any, ...]:
|
||||
values = getter(instance)
|
||||
return values if isinstance(values, tuple) else (values,)
|
||||
|
||||
return _multi
|
||||
|
||||
|
||||
def _model_to_dict(instance: Base) -> dict[str, Any]:
|
||||
def _model_to_dict(instance: Any) -> dict[str, Any]:
|
||||
"""将 SQLAlchemy 模型实例转换为字典
|
||||
|
||||
Args:
|
||||
@@ -113,7 +114,7 @@ def _dict_to_model(model_class: type[T], data: dict[str, Any]) -> T:
|
||||
return instance
|
||||
|
||||
|
||||
class CRUDBase:
|
||||
class CRUDBase(Generic[T]):
|
||||
"""基础CRUD操作类
|
||||
|
||||
提供通用的增删改查操作,自动集成缓存和批处理
|
||||
@@ -246,7 +247,7 @@ class CRUDBase:
|
||||
cached_dicts = await cache.get(cache_key)
|
||||
if cached_dicts is not None:
|
||||
# 从字典列表恢复对象列表
|
||||
return [_dict_to_model(self.model, d) for d in cached_dicts]
|
||||
return [_dict_to_model(self.model, d) for d in cached_dicts] # type: ignore
|
||||
|
||||
# 从数据库查询
|
||||
async with get_db_session() as session:
|
||||
@@ -275,7 +276,7 @@ class CRUDBase:
|
||||
await cache.set(cache_key, instances_dicts)
|
||||
|
||||
# 从字典列表重建对象列表返回(detached状态,所有字段已加载)
|
||||
return [_dict_to_model(self.model, d) for d in instances_dicts]
|
||||
return [_dict_to_model(self.model, d) for d in instances_dicts] # type: ignore
|
||||
|
||||
async def create(
|
||||
self,
|
||||
@@ -417,7 +418,7 @@ class CRUDBase:
|
||||
async with get_db_session() as session:
|
||||
stmt = delete(self.model).where(self.model.id == id)
|
||||
result = await session.execute(stmt)
|
||||
success = result.rowcount > 0
|
||||
success = result.rowcount > 0 # type: ignore
|
||||
# 注意:commit在get_db_session的context manager退出时自动执行
|
||||
|
||||
# 清除缓存
|
||||
@@ -452,7 +453,7 @@ class CRUDBase:
|
||||
stmt = stmt.where(getattr(self.model, key) == value)
|
||||
|
||||
result = await session.execute(stmt)
|
||||
return result.scalar()
|
||||
return int(result.scalar() or 0)
|
||||
|
||||
async def exists(
|
||||
self,
|
||||
@@ -546,7 +547,7 @@ class CRUDBase:
|
||||
.values(**obj_in)
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
count += result.rowcount
|
||||
count += result.rowcount # type: ignore
|
||||
|
||||
# 清除缓存
|
||||
cache_key = f"{self.model_name}:id:{id}"
|
||||
|
||||
@@ -20,7 +20,7 @@ from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("database.query")
|
||||
|
||||
T = TypeVar("T", bound="Base")
|
||||
T = TypeVar("T", bound=Any)
|
||||
|
||||
|
||||
class QueryBuilder(Generic[T]):
|
||||
@@ -327,7 +327,7 @@ class QueryBuilder(Generic[T]):
|
||||
|
||||
items = await self.all()
|
||||
|
||||
return items, total
|
||||
return items, total # type: ignore
|
||||
|
||||
|
||||
class AggregateQuery:
|
||||
|
||||
@@ -122,7 +122,7 @@ async def get_recent_actions(
|
||||
动作记录列表
|
||||
"""
|
||||
query = QueryBuilder(ActionRecords)
|
||||
return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all()
|
||||
return await query.filter(chat_id=chat_id).order_by("-time").limit(limit).all() # type: ignore
|
||||
|
||||
|
||||
# ===== Messages 业务API =====
|
||||
@@ -148,7 +148,7 @@ async def get_chat_history(
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
.all()
|
||||
)
|
||||
) # type: ignore
|
||||
|
||||
|
||||
async def get_message_count(stream_id: str) -> int:
|
||||
@@ -292,7 +292,7 @@ async def get_active_streams(
|
||||
if platform:
|
||||
query = query.filter(platform=platform)
|
||||
|
||||
return await query.order_by("-last_message_time").limit(limit).all()
|
||||
return await query.order_by("-last_message_time").limit(limit).all() # type: ignore
|
||||
|
||||
|
||||
# ===== LLMUsage 业务API =====
|
||||
@@ -390,7 +390,7 @@ async def get_usage_statistics(
|
||||
# 聚合统计
|
||||
total_input = await query.sum("input_tokens")
|
||||
total_output = await query.sum("output_tokens")
|
||||
total_count = await query.filter().count() if hasattr(query, "count") else 0
|
||||
total_count = await getattr(query.filter(), "count")() if hasattr(query, "count") else 0
|
||||
|
||||
return {
|
||||
"total_input_tokens": int(total_input),
|
||||
|
||||
@@ -123,7 +123,7 @@ async def build_filters(model_class, filters: dict[str, Any]):
|
||||
return conditions
|
||||
|
||||
|
||||
def _model_to_dict(instance) -> dict[str, Any]:
|
||||
def _model_to_dict(instance) -> dict[str, Any] | None:
|
||||
"""将数据库模型实例转换为字典(兼容旧API
|
||||
|
||||
Args:
|
||||
@@ -238,7 +238,7 @@ async def db_query(
|
||||
return None
|
||||
|
||||
# 更新记录
|
||||
updated = await crud.update(instance.id, data)
|
||||
updated = await crud.update(instance.id, data) # type: ignore
|
||||
return _model_to_dict(updated)
|
||||
|
||||
elif query_type == "delete":
|
||||
@@ -257,7 +257,7 @@ async def db_query(
|
||||
return None
|
||||
|
||||
# 删除记录
|
||||
success = await crud.delete(instance.id)
|
||||
success = await crud.delete(instance.id) # type: ignore
|
||||
return {"deleted": success}
|
||||
|
||||
elif query_type == "count":
|
||||
|
||||
@@ -46,6 +46,7 @@ async def get_engine() -> AsyncEngine:
|
||||
if _engine_lock is None:
|
||||
_engine_lock = asyncio.Lock()
|
||||
|
||||
assert _engine_lock is not None
|
||||
# 使用锁保护初始化过程
|
||||
async with _engine_lock:
|
||||
# 双重检查锁定模式
|
||||
@@ -55,6 +56,7 @@ async def get_engine() -> AsyncEngine:
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
config = global_config.database
|
||||
db_type = config.database_type
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ def get_string_field(max_length=255, **kwargs):
|
||||
"""
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
db_type = global_config.database.database_type
|
||||
|
||||
# MySQL 索引需要指定长度的 VARCHAR
|
||||
|
||||
@@ -75,6 +75,7 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
|
||||
# 可以设置 schema 搜索路径等
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
schema = global_config.database.postgresql_schema
|
||||
if schema and schema != "public":
|
||||
await session.execute(text(f"SET search_path TO {schema}"))
|
||||
@@ -114,6 +115,7 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
# 获取数据库类型并应用特定设置
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
await _apply_session_settings(session, global_config.database.database_type)
|
||||
|
||||
yield session
|
||||
@@ -142,6 +144,7 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
||||
# 应用数据库特定设置
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
await _apply_session_settings(session, global_config.database.database_type)
|
||||
|
||||
yield session
|
||||
|
||||
@@ -13,7 +13,7 @@ from collections import defaultdict, deque
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from enum import IntEnum
|
||||
from typing import Any, TypeVar
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import delete, insert, select, update
|
||||
|
||||
@@ -23,8 +23,6 @@ from src.common.memory_utils import estimate_size_smart
|
||||
|
||||
logger = get_logger("batch_scheduler")
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class Priority(IntEnum):
|
||||
"""操作优先级"""
|
||||
@@ -429,7 +427,7 @@ class AdaptiveBatchScheduler:
|
||||
|
||||
# 执行更新(但不commit)
|
||||
result = await session.execute(stmt)
|
||||
results.append((op, result.rowcount))
|
||||
results.append((op, result.rowcount)) # type: ignore
|
||||
|
||||
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
||||
|
||||
@@ -471,7 +469,7 @@ class AdaptiveBatchScheduler:
|
||||
|
||||
# 执行删除(但不commit)
|
||||
result = await session.execute(stmt)
|
||||
results.append((op, result.rowcount))
|
||||
results.append((op, result.rowcount)) # type: ignore
|
||||
|
||||
# 注意:commit 由 get_db_session_direct 上下文管理器自动处理
|
||||
|
||||
|
||||
@@ -398,47 +398,48 @@ class MultiLevelCache:
|
||||
l2_stats_task = asyncio.create_task(self._get_cache_stats_safe(self.l2_cache, "L2"))
|
||||
|
||||
# 使用超时避免死锁
|
||||
try:
|
||||
l1_stats, l2_stats = await asyncio.gather(
|
||||
asyncio.wait_for(l1_stats_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_stats_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("缓存统计获取超时,使用基本统计")
|
||||
l1_stats = await self.l1_cache.get_stats()
|
||||
l2_stats = await self.l2_cache.get_stats()
|
||||
results = await asyncio.gather(
|
||||
asyncio.wait_for(l1_stats_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_stats_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
l1_stats = results[0]
|
||||
l2_stats = results[1]
|
||||
|
||||
# 处理异常情况
|
||||
if isinstance(l1_stats, Exception):
|
||||
if isinstance(l1_stats, BaseException):
|
||||
logger.error(f"L1统计获取失败: {l1_stats}")
|
||||
l1_stats = CacheStats()
|
||||
if isinstance(l2_stats, Exception):
|
||||
if isinstance(l2_stats, BaseException):
|
||||
logger.error(f"L2统计获取失败: {l2_stats}")
|
||||
l2_stats = CacheStats()
|
||||
|
||||
assert isinstance(l1_stats, CacheStats)
|
||||
assert isinstance(l2_stats, CacheStats)
|
||||
|
||||
# 🔧 修复:并行获取键集合,避免锁嵌套
|
||||
l1_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l1_cache))
|
||||
l2_keys_task = asyncio.create_task(self._get_cache_keys_safe(self.l2_cache))
|
||||
|
||||
try:
|
||||
l1_keys, l2_keys = await asyncio.gather(
|
||||
asyncio.wait_for(l1_keys_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_keys_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("缓存键获取超时,使用默认值")
|
||||
l1_keys, l2_keys = set(), set()
|
||||
results = await asyncio.gather(
|
||||
asyncio.wait_for(l1_keys_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_keys_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
l1_keys = results[0]
|
||||
l2_keys = results[1]
|
||||
|
||||
# 处理异常情况
|
||||
if isinstance(l1_keys, Exception):
|
||||
if isinstance(l1_keys, BaseException):
|
||||
logger.warning(f"L1键获取失败: {l1_keys}")
|
||||
l1_keys = set()
|
||||
if isinstance(l2_keys, Exception):
|
||||
if isinstance(l2_keys, BaseException):
|
||||
logger.warning(f"L2键获取失败: {l2_keys}")
|
||||
l2_keys = set()
|
||||
|
||||
assert isinstance(l1_keys, set)
|
||||
assert isinstance(l2_keys, set)
|
||||
|
||||
# 计算共享键和独占键
|
||||
shared_keys = l1_keys & l2_keys
|
||||
l1_only_keys = l1_keys - l2_keys
|
||||
@@ -448,24 +449,25 @@ class MultiLevelCache:
|
||||
l1_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l1_cache, l1_keys))
|
||||
l2_size_task = asyncio.create_task(self._calculate_memory_usage_safe(self.l2_cache, l2_keys))
|
||||
|
||||
try:
|
||||
l1_size, l2_size = await asyncio.gather(
|
||||
asyncio.wait_for(l1_size_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_size_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning("内存计算超时,使用统计值")
|
||||
l1_size, l2_size = l1_stats.total_size, l2_stats.total_size
|
||||
results = await asyncio.gather(
|
||||
asyncio.wait_for(l1_size_task, timeout=1.0),
|
||||
asyncio.wait_for(l2_size_task, timeout=1.0),
|
||||
return_exceptions=True
|
||||
)
|
||||
l1_size = results[0]
|
||||
l2_size = results[1]
|
||||
|
||||
# 处理异常情况
|
||||
if isinstance(l1_size, Exception):
|
||||
if isinstance(l1_size, BaseException):
|
||||
logger.warning(f"L1内存计算失败: {l1_size}")
|
||||
l1_size = l1_stats.total_size
|
||||
if isinstance(l2_size, Exception):
|
||||
if isinstance(l2_size, BaseException):
|
||||
logger.warning(f"L2内存计算失败: {l2_size}")
|
||||
l2_size = l2_stats.total_size
|
||||
|
||||
assert isinstance(l1_size, int)
|
||||
assert isinstance(l2_size, int)
|
||||
|
||||
# 计算实际总内存(避免重复计数)
|
||||
actual_total_size = l1_size + l2_size - min(l1_stats.total_size, l2_stats.total_size)
|
||||
|
||||
@@ -769,6 +771,7 @@ async def get_cache() -> MultiLevelCache:
|
||||
try:
|
||||
from src.config.config import global_config
|
||||
|
||||
assert global_config is not None
|
||||
db_config = global_config.database
|
||||
|
||||
# 检查是否启用缓存
|
||||
|
||||
@@ -10,8 +10,8 @@ import asyncio
|
||||
import functools
|
||||
import hashlib
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||
from sqlalchemy.exc import TimeoutError as SQLTimeoutError
|
||||
@@ -56,8 +56,9 @@ def generate_cache_key(
|
||||
|
||||
return ":".join(cache_key_parts)
|
||||
|
||||
T = TypeVar("T")
|
||||
F = TypeVar("F", bound=Callable[..., Awaitable[Any]])
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
def retry(
|
||||
@@ -77,14 +78,13 @@ def retry(
|
||||
exceptions: 需要重试的异常类型
|
||||
|
||||
Example:
|
||||
@retry(max_attempts=3, delay=1.0)
|
||||
async def query_data():
|
||||
return await session.execute(stmt)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
last_exception = None
|
||||
current_delay = delay
|
||||
|
||||
@@ -107,7 +107,9 @@ def retry(
|
||||
)
|
||||
|
||||
# 所有尝试都失败
|
||||
raise last_exception
|
||||
if last_exception:
|
||||
raise last_exception
|
||||
raise RuntimeError(f"Retry failed after {max_attempts} attempts")
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -128,9 +130,9 @@ def timeout(seconds: float):
|
||||
return await session.execute(complex_stmt)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
try:
|
||||
return await asyncio.wait_for(func(*args, **kwargs), timeout=seconds)
|
||||
except asyncio.TimeoutError:
|
||||
@@ -164,9 +166,9 @@ def cached(
|
||||
return await query_user(user_id)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# 延迟导入避免循环依赖
|
||||
from src.common.database.optimization import get_cache
|
||||
|
||||
@@ -225,9 +227,9 @@ def measure_time(log_slow: float | None = None):
|
||||
return await session.execute(stmt)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.perf_counter()
|
||||
|
||||
try:
|
||||
@@ -267,21 +269,23 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
||||
函数需要接受session参数
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: Any, **kwargs: Any) -> T:
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# 查找session参数
|
||||
session = None
|
||||
if args:
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
session: AsyncSession | None = None
|
||||
if args:
|
||||
for arg in args:
|
||||
if isinstance(arg, AsyncSession):
|
||||
session = arg
|
||||
break
|
||||
|
||||
if not session and "session" in kwargs:
|
||||
session = kwargs["session"]
|
||||
possible_session = kwargs["session"]
|
||||
if isinstance(possible_session, AsyncSession):
|
||||
session = possible_session
|
||||
|
||||
if not session:
|
||||
logger.warning(f"{func.__name__} 未找到session参数,跳过事务管理")
|
||||
@@ -330,7 +334,7 @@ def db_operation(
|
||||
return await complex_operation()
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., T]) -> Callable[..., T]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
# 从内到外应用装饰器
|
||||
wrapped = func
|
||||
|
||||
|
||||
@@ -21,10 +21,15 @@ from structlog.typing import EventDict, WrappedLogger
|
||||
class DaemonQueueListener(QueueListener):
|
||||
"""QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。"""
|
||||
|
||||
def _configure_listener(self):
|
||||
super()._configure_listener()
|
||||
if hasattr(self, "_thread") and self._thread is not None: # type: ignore[attr-defined]
|
||||
self._thread.daemon = True # type: ignore[attr-defined]
|
||||
def start(self):
|
||||
"""Start the listener.
|
||||
This starts up a background thread to monitor the queue for
|
||||
LogRecords to process.
|
||||
"""
|
||||
# 覆盖 start 方法以设置 daemon=True
|
||||
# 注意:_monitor 是 QueueListener 的内部方法
|
||||
self._thread = threading.Thread(target=self._monitor, daemon=True) # type: ignore
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""停止监听器,避免在退出时无限期阻塞。"""
|
||||
@@ -345,10 +350,12 @@ def load_log_config(): # sourcery skip: use-contextlib-suppress
|
||||
"websockets",
|
||||
"httpcore",
|
||||
"requests",
|
||||
"aiosqlite",
|
||||
"peewee",
|
||||
"openai",
|
||||
"uvicorn",
|
||||
"rjieba",
|
||||
"message_bus",
|
||||
],
|
||||
"library_log_levels": {"aiohttp": "WARNING"},
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mofox_wire import MessageServer
|
||||
|
||||
@@ -18,6 +19,8 @@ def get_global_api() -> MessageServer:
|
||||
if global_api is not None:
|
||||
return global_api
|
||||
|
||||
assert global_config is not None
|
||||
|
||||
bus_config = global_config.message_bus
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
@@ -27,7 +30,7 @@ def get_global_api() -> MessageServer:
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
kwargs: dict[str, object] = {
|
||||
kwargs: dict[str, Any] = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"app": get_global_server().get_app(),
|
||||
|
||||
@@ -52,6 +52,7 @@ async def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
assert global_config is not None
|
||||
async with get_db_session() as session:
|
||||
query = select(Messages)
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
@staticmethod
|
||||
def _get_sys_info() -> dict[str, str]:
|
||||
"""获取系统信息"""
|
||||
assert global_config is not None
|
||||
info_dict = {
|
||||
"os_type": "Unknown",
|
||||
"py_version": platform.python_version(),
|
||||
|
||||
@@ -16,6 +16,7 @@ async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str:
|
||||
FastAPI 依赖项,用于验证API密钥。
|
||||
从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。
|
||||
"""
|
||||
assert bot_config is not None
|
||||
valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys
|
||||
if not valid_keys:
|
||||
logger.warning("API密钥认证已启用,但未配置任何有效的API密钥。所有请求都将被拒绝。")
|
||||
|
||||
@@ -30,6 +30,7 @@ def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
|
||||
|
||||
class Server:
|
||||
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"):
|
||||
assert bot_config is not None
|
||||
# 根据配置初始化速率限制器
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
|
||||
Reference in New Issue
Block a user