依旧修pyright喵喵喵~

This commit is contained in:
ikun-11451
2025-11-29 22:20:55 +08:00
parent 574c2384a2
commit acafc074b1
18 changed files with 106 additions and 53 deletions

View File

@@ -38,6 +38,9 @@ class CacheManager:
初始化缓存管理器。
"""
if not hasattr(self, "_initialized"):
assert global_config is not None
assert model_config is not None
self.default_ttl = default_ttl or 3600
self.semantic_cache_collection_name = "semantic_cache"
@@ -87,6 +90,7 @@ class CacheManager:
embedding_array = embedding_array.flatten()
# 检查维度是否符合预期
assert global_config is not None
expected_dim = (
getattr(CacheManager, "embedding_dimension", None)
or global_config.lpmm_knowledge.embedding_dimension

View File

@@ -14,23 +14,29 @@ def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: boo
candidates: list[int | None] = []
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
if embedding_task is not None:
candidates.append(getattr(embedding_task, "embedding_dimension", None))
except Exception:
if model_config is not None:
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
if embedding_task is not None:
candidates.append(getattr(embedding_task, "embedding_dimension", None))
except Exception:
candidates.append(None)
else:
candidates.append(None)
try:
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
except Exception:
if global_config is not None:
try:
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
except Exception:
candidates.append(None)
else:
candidates.append(None)
candidates.append(fallback)
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
if resolved and sync_global:
if resolved and sync_global and global_config is not None:
try:
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]

View File

@@ -10,7 +10,7 @@ import asyncio
import functools
import hashlib
import time
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, ParamSpec, TypeVar
from sqlalchemy.exc import DBAPIError, OperationalError
@@ -82,7 +82,7 @@ def retry(
return await session.execute(stmt)
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception = None
@@ -130,7 +130,7 @@ def timeout(seconds: float):
return await session.execute(complex_stmt)
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
@@ -166,7 +166,7 @@ def cached(
return await query_user(user_id)
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 延迟导入避免循环依赖
@@ -228,7 +228,7 @@ def measure_time(log_slow: float | None = None):
return await session.execute(stmt)
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.perf_counter()
@@ -270,7 +270,7 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
函数需要接受session参数
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 查找session参数
@@ -335,7 +335,7 @@ def db_operation(
return await complex_operation()
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
# 从内到外应用装饰器
wrapped = func

View File

@@ -21,10 +21,15 @@ from structlog.typing import EventDict, WrappedLogger
class DaemonQueueListener(QueueListener):
"""QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。"""
def _configure_listener(self):
super()._configure_listener()
if hasattr(self, "_thread") and self._thread is not None: # type: ignore[attr-defined]
self._thread.daemon = True # type: ignore[attr-defined]
def start(self):
"""Start the listener.
This starts up a background thread to monitor the queue for
LogRecords to process.
"""
# 覆盖 start 方法以设置 daemon=True
# 注意_monitor 是 QueueListener 的内部方法
self._thread = threading.Thread(target=self._monitor, daemon=True) # type: ignore
self._thread.start()
def stop(self):
"""停止监听器,避免在退出时无限期阻塞。"""

View File

@@ -1,4 +1,5 @@
import os
from typing import Any
from mofox_wire import MessageServer
@@ -18,6 +19,8 @@ def get_global_api() -> MessageServer:
if global_api is not None:
return global_api
assert global_config is not None
bus_config = global_config.message_bus
host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000")
@@ -27,7 +30,7 @@ def get_global_api() -> MessageServer:
except ValueError:
port = 8000
kwargs: dict[str, object] = {
kwargs: dict[str, Any] = {
"host": host,
"port": port,
"app": get_global_server().get_app(),

View File

@@ -52,6 +52,7 @@ async def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
assert global_config is not None
async with get_db_session() as session:
query = select(Messages)

View File

@@ -42,6 +42,7 @@ class TelemetryHeartBeatTask(AsyncTask):
@staticmethod
def _get_sys_info() -> dict[str, str]:
"""获取系统信息"""
assert global_config is not None
info_dict = {
"os_type": "Unknown",
"py_version": platform.python_version(),

View File

@@ -16,6 +16,7 @@ async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str:
FastAPI 依赖项用于验证API密钥。
从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。
"""
assert bot_config is not None
valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys
if not valid_keys:
logger.warning("API密钥认证已启用但未配置任何有效的API密钥。所有请求都将被拒绝。")

View File

@@ -30,6 +30,7 @@ def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
class Server:
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"):
assert bot_config is not None
# 根据配置初始化速率限制器
limiter = Limiter(
key_func=get_remote_address,