依旧修pyright喵喵喵~
This commit is contained in:
@@ -38,6 +38,9 @@ class CacheManager:
|
||||
初始化缓存管理器。
|
||||
"""
|
||||
if not hasattr(self, "_initialized"):
|
||||
assert global_config is not None
|
||||
assert model_config is not None
|
||||
|
||||
self.default_ttl = default_ttl or 3600
|
||||
self.semantic_cache_collection_name = "semantic_cache"
|
||||
|
||||
@@ -87,6 +90,7 @@ class CacheManager:
|
||||
embedding_array = embedding_array.flatten()
|
||||
|
||||
# 检查维度是否符合预期
|
||||
assert global_config is not None
|
||||
expected_dim = (
|
||||
getattr(CacheManager, "embedding_dimension", None)
|
||||
or global_config.lpmm_knowledge.embedding_dimension
|
||||
|
||||
@@ -14,23 +14,29 @@ def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: boo
|
||||
|
||||
candidates: list[int | None] = []
|
||||
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task is not None:
|
||||
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||
except Exception:
|
||||
if model_config is not None:
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task is not None:
|
||||
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
else:
|
||||
candidates.append(None)
|
||||
|
||||
try:
|
||||
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||
except Exception:
|
||||
if global_config is not None:
|
||||
try:
|
||||
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
else:
|
||||
candidates.append(None)
|
||||
|
||||
candidates.append(fallback)
|
||||
|
||||
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
||||
|
||||
if resolved and sync_global:
|
||||
if resolved and sync_global and global_config is not None:
|
||||
try:
|
||||
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
|
||||
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
|
||||
|
||||
@@ -10,7 +10,7 @@ import asyncio
|
||||
import functools
|
||||
import hashlib
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||
@@ -82,7 +82,7 @@ def retry(
|
||||
return await session.execute(stmt)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
last_exception = None
|
||||
@@ -130,7 +130,7 @@ def timeout(seconds: float):
|
||||
return await session.execute(complex_stmt)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
try:
|
||||
@@ -166,7 +166,7 @@ def cached(
|
||||
return await query_user(user_id)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# 延迟导入避免循环依赖
|
||||
@@ -228,7 +228,7 @@ def measure_time(log_slow: float | None = None):
|
||||
return await session.execute(stmt)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.perf_counter()
|
||||
@@ -270,7 +270,7 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
||||
函数需要接受session参数
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# 查找session参数
|
||||
@@ -335,7 +335,7 @@ def db_operation(
|
||||
return await complex_operation()
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
# 从内到外应用装饰器
|
||||
wrapped = func
|
||||
|
||||
|
||||
@@ -21,10 +21,15 @@ from structlog.typing import EventDict, WrappedLogger
|
||||
class DaemonQueueListener(QueueListener):
|
||||
"""QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。"""
|
||||
|
||||
def _configure_listener(self):
|
||||
super()._configure_listener()
|
||||
if hasattr(self, "_thread") and self._thread is not None: # type: ignore[attr-defined]
|
||||
self._thread.daemon = True # type: ignore[attr-defined]
|
||||
def start(self):
|
||||
"""Start the listener.
|
||||
This starts up a background thread to monitor the queue for
|
||||
LogRecords to process.
|
||||
"""
|
||||
# 覆盖 start 方法以设置 daemon=True
|
||||
# 注意:_monitor 是 QueueListener 的内部方法
|
||||
self._thread = threading.Thread(target=self._monitor, daemon=True) # type: ignore
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""停止监听器,避免在退出时无限期阻塞。"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mofox_wire import MessageServer
|
||||
|
||||
@@ -18,6 +19,8 @@ def get_global_api() -> MessageServer:
|
||||
if global_api is not None:
|
||||
return global_api
|
||||
|
||||
assert global_config is not None
|
||||
|
||||
bus_config = global_config.message_bus
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
@@ -27,7 +30,7 @@ def get_global_api() -> MessageServer:
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
kwargs: dict[str, object] = {
|
||||
kwargs: dict[str, Any] = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"app": get_global_server().get_app(),
|
||||
|
||||
@@ -52,6 +52,7 @@ async def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
assert global_config is not None
|
||||
async with get_db_session() as session:
|
||||
query = select(Messages)
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
@staticmethod
|
||||
def _get_sys_info() -> dict[str, str]:
|
||||
"""获取系统信息"""
|
||||
assert global_config is not None
|
||||
info_dict = {
|
||||
"os_type": "Unknown",
|
||||
"py_version": platform.python_version(),
|
||||
|
||||
@@ -16,6 +16,7 @@ async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str:
|
||||
FastAPI 依赖项,用于验证API密钥。
|
||||
从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。
|
||||
"""
|
||||
assert bot_config is not None
|
||||
valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys
|
||||
if not valid_keys:
|
||||
logger.warning("API密钥认证已启用,但未配置任何有效的API密钥。所有请求都将被拒绝。")
|
||||
|
||||
@@ -30,6 +30,7 @@ def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
|
||||
|
||||
class Server:
|
||||
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"):
|
||||
assert bot_config is not None
|
||||
# 根据配置初始化速率限制器
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
|
||||
Reference in New Issue
Block a user