依旧修pyright喵喵喵~

This commit is contained in:
ikun-11451
2025-11-29 22:20:55 +08:00
parent 574c2384a2
commit acafc074b1
18 changed files with 106 additions and 53 deletions

View File

@@ -2,6 +2,7 @@ import asyncio
import math
import os
from dataclasses import dataclass
from typing import Any
# import tqdm
import aiofiles
@@ -121,7 +122,7 @@ class EmbeddingStore:
self.store = {}
self.faiss_index = None
self.faiss_index: Any = None
self.idx2hash = None
@staticmethod
@@ -158,6 +159,8 @@ class EmbeddingStore:
from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest
assert model_config is not None
# 限制 chunk_size 和 max_workers 在合理范围内
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
@@ -402,6 +405,7 @@ class EmbeddingStore:
def build_faiss_index(self) -> None:
"""重新构建Faiss索引以余弦相似度为度量"""
assert global_config is not None
# 获取所有的embedding
array = []
self.idx2hash = {}

View File

@@ -344,14 +344,15 @@ class ImageManager:
# --- 新的帧选择逻辑均匀抽取4帧 ---
num_frames = len(all_frames)
if num_frames <= 4:
# 如果总帧数小于等于4则全部选中
# 如果总宽度小于等于4则全部选中
selected_frames = all_frames
indices = list(range(num_frames))
else:
# 使用linspace计算4个均匀分布的索引
indices = np.linspace(0, num_frames - 1, 4, dtype=int)
selected_frames = [all_frames[i] for i in indices]
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices if num_frames > 4 else list(range(num_frames))}")
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices}")
# --- 帧选择逻辑结束 ---
# 如果选择后连一帧都没有比如GIF只有一帧且后续处理失败或者原始GIF就没帧也返回None

View File

@@ -37,7 +37,7 @@ _locks_guard = asyncio.Lock()
logger = get_logger("utils_video")
from inkfox import video
from inkfox import video # type: ignore
class VideoAnalyzer:
@@ -123,7 +123,6 @@ class VideoAnalyzer:
# ---- 批量分析 ----
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType
prompt = self.batch_analysis_prompt.format(
personality_core=self.personality_core, personality_side=self.personality_side
@@ -139,12 +138,7 @@ class VideoAnalyzer:
for b64, _ in frames:
mb.add_image_content("jpeg", b64)
message = mb.build()
model_info, api_provider, client = self.video_llm._select_model()
resp = await self.video_llm._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
resp = await self.video_llm.execute_with_messages(
message_list=[message],
temperature=None,
max_tokens=None,

View File

@@ -31,9 +31,9 @@ def _extract_frames_worker(
max_image_size: int,
frame_extraction_mode: str,
frame_interval_seconds: float | None,
) -> list[Any] | list[tuple[str, str]]:
) -> list[tuple[str, float]] | list[tuple[str, str]]:
"""线程池中提取视频帧的工作函数"""
frames = []
frames: list[tuple[str, float]] = []
try:
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
@@ -42,7 +42,7 @@ def _extract_frames_worker(
if frame_extraction_mode == "time_interval":
# 新模式:按时间间隔抽帧
time_interval = frame_interval_seconds
time_interval = frame_interval_seconds or 2.0
next_frame_time = 0.0
extracted_count = 0 # 初始化提取帧计数器
@@ -61,7 +61,7 @@ def _extract_frames_worker(
# 调整图像大小
if max(pil_image.size) > max_image_size:
ratio = max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为base64
@@ -240,6 +240,7 @@ class LegacyVideoAnalyzer:
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
else:
estimated_frames = self.max_frames
frame_interval = 1
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
@@ -276,7 +277,7 @@ class LegacyVideoAnalyzer:
return await self._extract_frames_fallback(video_path)
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
return frames
return frames # type: ignore
except Exception as e:
logger.error(f"线程池帧提取失败: {e}")
@@ -315,7 +316,7 @@ class LegacyVideoAnalyzer:
# 调整图像大小
if max(pil_image.size) > self.max_image_size:
ratio = self.max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为base64
@@ -463,11 +464,11 @@ class LegacyVideoAnalyzer:
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
# 获取模型信息和客户端
model_info, api_provider, client = self.video_llm._select_model()
model_info, api_provider, client = self.video_llm._select_model() # type: ignore
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
# 直接执行多图片请求
api_response = await self.video_llm._execute_request(
api_response = await self.video_llm._execute_request( # type: ignore
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,

View File

@@ -38,6 +38,9 @@ class CacheManager:
初始化缓存管理器。
"""
if not hasattr(self, "_initialized"):
assert global_config is not None
assert model_config is not None
self.default_ttl = default_ttl or 3600
self.semantic_cache_collection_name = "semantic_cache"
@@ -87,6 +90,7 @@ class CacheManager:
embedding_array = embedding_array.flatten()
# 检查维度是否符合预期
assert global_config is not None
expected_dim = (
getattr(CacheManager, "embedding_dimension", None)
or global_config.lpmm_knowledge.embedding_dimension

View File

@@ -14,23 +14,29 @@ def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: boo
candidates: list[int | None] = []
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
if embedding_task is not None:
candidates.append(getattr(embedding_task, "embedding_dimension", None))
except Exception:
if model_config is not None:
try:
embedding_task = getattr(model_config.model_task_config, "embedding", None)
if embedding_task is not None:
candidates.append(getattr(embedding_task, "embedding_dimension", None))
except Exception:
candidates.append(None)
else:
candidates.append(None)
try:
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
except Exception:
if global_config is not None:
try:
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
except Exception:
candidates.append(None)
else:
candidates.append(None)
candidates.append(fallback)
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
if resolved and sync_global:
if resolved and sync_global and global_config is not None:
try:
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]

View File

@@ -10,7 +10,7 @@ import asyncio
import functools
import hashlib
import time
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, ParamSpec, TypeVar
from sqlalchemy.exc import DBAPIError, OperationalError
@@ -82,7 +82,7 @@ def retry(
return await session.execute(stmt)
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception = None
@@ -130,7 +130,7 @@ def timeout(seconds: float):
return await session.execute(complex_stmt)
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try:
@@ -166,7 +166,7 @@ def cached(
return await query_user(user_id)
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 延迟导入避免循环依赖
@@ -228,7 +228,7 @@ def measure_time(log_slow: float | None = None):
return await session.execute(stmt)
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.perf_counter()
@@ -270,7 +270,7 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
函数需要接受session参数
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 查找session参数
@@ -335,7 +335,7 @@ def db_operation(
return await complex_operation()
"""
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
# 从内到外应用装饰器
wrapped = func

View File

@@ -21,10 +21,15 @@ from structlog.typing import EventDict, WrappedLogger
class DaemonQueueListener(QueueListener):
"""QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。"""
def _configure_listener(self):
super()._configure_listener()
if hasattr(self, "_thread") and self._thread is not None: # type: ignore[attr-defined]
self._thread.daemon = True # type: ignore[attr-defined]
def start(self):
"""Start the listener.
This starts up a background thread to monitor the queue for
LogRecords to process.
"""
# 覆盖 start 方法以设置 daemon=True
# 注意_monitor 是 QueueListener 的内部方法
self._thread = threading.Thread(target=self._monitor, daemon=True) # type: ignore
self._thread.start()
def stop(self):
"""停止监听器,避免在退出时无限期阻塞。"""

View File

@@ -1,4 +1,5 @@
import os
from typing import Any
from mofox_wire import MessageServer
@@ -18,6 +19,8 @@ def get_global_api() -> MessageServer:
if global_api is not None:
return global_api
assert global_config is not None
bus_config = global_config.message_bus
host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000")
@@ -27,7 +30,7 @@ def get_global_api() -> MessageServer:
except ValueError:
port = 8000
kwargs: dict[str, object] = {
kwargs: dict[str, Any] = {
"host": host,
"port": port,
"app": get_global_server().get_app(),

View File

@@ -52,6 +52,7 @@ async def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
assert global_config is not None
async with get_db_session() as session:
query = select(Messages)

View File

@@ -42,6 +42,7 @@ class TelemetryHeartBeatTask(AsyncTask):
@staticmethod
def _get_sys_info() -> dict[str, str]:
"""获取系统信息"""
assert global_config is not None
info_dict = {
"os_type": "Unknown",
"py_version": platform.python_version(),

View File

@@ -16,6 +16,7 @@ async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str:
FastAPI 依赖项用于验证API密钥。
从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。
"""
assert bot_config is not None
valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys
if not valid_keys:
logger.warning("API密钥认证已启用但未配置任何有效的API密钥。所有请求都将被拒绝。")

View File

@@ -30,6 +30,7 @@ def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
class Server:
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"):
assert bot_config is not None
# 根据配置初始化速率限制器
limiter = Limiter(
key_func=get_remote_address,

View File

@@ -176,7 +176,7 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
if key not in reference:
del target[key]
elif isinstance(target.get(key), dict | Table) and isinstance(reference.get(key), dict | Table):
_remove_obsolete_keys(target[key], reference[key])
_remove_obsolete_keys(target[key], reference[key]) # type: ignore
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
@@ -433,9 +433,9 @@ class Config(ValidatedConfigBase):
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
models: list[ModelInfo] = Field(..., min_items=1, description="模型列表")
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_items=1, description="API提供商列表")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
def __init__(self, **data):
super().__init__(**data)

View File

@@ -177,7 +177,9 @@ class ValidatedConfigBase(BaseModel):
element_index = field_path[1]
# 尝试获取父字段的类型信息
parent_field_info = cls.model_fields.get(parent_field)
parent_field_info = None
if isinstance(parent_field, str):
parent_field_info = cls.model_fields.get(parent_field)
if parent_field_info and hasattr(parent_field_info, "annotation"):
expected_type = parent_field_info.annotation
@@ -214,7 +216,9 @@ class ValidatedConfigBase(BaseModel):
# 处理模型类型错误
elif error_type in ["model_type", "dict_type", "is_instance_of"]:
field_name = field_path[0] if field_path else "unknown"
field_info = cls.model_fields.get(field_name)
field_info = None
if isinstance(field_name, str):
field_info = cls.model_fields.get(field_name)
if field_info and hasattr(field_info, "annotation"):
expected_type = field_info.annotation

View File

@@ -638,15 +638,15 @@ class PlanningSystemConfig(ValidatedConfigBase):
"""规划系统配置 (日程与月度计划)"""
# --- 日程生成 (原 ScheduleConfig) ---
schedule_enable: bool = Field(True, description="是否启用每日日程生成功能")
schedule_guidelines: str = Field("", description="日程生成指导原则")
schedule_enable: bool = Field(default=True, description="是否启用每日日程生成功能")
schedule_guidelines: str = Field(default="", description="日程生成指导原则")
# --- 月度计划 (原 MonthlyPlanSystemConfig) ---
monthly_plan_enable: bool = Field(True, description="是否启用月度计划系统")
monthly_plan_guidelines: str = Field("", description="月度计划生成指导原则")
max_plans_per_month: int = Field(10, description="每月最多生成的计划数量")
avoid_repetition_days: int = Field(7, description="避免在多少天内重复使用同一个月度计划")
completion_threshold: int = Field(3, description="一个月度计划被使用多少次后算作完成")
monthly_plan_enable: bool = Field(default=True, description="是否启用月度计划系统")
monthly_plan_guidelines: str = Field(default="", description="月度计划生成指导原则")
max_plans_per_month: int = Field(default=10, description="每月最多生成的计划数量")
avoid_repetition_days: int = Field(default=7, description="避免在多少天内重复使用同一个月度计划")
completion_threshold: int = Field(default=3, description="一个月度计划被使用多少次后算作完成")
class DependencyManagementConfig(ValidatedConfigBase):

View File

@@ -6,10 +6,18 @@ import orjson
from rich.traceback import install
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.config.config import global_config as _global_config, model_config as _model_config
from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import get_person_info_manager
if _global_config is None:
raise ValueError("global_config is not initialized")
if _model_config is None:
raise ValueError("model_config is not initialized")
global_config = _global_config
model_config = _model_config
install(extra_lines=3)
logger = get_logger("individuality")

View File

@@ -1152,6 +1152,25 @@ class LLMRequest:
return embeddings, model_info.name # type: ignore[return-value]
async def execute_with_messages(
self,
message_list: list[Message],
temperature: float | None = None,
max_tokens: int | None = None,
) -> APIResponse:
"""
使用自定义消息列表执行请求(支持多模态/多图)。
"""
start_time = time.time()
response, model_info = await self._strategy.execute_with_failover(
RequestType.RESPONSE,
message_list=message_list,
temperature=temperature,
max_tokens=max_tokens,
)
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
return response
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
"""
记录模型使用情况。