依旧修pyright喵喵喵~
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
# import tqdm
|
||||
import aiofiles
|
||||
@@ -121,7 +122,7 @@ class EmbeddingStore:
|
||||
|
||||
self.store = {}
|
||||
|
||||
self.faiss_index = None
|
||||
self.faiss_index: Any = None
|
||||
self.idx2hash = None
|
||||
|
||||
@staticmethod
|
||||
@@ -158,6 +159,8 @@ class EmbeddingStore:
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
assert model_config is not None
|
||||
|
||||
# 限制 chunk_size 和 max_workers 在合理范围内
|
||||
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
||||
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
||||
@@ -402,6 +405,7 @@ class EmbeddingStore:
|
||||
|
||||
def build_faiss_index(self) -> None:
|
||||
"""重新构建Faiss索引,以余弦相似度为度量"""
|
||||
assert global_config is not None
|
||||
# 获取所有的embedding
|
||||
array = []
|
||||
self.idx2hash = {}
|
||||
|
||||
@@ -344,14 +344,15 @@ class ImageManager:
|
||||
# --- 新的帧选择逻辑:均匀抽取4帧 ---
|
||||
num_frames = len(all_frames)
|
||||
if num_frames <= 4:
|
||||
# 如果总帧数小于等于4,则全部选中
|
||||
# 如果总宽度小于等于4,则全部选中
|
||||
selected_frames = all_frames
|
||||
indices = list(range(num_frames))
|
||||
else:
|
||||
# 使用linspace计算4个均匀分布的索引
|
||||
indices = np.linspace(0, num_frames - 1, 4, dtype=int)
|
||||
selected_frames = [all_frames[i] for i in indices]
|
||||
|
||||
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices if num_frames > 4 else list(range(num_frames))}")
|
||||
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices}")
|
||||
# --- 帧选择逻辑结束 ---
|
||||
|
||||
# 如果选择后连一帧都没有(比如GIF只有一帧且后续处理失败?)或者原始GIF就没帧,也返回None
|
||||
|
||||
@@ -37,7 +37,7 @@ _locks_guard = asyncio.Lock()
|
||||
|
||||
logger = get_logger("utils_video")
|
||||
|
||||
from inkfox import video
|
||||
from inkfox import video # type: ignore
|
||||
|
||||
|
||||
class VideoAnalyzer:
|
||||
@@ -123,7 +123,6 @@ class VideoAnalyzer:
|
||||
# ---- 批量分析 ----
|
||||
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||
from src.llm_models.utils_model import RequestType
|
||||
|
||||
prompt = self.batch_analysis_prompt.format(
|
||||
personality_core=self.personality_core, personality_side=self.personality_side
|
||||
@@ -139,12 +138,7 @@ class VideoAnalyzer:
|
||||
for b64, _ in frames:
|
||||
mb.add_image_content("jpeg", b64)
|
||||
message = mb.build()
|
||||
model_info, api_provider, client = self.video_llm._select_model()
|
||||
resp = await self.video_llm._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
resp = await self.video_llm.execute_with_messages(
|
||||
message_list=[message],
|
||||
temperature=None,
|
||||
max_tokens=None,
|
||||
|
||||
@@ -31,9 +31,9 @@ def _extract_frames_worker(
|
||||
max_image_size: int,
|
||||
frame_extraction_mode: str,
|
||||
frame_interval_seconds: float | None,
|
||||
) -> list[Any] | list[tuple[str, str]]:
|
||||
) -> list[tuple[str, float]] | list[tuple[str, str]]:
|
||||
"""线程池中提取视频帧的工作函数"""
|
||||
frames = []
|
||||
frames: list[tuple[str, float]] = []
|
||||
try:
|
||||
cap = cv2.VideoCapture(video_path)
|
||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||
@@ -42,7 +42,7 @@ def _extract_frames_worker(
|
||||
|
||||
if frame_extraction_mode == "time_interval":
|
||||
# 新模式:按时间间隔抽帧
|
||||
time_interval = frame_interval_seconds
|
||||
time_interval = frame_interval_seconds or 2.0
|
||||
next_frame_time = 0.0
|
||||
extracted_count = 0 # 初始化提取帧计数器
|
||||
|
||||
@@ -61,7 +61,7 @@ def _extract_frames_worker(
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > max_image_size:
|
||||
ratio = max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为base64
|
||||
@@ -240,6 +240,7 @@ class LegacyVideoAnalyzer:
|
||||
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
|
||||
else:
|
||||
estimated_frames = self.max_frames
|
||||
frame_interval = 1
|
||||
|
||||
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
|
||||
|
||||
@@ -276,7 +277,7 @@ class LegacyVideoAnalyzer:
|
||||
return await self._extract_frames_fallback(video_path)
|
||||
|
||||
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
|
||||
return frames
|
||||
return frames # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"线程池帧提取失败: {e}")
|
||||
@@ -315,7 +316,7 @@ class LegacyVideoAnalyzer:
|
||||
# 调整图像大小
|
||||
if max(pil_image.size) > self.max_image_size:
|
||||
ratio = self.max_image_size / max(pil_image.size)
|
||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
||||
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
|
||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||
|
||||
# 转换为base64
|
||||
@@ -463,11 +464,11 @@ class LegacyVideoAnalyzer:
|
||||
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
||||
|
||||
# 获取模型信息和客户端
|
||||
model_info, api_provider, client = self.video_llm._select_model()
|
||||
model_info, api_provider, client = self.video_llm._select_model() # type: ignore
|
||||
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
||||
|
||||
# 直接执行多图片请求
|
||||
api_response = await self.video_llm._execute_request(
|
||||
api_response = await self.video_llm._execute_request( # type: ignore
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
|
||||
@@ -38,6 +38,9 @@ class CacheManager:
|
||||
初始化缓存管理器。
|
||||
"""
|
||||
if not hasattr(self, "_initialized"):
|
||||
assert global_config is not None
|
||||
assert model_config is not None
|
||||
|
||||
self.default_ttl = default_ttl or 3600
|
||||
self.semantic_cache_collection_name = "semantic_cache"
|
||||
|
||||
@@ -87,6 +90,7 @@ class CacheManager:
|
||||
embedding_array = embedding_array.flatten()
|
||||
|
||||
# 检查维度是否符合预期
|
||||
assert global_config is not None
|
||||
expected_dim = (
|
||||
getattr(CacheManager, "embedding_dimension", None)
|
||||
or global_config.lpmm_knowledge.embedding_dimension
|
||||
|
||||
@@ -14,23 +14,29 @@ def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: boo
|
||||
|
||||
candidates: list[int | None] = []
|
||||
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task is not None:
|
||||
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||
except Exception:
|
||||
if model_config is not None:
|
||||
try:
|
||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||
if embedding_task is not None:
|
||||
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
else:
|
||||
candidates.append(None)
|
||||
|
||||
try:
|
||||
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||
except Exception:
|
||||
if global_config is not None:
|
||||
try:
|
||||
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||
except Exception:
|
||||
candidates.append(None)
|
||||
else:
|
||||
candidates.append(None)
|
||||
|
||||
candidates.append(fallback)
|
||||
|
||||
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
||||
|
||||
if resolved and sync_global:
|
||||
if resolved and sync_global and global_config is not None:
|
||||
try:
|
||||
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
|
||||
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
|
||||
|
||||
@@ -10,7 +10,7 @@ import asyncio
|
||||
import functools
|
||||
import hashlib
|
||||
import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from collections.abc import Awaitable, Callable, Coroutine
|
||||
from typing import Any, ParamSpec, TypeVar
|
||||
|
||||
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||
@@ -82,7 +82,7 @@ def retry(
|
||||
return await session.execute(stmt)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
last_exception = None
|
||||
@@ -130,7 +130,7 @@ def timeout(seconds: float):
|
||||
return await session.execute(complex_stmt)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
try:
|
||||
@@ -166,7 +166,7 @@ def cached(
|
||||
return await query_user(user_id)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# 延迟导入避免循环依赖
|
||||
@@ -228,7 +228,7 @@ def measure_time(log_slow: float | None = None):
|
||||
return await session.execute(stmt)
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
start_time = time.perf_counter()
|
||||
@@ -270,7 +270,7 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
||||
函数需要接受session参数
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
@functools.wraps(func)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
# 查找session参数
|
||||
@@ -335,7 +335,7 @@ def db_operation(
|
||||
return await complex_operation()
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
||||
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
# 从内到外应用装饰器
|
||||
wrapped = func
|
||||
|
||||
|
||||
@@ -21,10 +21,15 @@ from structlog.typing import EventDict, WrappedLogger
|
||||
class DaemonQueueListener(QueueListener):
|
||||
"""QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。"""
|
||||
|
||||
def _configure_listener(self):
|
||||
super()._configure_listener()
|
||||
if hasattr(self, "_thread") and self._thread is not None: # type: ignore[attr-defined]
|
||||
self._thread.daemon = True # type: ignore[attr-defined]
|
||||
def start(self):
|
||||
"""Start the listener.
|
||||
This starts up a background thread to monitor the queue for
|
||||
LogRecords to process.
|
||||
"""
|
||||
# 覆盖 start 方法以设置 daemon=True
|
||||
# 注意:_monitor 是 QueueListener 的内部方法
|
||||
self._thread = threading.Thread(target=self._monitor, daemon=True) # type: ignore
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
"""停止监听器,避免在退出时无限期阻塞。"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from mofox_wire import MessageServer
|
||||
|
||||
@@ -18,6 +19,8 @@ def get_global_api() -> MessageServer:
|
||||
if global_api is not None:
|
||||
return global_api
|
||||
|
||||
assert global_config is not None
|
||||
|
||||
bus_config = global_config.message_bus
|
||||
host = os.getenv("HOST", "127.0.0.1")
|
||||
port_str = os.getenv("PORT", "8000")
|
||||
@@ -27,7 +30,7 @@ def get_global_api() -> MessageServer:
|
||||
except ValueError:
|
||||
port = 8000
|
||||
|
||||
kwargs: dict[str, object] = {
|
||||
kwargs: dict[str, Any] = {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"app": get_global_server().get_app(),
|
||||
|
||||
@@ -52,6 +52,7 @@ async def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
assert global_config is not None
|
||||
async with get_db_session() as session:
|
||||
query = select(Messages)
|
||||
|
||||
|
||||
@@ -42,6 +42,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
||||
@staticmethod
|
||||
def _get_sys_info() -> dict[str, str]:
|
||||
"""获取系统信息"""
|
||||
assert global_config is not None
|
||||
info_dict = {
|
||||
"os_type": "Unknown",
|
||||
"py_version": platform.python_version(),
|
||||
|
||||
@@ -16,6 +16,7 @@ async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str:
|
||||
FastAPI 依赖项,用于验证API密钥。
|
||||
从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。
|
||||
"""
|
||||
assert bot_config is not None
|
||||
valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys
|
||||
if not valid_keys:
|
||||
logger.warning("API密钥认证已启用,但未配置任何有效的API密钥。所有请求都将被拒绝。")
|
||||
|
||||
@@ -30,6 +30,7 @@ def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
|
||||
|
||||
class Server:
|
||||
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"):
|
||||
assert bot_config is not None
|
||||
# 根据配置初始化速率限制器
|
||||
limiter = Limiter(
|
||||
key_func=get_remote_address,
|
||||
|
||||
@@ -176,7 +176,7 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
|
||||
if key not in reference:
|
||||
del target[key]
|
||||
elif isinstance(target.get(key), dict | Table) and isinstance(reference.get(key), dict | Table):
|
||||
_remove_obsolete_keys(target[key], reference[key])
|
||||
_remove_obsolete_keys(target[key], reference[key]) # type: ignore
|
||||
|
||||
|
||||
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||
@@ -433,9 +433,9 @@ class Config(ValidatedConfigBase):
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
models: list[ModelInfo] = Field(..., min_items=1, description="模型列表")
|
||||
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||
api_providers: list[APIProvider] = Field(..., min_items=1, description="API提供商列表")
|
||||
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -177,7 +177,9 @@ class ValidatedConfigBase(BaseModel):
|
||||
element_index = field_path[1]
|
||||
|
||||
# 尝试获取父字段的类型信息
|
||||
parent_field_info = cls.model_fields.get(parent_field)
|
||||
parent_field_info = None
|
||||
if isinstance(parent_field, str):
|
||||
parent_field_info = cls.model_fields.get(parent_field)
|
||||
|
||||
if parent_field_info and hasattr(parent_field_info, "annotation"):
|
||||
expected_type = parent_field_info.annotation
|
||||
@@ -214,7 +216,9 @@ class ValidatedConfigBase(BaseModel):
|
||||
# 处理模型类型错误
|
||||
elif error_type in ["model_type", "dict_type", "is_instance_of"]:
|
||||
field_name = field_path[0] if field_path else "unknown"
|
||||
field_info = cls.model_fields.get(field_name)
|
||||
field_info = None
|
||||
if isinstance(field_name, str):
|
||||
field_info = cls.model_fields.get(field_name)
|
||||
|
||||
if field_info and hasattr(field_info, "annotation"):
|
||||
expected_type = field_info.annotation
|
||||
|
||||
@@ -638,15 +638,15 @@ class PlanningSystemConfig(ValidatedConfigBase):
|
||||
"""规划系统配置 (日程与月度计划)"""
|
||||
|
||||
# --- 日程生成 (原 ScheduleConfig) ---
|
||||
schedule_enable: bool = Field(True, description="是否启用每日日程生成功能")
|
||||
schedule_guidelines: str = Field("", description="日程生成指导原则")
|
||||
schedule_enable: bool = Field(default=True, description="是否启用每日日程生成功能")
|
||||
schedule_guidelines: str = Field(default="", description="日程生成指导原则")
|
||||
|
||||
# --- 月度计划 (原 MonthlyPlanSystemConfig) ---
|
||||
monthly_plan_enable: bool = Field(True, description="是否启用月度计划系统")
|
||||
monthly_plan_guidelines: str = Field("", description="月度计划生成指导原则")
|
||||
max_plans_per_month: int = Field(10, description="每月最多生成的计划数量")
|
||||
avoid_repetition_days: int = Field(7, description="避免在多少天内重复使用同一个月度计划")
|
||||
completion_threshold: int = Field(3, description="一个月度计划被使用多少次后算作完成")
|
||||
monthly_plan_enable: bool = Field(default=True, description="是否启用月度计划系统")
|
||||
monthly_plan_guidelines: str = Field(default="", description="月度计划生成指导原则")
|
||||
max_plans_per_month: int = Field(default=10, description="每月最多生成的计划数量")
|
||||
avoid_repetition_days: int = Field(default=7, description="避免在多少天内重复使用同一个月度计划")
|
||||
completion_threshold: int = Field(default=3, description="一个月度计划被使用多少次后算作完成")
|
||||
|
||||
|
||||
class DependencyManagementConfig(ValidatedConfigBase):
|
||||
|
||||
@@ -6,10 +6,18 @@ import orjson
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.config.config import global_config as _global_config, model_config as _model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.person_info.person_info import get_person_info_manager
|
||||
|
||||
if _global_config is None:
|
||||
raise ValueError("global_config is not initialized")
|
||||
if _model_config is None:
|
||||
raise ValueError("model_config is not initialized")
|
||||
|
||||
global_config = _global_config
|
||||
model_config = _model_config
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
logger = get_logger("individuality")
|
||||
|
||||
@@ -1152,6 +1152,25 @@ class LLMRequest:
|
||||
|
||||
return embeddings, model_info.name # type: ignore[return-value]
|
||||
|
||||
async def execute_with_messages(
|
||||
self,
|
||||
message_list: list[Message],
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> APIResponse:
|
||||
"""
|
||||
使用自定义消息列表执行请求(支持多模态/多图)。
|
||||
"""
|
||||
start_time = time.time()
|
||||
response, model_info = await self._strategy.execute_with_failover(
|
||||
RequestType.RESPONSE,
|
||||
message_list=message_list,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
)
|
||||
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||
return response
|
||||
|
||||
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
|
||||
"""
|
||||
记录模型使用情况。
|
||||
|
||||
Reference in New Issue
Block a user