diff --git a/src/chat/knowledge/embedding_store.py b/src/chat/knowledge/embedding_store.py index b286fa968..1751b198d 100644 --- a/src/chat/knowledge/embedding_store.py +++ b/src/chat/knowledge/embedding_store.py @@ -2,6 +2,7 @@ import asyncio import math import os from dataclasses import dataclass +from typing import Any # import tqdm import aiofiles @@ -121,7 +122,7 @@ class EmbeddingStore: self.store = {} - self.faiss_index = None + self.faiss_index: Any = None self.idx2hash = None @staticmethod @@ -158,6 +159,8 @@ class EmbeddingStore: from src.config.config import model_config from src.llm_models.utils_model import LLMRequest + assert model_config is not None + # 限制 chunk_size 和 max_workers 在合理范围内 chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE)) max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS)) @@ -402,6 +405,7 @@ class EmbeddingStore: def build_faiss_index(self) -> None: """重新构建Faiss索引,以余弦相似度为度量""" + assert global_config is not None # 获取所有的embedding array = [] self.idx2hash = {} diff --git a/src/chat/utils/utils_image.py b/src/chat/utils/utils_image.py index 967d5af08..dd2033122 100644 --- a/src/chat/utils/utils_image.py +++ b/src/chat/utils/utils_image.py @@ -344,14 +344,15 @@ class ImageManager: # --- 新的帧选择逻辑:均匀抽取4帧 --- num_frames = len(all_frames) if num_frames <= 4: - # 如果总帧数小于等于4,则全部选中 + # 如果总宽度小于等于4,则全部选中 selected_frames = all_frames + indices = list(range(num_frames)) else: # 使用linspace计算4个均匀分布的索引 indices = np.linspace(0, num_frames - 1, 4, dtype=int) selected_frames = [all_frames[i] for i in indices] - logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices if num_frames > 4 else list(range(num_frames))}") + logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices}") # --- 帧选择逻辑结束 --- # 如果选择后连一帧都没有(比如GIF只有一帧且后续处理失败?)或者原始GIF就没帧,也返回None diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 8ed85e2cc..f03159364 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -37,7 +37,7 @@ _locks_guard = asyncio.Lock() logger = get_logger("utils_video") -from inkfox import video +from inkfox import video # type: ignore class VideoAnalyzer: @@ -123,7 +123,6 @@ class VideoAnalyzer: # ---- 批量分析 ---- async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str: from src.llm_models.payload_content.message import MessageBuilder, RoleType - from src.llm_models.utils_model import RequestType prompt = self.batch_analysis_prompt.format( personality_core=self.personality_core, personality_side=self.personality_side @@ -139,12 +138,7 @@ class VideoAnalyzer: for b64, _ in frames: mb.add_image_content("jpeg", b64) message = mb.build() - model_info, api_provider, client = self.video_llm._select_model() - resp = await self.video_llm._execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, + resp = await self.video_llm.execute_with_messages( message_list=[message], temperature=None, max_tokens=None, diff --git a/src/chat/utils/utils_video_legacy.py b/src/chat/utils/utils_video_legacy.py index 91219d402..5bf3b769d 100644 --- a/src/chat/utils/utils_video_legacy.py +++ b/src/chat/utils/utils_video_legacy.py @@ -31,9 +31,9 @@ def _extract_frames_worker( max_image_size: int, frame_extraction_mode: str, frame_interval_seconds: float | None, -) -> list[Any] | list[tuple[str, str]]: +) -> list[tuple[str, float]] | list[tuple[str, str]]: """线程池中提取视频帧的工作函数""" - frames = [] + frames: list[tuple[str, float]] = [] try: cap = cv2.VideoCapture(video_path) fps = cap.get(cv2.CAP_PROP_FPS) @@ -42,7 +42,7 @@ def _extract_frames_worker( if frame_extraction_mode == "time_interval": # 新模式:按时间间隔抽帧 - time_interval = frame_interval_seconds + time_interval = frame_interval_seconds or 2.0 next_frame_time = 0.0 extracted_count = 0 # 初始化提取帧计数器 @@ -61,7 +61,7 @@ def _extract_frames_worker( # 调整图像大小 if max(pil_image.size) > max_image_size: ratio = max_image_size / max(pil_image.size) - new_size = tuple(int(dim * ratio) for dim in pil_image.size) + new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio)) pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) # 转换为base64 @@ -240,6 +240,7 @@ class LegacyVideoAnalyzer: estimated_frames = min(self.max_frames, total_frames // frame_interval + 1) else: estimated_frames = self.max_frames + frame_interval = 1 logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)") @@ -276,7 +277,7 @@ class LegacyVideoAnalyzer: return await self._extract_frames_fallback(video_path) logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)") - return frames + return frames # type: ignore except Exception as e: logger.error(f"线程池帧提取失败: {e}") @@ -315,7 +316,7 @@ class LegacyVideoAnalyzer: # 调整图像大小 if max(pil_image.size) > self.max_image_size: ratio = self.max_image_size / max(pil_image.size) - new_size = tuple(int(dim * ratio) for dim in pil_image.size) + new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio)) pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) # 转换为base64 @@ -463,11 +464,11 @@ class LegacyVideoAnalyzer: # logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片") # 获取模型信息和客户端 - model_info, api_provider, client = self.video_llm._select_model() + model_info, api_provider, client = self.video_llm._select_model() # type: ignore # logger.info(f"使用模型: {model_info.name} 进行多帧分析") # 直接执行多图片请求 - api_response = await self.video_llm._execute_request( + api_response = await self.video_llm._execute_request( # type: ignore api_provider=api_provider, client=client, request_type=RequestType.RESPONSE, diff --git a/src/common/cache_manager.py b/src/common/cache_manager.py index 6edb72169..9118adb40 100644 --- a/src/common/cache_manager.py +++ b/src/common/cache_manager.py @@ -38,6 +38,9 @@ class CacheManager: 初始化缓存管理器。 """ if not hasattr(self, "_initialized"): + assert global_config is not None + assert model_config is not None + self.default_ttl = default_ttl or 3600 self.semantic_cache_collection_name = "semantic_cache" @@ -87,6 +90,7 @@ class CacheManager: embedding_array = embedding_array.flatten() # 检查维度是否符合预期 + assert global_config is not None expected_dim = ( getattr(CacheManager, "embedding_dimension", None) or global_config.lpmm_knowledge.embedding_dimension diff --git a/src/common/config_helpers.py b/src/common/config_helpers.py index f5460fece..463edaff6 100644 --- a/src/common/config_helpers.py +++ b/src/common/config_helpers.py @@ -14,23 +14,29 @@ def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: boo candidates: list[int | None] = [] - try: - embedding_task = getattr(model_config.model_task_config, "embedding", None) - if embedding_task is not None: - candidates.append(getattr(embedding_task, "embedding_dimension", None)) - except Exception: + if model_config is not None: + try: + embedding_task = getattr(model_config.model_task_config, "embedding", None) + if embedding_task is not None: + candidates.append(getattr(embedding_task, "embedding_dimension", None)) + except Exception: + candidates.append(None) + else: candidates.append(None) - try: - candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None)) - except Exception: + if global_config is not None: + try: + candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None)) + except Exception: + candidates.append(None) + else: candidates.append(None) candidates.append(fallback) resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None) - if resolved and sync_global: + if resolved and sync_global and global_config is not None: try: if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved: global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined] diff --git a/src/common/database/utils/decorators.py b/src/common/database/utils/decorators.py index b2baad6db..091c3dc53 100644 --- a/src/common/database/utils/decorators.py +++ b/src/common/database/utils/decorators.py @@ -10,7 +10,7 @@ import asyncio import functools import hashlib import time -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Coroutine from typing import Any, ParamSpec, TypeVar from sqlalchemy.exc import DBAPIError, OperationalError @@ -82,7 +82,7 @@ def retry( return await session.execute(stmt) """ - def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: @functools.wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: last_exception = None @@ -130,7 +130,7 @@ def timeout(seconds: float): return await session.execute(complex_stmt) """ - def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: @functools.wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: try: @@ -166,7 +166,7 @@ def cached( return await query_user(user_id) """ - def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: @functools.wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # 延迟导入避免循环依赖 @@ -228,7 +228,7 @@ def measure_time(log_slow: float | None = None): return await session.execute(stmt) """ - def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: @functools.wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: start_time = time.perf_counter() @@ -270,7 +270,7 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True): 函数需要接受session参数 """ - def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: @functools.wraps(func) async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: # 查找session参数 @@ -335,7 +335,7 @@ def db_operation( return await complex_operation() """ - def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: + def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]: # 从内到外应用装饰器 wrapped = func diff --git a/src/common/logger.py b/src/common/logger.py index 896f9b26f..b4f2dc261 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -21,10 +21,15 @@ from structlog.typing import EventDict, WrappedLogger class DaemonQueueListener(QueueListener): """QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。""" - def _configure_listener(self): - super()._configure_listener() - if hasattr(self, "_thread") and self._thread is not None: # type: ignore[attr-defined] - self._thread.daemon = True # type: ignore[attr-defined] + def start(self): + """Start the listener. + This starts up a background thread to monitor the queue for + LogRecords to process. + """ + # 覆盖 start 方法以设置 daemon=True + # 注意:_monitor 是 QueueListener 的内部方法 + self._thread = threading.Thread(target=self._monitor, daemon=True) # type: ignore + self._thread.start() def stop(self): """停止监听器,避免在退出时无限期阻塞。""" diff --git a/src/common/message/api.py b/src/common/message/api.py index 34eb01fcb..754c3b793 100644 --- a/src/common/message/api.py +++ b/src/common/message/api.py @@ -1,4 +1,5 @@ import os +from typing import Any from mofox_wire import MessageServer @@ -18,6 +19,8 @@ def get_global_api() -> MessageServer: if global_api is not None: return global_api + assert global_config is not None + bus_config = global_config.message_bus host = os.getenv("HOST", "127.0.0.1") port_str = os.getenv("PORT", "8000") @@ -27,7 +30,7 @@ def get_global_api() -> MessageServer: except ValueError: port = 8000 - kwargs: dict[str, object] = { + kwargs: dict[str, Any] = { "host": host, "port": port, "app": get_global_server().get_app(), diff --git a/src/common/message_repository.py b/src/common/message_repository.py index 14be6640a..b74b76f20 100644 --- a/src/common/message_repository.py +++ b/src/common/message_repository.py @@ -52,6 +52,7 @@ async def find_messages( 消息字典列表,如果出错则返回空列表。 """ try: + assert global_config is not None async with get_db_session() as session: query = select(Messages) diff --git a/src/common/remote.py b/src/common/remote.py index f6396a037..7f76e8f84 100644 --- a/src/common/remote.py +++ b/src/common/remote.py @@ -42,6 +42,7 @@ class TelemetryHeartBeatTask(AsyncTask): @staticmethod def _get_sys_info() -> dict[str, str]: """获取系统信息""" + assert global_config is not None info_dict = { "os_type": "Unknown", "py_version": platform.python_version(), diff --git a/src/common/security.py b/src/common/security.py index b151dfd09..104e1cf94 100644 --- a/src/common/security.py +++ b/src/common/security.py @@ -16,6 +16,7 @@ async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str: FastAPI 依赖项,用于验证API密钥。 从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。 """ + assert bot_config is not None valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys if not valid_keys: logger.warning("API密钥认证已启用,但未配置任何有效的API密钥。所有请求都将被拒绝。") diff --git a/src/common/server.py b/src/common/server.py index 527663be2..15f5de16a 100644 --- a/src/common/server.py +++ b/src/common/server.py @@ -30,6 +30,7 @@ def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response: class Server: def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"): + assert bot_config is not None # 根据配置初始化速率限制器 limiter = Limiter( key_func=get_remote_address, diff --git a/src/config/config.py b/src/config/config.py index 13f352ed3..98ae97646 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -176,7 +176,7 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo if key not in reference: del target[key] elif isinstance(target.get(key), dict | Table) and isinstance(reference.get(key), dict | Table): - _remove_obsolete_keys(target[key], reference[key]) + _remove_obsolete_keys(target[key], reference[key]) # type: ignore def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): @@ -433,9 +433,9 @@ class Config(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase): """API Adapter配置类""" - models: list[ModelInfo] = Field(..., min_items=1, description="模型列表") + models: list[ModelInfo] = Field(..., min_length=1, description="模型列表") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") - api_providers: list[APIProvider] = Field(..., min_items=1, description="API提供商列表") + api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表") def __init__(self, **data): super().__init__(**data) diff --git a/src/config/config_base.py b/src/config/config_base.py index a80740a46..551326fa3 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -177,7 +177,9 @@ class ValidatedConfigBase(BaseModel): element_index = field_path[1] # 尝试获取父字段的类型信息 - parent_field_info = cls.model_fields.get(parent_field) + parent_field_info = None + if isinstance(parent_field, str): + parent_field_info = cls.model_fields.get(parent_field) if parent_field_info and hasattr(parent_field_info, "annotation"): expected_type = parent_field_info.annotation @@ -214,7 +216,9 @@ class ValidatedConfigBase(BaseModel): # 处理模型类型错误 elif error_type in ["model_type", "dict_type", "is_instance_of"]: field_name = field_path[0] if field_path else "unknown" - field_info = cls.model_fields.get(field_name) + field_info = None + if isinstance(field_name, str): + field_info = cls.model_fields.get(field_name) if field_info and hasattr(field_info, "annotation"): expected_type = field_info.annotation diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 63f4fa4bc..c0b8a5e0c 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -638,15 +638,15 @@ class PlanningSystemConfig(ValidatedConfigBase): """规划系统配置 (日程与月度计划)""" # --- 日程生成 (原 ScheduleConfig) --- - schedule_enable: bool = Field(True, description="是否启用每日日程生成功能") - schedule_guidelines: str = Field("", description="日程生成指导原则") + schedule_enable: bool = Field(default=True, description="是否启用每日日程生成功能") + schedule_guidelines: str = Field(default="", description="日程生成指导原则") # --- 月度计划 (原 MonthlyPlanSystemConfig) --- - monthly_plan_enable: bool = Field(True, description="是否启用月度计划系统") - monthly_plan_guidelines: str = Field("", description="月度计划生成指导原则") - max_plans_per_month: int = Field(10, description="每月最多生成的计划数量") - avoid_repetition_days: int = Field(7, description="避免在多少天内重复使用同一个月度计划") - completion_threshold: int = Field(3, description="一个月度计划被使用多少次后算作完成") + monthly_plan_enable: bool = Field(default=True, description="是否启用月度计划系统") + monthly_plan_guidelines: str = Field(default="", description="月度计划生成指导原则") + max_plans_per_month: int = Field(default=10, description="每月最多生成的计划数量") + avoid_repetition_days: int = Field(default=7, description="避免在多少天内重复使用同一个月度计划") + completion_threshold: int = Field(default=3, description="一个月度计划被使用多少次后算作完成") class DependencyManagementConfig(ValidatedConfigBase): diff --git a/src/individuality/individuality.py b/src/individuality/individuality.py index 3ef490e57..36806e0bb 100644 --- a/src/individuality/individuality.py +++ b/src/individuality/individuality.py @@ -6,10 +6,18 @@ import orjson from rich.traceback import install from src.common.logger import get_logger -from src.config.config import global_config, model_config +from src.config.config import global_config as _global_config, model_config as _model_config from src.llm_models.utils_model import LLMRequest from src.person_info.person_info import get_person_info_manager +if _global_config is None: + raise ValueError("global_config is not initialized") +if _model_config is None: + raise ValueError("model_config is not initialized") + +global_config = _global_config +model_config = _model_config + install(extra_lines=3) logger = get_logger("individuality") diff --git a/src/llm_models/utils_model.py b/src/llm_models/utils_model.py index 1601c2eb9..f45b31041 100644 --- a/src/llm_models/utils_model.py +++ b/src/llm_models/utils_model.py @@ -1152,6 +1152,25 @@ class LLMRequest: return embeddings, model_info.name # type: ignore[return-value] + async def execute_with_messages( + self, + message_list: list[Message], + temperature: float | None = None, + max_tokens: int | None = None, + ) -> APIResponse: + """ + 使用自定义消息列表执行请求(支持多模态/多图)。 + """ + start_time = time.time() + response, model_info = await self._strategy.execute_with_failover( + RequestType.RESPONSE, + message_list=message_list, + temperature=temperature, + max_tokens=max_tokens, + ) + await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions") + return response + async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str): """ 记录模型使用情况。