依旧修pyright喵喵喵~

This commit is contained in:
ikun-11451
2025-11-29 22:20:55 +08:00
parent 574c2384a2
commit acafc074b1
18 changed files with 106 additions and 53 deletions

View File

@@ -2,6 +2,7 @@ import asyncio
import math import math
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
# import tqdm # import tqdm
import aiofiles import aiofiles
@@ -121,7 +122,7 @@ class EmbeddingStore:
self.store = {} self.store = {}
self.faiss_index = None self.faiss_index: Any = None
self.idx2hash = None self.idx2hash = None
@staticmethod @staticmethod
@@ -158,6 +159,8 @@ class EmbeddingStore:
from src.config.config import model_config from src.config.config import model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
assert model_config is not None
# 限制 chunk_size 和 max_workers 在合理范围内 # 限制 chunk_size 和 max_workers 在合理范围内
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE)) chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS)) max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
@@ -402,6 +405,7 @@ class EmbeddingStore:
def build_faiss_index(self) -> None: def build_faiss_index(self) -> None:
"""重新构建Faiss索引以余弦相似度为度量""" """重新构建Faiss索引以余弦相似度为度量"""
assert global_config is not None
# 获取所有的embedding # 获取所有的embedding
array = [] array = []
self.idx2hash = {} self.idx2hash = {}

View File

@@ -344,14 +344,15 @@ class ImageManager:
# --- 新的帧选择逻辑均匀抽取4帧 --- # --- 新的帧选择逻辑均匀抽取4帧 ---
num_frames = len(all_frames) num_frames = len(all_frames)
if num_frames <= 4: if num_frames <= 4:
# 如果总帧数小于等于4则全部选中 # 如果总宽度小于等于4则全部选中
selected_frames = all_frames selected_frames = all_frames
indices = list(range(num_frames))
else: else:
# 使用linspace计算4个均匀分布的索引 # 使用linspace计算4个均匀分布的索引
indices = np.linspace(0, num_frames - 1, 4, dtype=int) indices = np.linspace(0, num_frames - 1, 4, dtype=int)
selected_frames = [all_frames[i] for i in indices] selected_frames = [all_frames[i] for i in indices]
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices if num_frames > 4 else list(range(num_frames))}") logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices}")
# --- 帧选择逻辑结束 --- # --- 帧选择逻辑结束 ---
# 如果选择后连一帧都没有比如GIF只有一帧且后续处理失败或者原始GIF就没帧也返回None # 如果选择后连一帧都没有比如GIF只有一帧且后续处理失败或者原始GIF就没帧也返回None

View File

@@ -37,7 +37,7 @@ _locks_guard = asyncio.Lock()
logger = get_logger("utils_video") logger = get_logger("utils_video")
from inkfox import video from inkfox import video # type: ignore
class VideoAnalyzer: class VideoAnalyzer:
@@ -123,7 +123,6 @@ class VideoAnalyzer:
# ---- 批量分析 ---- # ---- 批量分析 ----
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str: async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
from src.llm_models.payload_content.message import MessageBuilder, RoleType from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType
prompt = self.batch_analysis_prompt.format( prompt = self.batch_analysis_prompt.format(
personality_core=self.personality_core, personality_side=self.personality_side personality_core=self.personality_core, personality_side=self.personality_side
@@ -139,12 +138,7 @@ class VideoAnalyzer:
for b64, _ in frames: for b64, _ in frames:
mb.add_image_content("jpeg", b64) mb.add_image_content("jpeg", b64)
message = mb.build() message = mb.build()
model_info, api_provider, client = self.video_llm._select_model() resp = await self.video_llm.execute_with_messages(
resp = await self.video_llm._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,
model_info=model_info,
message_list=[message], message_list=[message],
temperature=None, temperature=None,
max_tokens=None, max_tokens=None,

View File

@@ -31,9 +31,9 @@ def _extract_frames_worker(
max_image_size: int, max_image_size: int,
frame_extraction_mode: str, frame_extraction_mode: str,
frame_interval_seconds: float | None, frame_interval_seconds: float | None,
) -> list[Any] | list[tuple[str, str]]: ) -> list[tuple[str, float]] | list[tuple[str, str]]:
"""线程池中提取视频帧的工作函数""" """线程池中提取视频帧的工作函数"""
frames = [] frames: list[tuple[str, float]] = []
try: try:
cap = cv2.VideoCapture(video_path) cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS) fps = cap.get(cv2.CAP_PROP_FPS)
@@ -42,7 +42,7 @@ def _extract_frames_worker(
if frame_extraction_mode == "time_interval": if frame_extraction_mode == "time_interval":
# 新模式:按时间间隔抽帧 # 新模式:按时间间隔抽帧
time_interval = frame_interval_seconds time_interval = frame_interval_seconds or 2.0
next_frame_time = 0.0 next_frame_time = 0.0
extracted_count = 0 # 初始化提取帧计数器 extracted_count = 0 # 初始化提取帧计数器
@@ -61,7 +61,7 @@ def _extract_frames_worker(
# 调整图像大小 # 调整图像大小
if max(pil_image.size) > max_image_size: if max(pil_image.size) > max_image_size:
ratio = max_image_size / max(pil_image.size) ratio = max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size) new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为base64 # 转换为base64
@@ -240,6 +240,7 @@ class LegacyVideoAnalyzer:
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1) estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
else: else:
estimated_frames = self.max_frames estimated_frames = self.max_frames
frame_interval = 1
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)") logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
@@ -276,7 +277,7 @@ class LegacyVideoAnalyzer:
return await self._extract_frames_fallback(video_path) return await self._extract_frames_fallback(video_path)
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)") logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
return frames return frames # type: ignore
except Exception as e: except Exception as e:
logger.error(f"线程池帧提取失败: {e}") logger.error(f"线程池帧提取失败: {e}")
@@ -315,7 +316,7 @@ class LegacyVideoAnalyzer:
# 调整图像大小 # 调整图像大小
if max(pil_image.size) > self.max_image_size: if max(pil_image.size) > self.max_image_size:
ratio = self.max_image_size / max(pil_image.size) ratio = self.max_image_size / max(pil_image.size)
new_size = tuple(int(dim * ratio) for dim in pil_image.size) new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
# 转换为base64 # 转换为base64
@@ -463,11 +464,11 @@ class LegacyVideoAnalyzer:
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片") # logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
# 获取模型信息和客户端 # 获取模型信息和客户端
model_info, api_provider, client = self.video_llm._select_model() model_info, api_provider, client = self.video_llm._select_model() # type: ignore
# logger.info(f"使用模型: {model_info.name} 进行多帧分析") # logger.info(f"使用模型: {model_info.name} 进行多帧分析")
# 直接执行多图片请求 # 直接执行多图片请求
api_response = await self.video_llm._execute_request( api_response = await self.video_llm._execute_request( # type: ignore
api_provider=api_provider, api_provider=api_provider,
client=client, client=client,
request_type=RequestType.RESPONSE, request_type=RequestType.RESPONSE,

View File

@@ -38,6 +38,9 @@ class CacheManager:
初始化缓存管理器。 初始化缓存管理器。
""" """
if not hasattr(self, "_initialized"): if not hasattr(self, "_initialized"):
assert global_config is not None
assert model_config is not None
self.default_ttl = default_ttl or 3600 self.default_ttl = default_ttl or 3600
self.semantic_cache_collection_name = "semantic_cache" self.semantic_cache_collection_name = "semantic_cache"
@@ -87,6 +90,7 @@ class CacheManager:
embedding_array = embedding_array.flatten() embedding_array = embedding_array.flatten()
# 检查维度是否符合预期 # 检查维度是否符合预期
assert global_config is not None
expected_dim = ( expected_dim = (
getattr(CacheManager, "embedding_dimension", None) getattr(CacheManager, "embedding_dimension", None)
or global_config.lpmm_knowledge.embedding_dimension or global_config.lpmm_knowledge.embedding_dimension

View File

@@ -14,23 +14,29 @@ def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: boo
candidates: list[int | None] = [] candidates: list[int | None] = []
try: if model_config is not None:
embedding_task = getattr(model_config.model_task_config, "embedding", None) try:
if embedding_task is not None: embedding_task = getattr(model_config.model_task_config, "embedding", None)
candidates.append(getattr(embedding_task, "embedding_dimension", None)) if embedding_task is not None:
except Exception: candidates.append(getattr(embedding_task, "embedding_dimension", None))
except Exception:
candidates.append(None)
else:
candidates.append(None) candidates.append(None)
try: if global_config is not None:
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None)) try:
except Exception: candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
except Exception:
candidates.append(None)
else:
candidates.append(None) candidates.append(None)
candidates.append(fallback) candidates.append(fallback)
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None) resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
if resolved and sync_global: if resolved and sync_global and global_config is not None:
try: try:
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved: if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined] global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]

View File

@@ -10,7 +10,7 @@ import asyncio
import functools import functools
import hashlib import hashlib
import time import time
from collections.abc import Awaitable, Callable from collections.abc import Awaitable, Callable, Coroutine
from typing import Any, ParamSpec, TypeVar from typing import Any, ParamSpec, TypeVar
from sqlalchemy.exc import DBAPIError, OperationalError from sqlalchemy.exc import DBAPIError, OperationalError
@@ -82,7 +82,7 @@ def retry(
return await session.execute(stmt) return await session.execute(stmt)
""" """
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
last_exception = None last_exception = None
@@ -130,7 +130,7 @@ def timeout(seconds: float):
return await session.execute(complex_stmt) return await session.execute(complex_stmt)
""" """
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
try: try:
@@ -166,7 +166,7 @@ def cached(
return await query_user(user_id) return await query_user(user_id)
""" """
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 延迟导入避免循环依赖 # 延迟导入避免循环依赖
@@ -228,7 +228,7 @@ def measure_time(log_slow: float | None = None):
return await session.execute(stmt) return await session.execute(stmt)
""" """
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
start_time = time.perf_counter() start_time = time.perf_counter()
@@ -270,7 +270,7 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
函数需要接受session参数 函数需要接受session参数
""" """
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
@functools.wraps(func) @functools.wraps(func)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# 查找session参数 # 查找session参数
@@ -335,7 +335,7 @@ def db_operation(
return await complex_operation() return await complex_operation()
""" """
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]: def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
# 从内到外应用装饰器 # 从内到外应用装饰器
wrapped = func wrapped = func

View File

@@ -21,10 +21,15 @@ from structlog.typing import EventDict, WrappedLogger
class DaemonQueueListener(QueueListener): class DaemonQueueListener(QueueListener):
"""QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。""" """QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。"""
def _configure_listener(self): def start(self):
super()._configure_listener() """Start the listener.
if hasattr(self, "_thread") and self._thread is not None: # type: ignore[attr-defined] This starts up a background thread to monitor the queue for
self._thread.daemon = True # type: ignore[attr-defined] LogRecords to process.
"""
# 覆盖 start 方法以设置 daemon=True
# 注意_monitor 是 QueueListener 的内部方法
self._thread = threading.Thread(target=self._monitor, daemon=True) # type: ignore
self._thread.start()
def stop(self): def stop(self):
"""停止监听器,避免在退出时无限期阻塞。""" """停止监听器,避免在退出时无限期阻塞。"""

View File

@@ -1,4 +1,5 @@
import os import os
from typing import Any
from mofox_wire import MessageServer from mofox_wire import MessageServer
@@ -18,6 +19,8 @@ def get_global_api() -> MessageServer:
if global_api is not None: if global_api is not None:
return global_api return global_api
assert global_config is not None
bus_config = global_config.message_bus bus_config = global_config.message_bus
host = os.getenv("HOST", "127.0.0.1") host = os.getenv("HOST", "127.0.0.1")
port_str = os.getenv("PORT", "8000") port_str = os.getenv("PORT", "8000")
@@ -27,7 +30,7 @@ def get_global_api() -> MessageServer:
except ValueError: except ValueError:
port = 8000 port = 8000
kwargs: dict[str, object] = { kwargs: dict[str, Any] = {
"host": host, "host": host,
"port": port, "port": port,
"app": get_global_server().get_app(), "app": get_global_server().get_app(),

View File

@@ -52,6 +52,7 @@ async def find_messages(
消息字典列表,如果出错则返回空列表。 消息字典列表,如果出错则返回空列表。
""" """
try: try:
assert global_config is not None
async with get_db_session() as session: async with get_db_session() as session:
query = select(Messages) query = select(Messages)

View File

@@ -42,6 +42,7 @@ class TelemetryHeartBeatTask(AsyncTask):
@staticmethod @staticmethod
def _get_sys_info() -> dict[str, str]: def _get_sys_info() -> dict[str, str]:
"""获取系统信息""" """获取系统信息"""
assert global_config is not None
info_dict = { info_dict = {
"os_type": "Unknown", "os_type": "Unknown",
"py_version": platform.python_version(), "py_version": platform.python_version(),

View File

@@ -16,6 +16,7 @@ async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str:
FastAPI 依赖项用于验证API密钥。 FastAPI 依赖项用于验证API密钥。
从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。 从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。
""" """
assert bot_config is not None
valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys
if not valid_keys: if not valid_keys:
logger.warning("API密钥认证已启用但未配置任何有效的API密钥。所有请求都将被拒绝。") logger.warning("API密钥认证已启用但未配置任何有效的API密钥。所有请求都将被拒绝。")

View File

@@ -30,6 +30,7 @@ def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
class Server: class Server:
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"): def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"):
assert bot_config is not None
# 根据配置初始化速率限制器 # 根据配置初始化速率限制器
limiter = Limiter( limiter = Limiter(
key_func=get_remote_address, key_func=get_remote_address,

View File

@@ -176,7 +176,7 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
if key not in reference: if key not in reference:
del target[key] del target[key]
elif isinstance(target.get(key), dict | Table) and isinstance(reference.get(key), dict | Table): elif isinstance(target.get(key), dict | Table) and isinstance(reference.get(key), dict | Table):
_remove_obsolete_keys(target[key], reference[key]) _remove_obsolete_keys(target[key], reference[key]) # type: ignore
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict): def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
@@ -433,9 +433,9 @@ class Config(ValidatedConfigBase):
class APIAdapterConfig(ValidatedConfigBase): class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类""" """API Adapter配置类"""
models: list[ModelInfo] = Field(..., min_items=1, description="模型列表") models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置") model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_items=1, description="API提供商列表") api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
def __init__(self, **data): def __init__(self, **data):
super().__init__(**data) super().__init__(**data)

View File

@@ -177,7 +177,9 @@ class ValidatedConfigBase(BaseModel):
element_index = field_path[1] element_index = field_path[1]
# 尝试获取父字段的类型信息 # 尝试获取父字段的类型信息
parent_field_info = cls.model_fields.get(parent_field) parent_field_info = None
if isinstance(parent_field, str):
parent_field_info = cls.model_fields.get(parent_field)
if parent_field_info and hasattr(parent_field_info, "annotation"): if parent_field_info and hasattr(parent_field_info, "annotation"):
expected_type = parent_field_info.annotation expected_type = parent_field_info.annotation
@@ -214,7 +216,9 @@ class ValidatedConfigBase(BaseModel):
# 处理模型类型错误 # 处理模型类型错误
elif error_type in ["model_type", "dict_type", "is_instance_of"]: elif error_type in ["model_type", "dict_type", "is_instance_of"]:
field_name = field_path[0] if field_path else "unknown" field_name = field_path[0] if field_path else "unknown"
field_info = cls.model_fields.get(field_name) field_info = None
if isinstance(field_name, str):
field_info = cls.model_fields.get(field_name)
if field_info and hasattr(field_info, "annotation"): if field_info and hasattr(field_info, "annotation"):
expected_type = field_info.annotation expected_type = field_info.annotation

View File

@@ -638,15 +638,15 @@ class PlanningSystemConfig(ValidatedConfigBase):
"""规划系统配置 (日程与月度计划)""" """规划系统配置 (日程与月度计划)"""
# --- 日程生成 (原 ScheduleConfig) --- # --- 日程生成 (原 ScheduleConfig) ---
schedule_enable: bool = Field(True, description="是否启用每日日程生成功能") schedule_enable: bool = Field(default=True, description="是否启用每日日程生成功能")
schedule_guidelines: str = Field("", description="日程生成指导原则") schedule_guidelines: str = Field(default="", description="日程生成指导原则")
# --- 月度计划 (原 MonthlyPlanSystemConfig) --- # --- 月度计划 (原 MonthlyPlanSystemConfig) ---
monthly_plan_enable: bool = Field(True, description="是否启用月度计划系统") monthly_plan_enable: bool = Field(default=True, description="是否启用月度计划系统")
monthly_plan_guidelines: str = Field("", description="月度计划生成指导原则") monthly_plan_guidelines: str = Field(default="", description="月度计划生成指导原则")
max_plans_per_month: int = Field(10, description="每月最多生成的计划数量") max_plans_per_month: int = Field(default=10, description="每月最多生成的计划数量")
avoid_repetition_days: int = Field(7, description="避免在多少天内重复使用同一个月度计划") avoid_repetition_days: int = Field(default=7, description="避免在多少天内重复使用同一个月度计划")
completion_threshold: int = Field(3, description="一个月度计划被使用多少次后算作完成") completion_threshold: int = Field(default=3, description="一个月度计划被使用多少次后算作完成")
class DependencyManagementConfig(ValidatedConfigBase): class DependencyManagementConfig(ValidatedConfigBase):

View File

@@ -6,10 +6,18 @@ import orjson
from rich.traceback import install from rich.traceback import install
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config as _global_config, model_config as _model_config
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.person_info.person_info import get_person_info_manager from src.person_info.person_info import get_person_info_manager
if _global_config is None:
raise ValueError("global_config is not initialized")
if _model_config is None:
raise ValueError("model_config is not initialized")
global_config = _global_config
model_config = _model_config
install(extra_lines=3) install(extra_lines=3)
logger = get_logger("individuality") logger = get_logger("individuality")

View File

@@ -1152,6 +1152,25 @@ class LLMRequest:
return embeddings, model_info.name # type: ignore[return-value] return embeddings, model_info.name # type: ignore[return-value]
async def execute_with_messages(
self,
message_list: list[Message],
temperature: float | None = None,
max_tokens: int | None = None,
) -> APIResponse:
"""
使用自定义消息列表执行请求(支持多模态/多图)。
"""
start_time = time.time()
response, model_info = await self._strategy.execute_with_failover(
RequestType.RESPONSE,
message_list=message_list,
temperature=temperature,
max_tokens=max_tokens,
)
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
return response
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str): async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
""" """
记录模型使用情况。 记录模型使用情况。