依旧修pyright喵喵喵~
This commit is contained in:
@@ -2,6 +2,7 @@ import asyncio
|
|||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
# import tqdm
|
# import tqdm
|
||||||
import aiofiles
|
import aiofiles
|
||||||
@@ -121,7 +122,7 @@ class EmbeddingStore:
|
|||||||
|
|
||||||
self.store = {}
|
self.store = {}
|
||||||
|
|
||||||
self.faiss_index = None
|
self.faiss_index: Any = None
|
||||||
self.idx2hash = None
|
self.idx2hash = None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -158,6 +159,8 @@ class EmbeddingStore:
|
|||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
|
|
||||||
|
assert model_config is not None
|
||||||
|
|
||||||
# 限制 chunk_size 和 max_workers 在合理范围内
|
# 限制 chunk_size 和 max_workers 在合理范围内
|
||||||
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
chunk_size = max(MIN_CHUNK_SIZE, min(chunk_size, MAX_CHUNK_SIZE))
|
||||||
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
max_workers = max(MIN_WORKERS, min(max_workers, MAX_WORKERS))
|
||||||
@@ -402,6 +405,7 @@ class EmbeddingStore:
|
|||||||
|
|
||||||
def build_faiss_index(self) -> None:
|
def build_faiss_index(self) -> None:
|
||||||
"""重新构建Faiss索引,以余弦相似度为度量"""
|
"""重新构建Faiss索引,以余弦相似度为度量"""
|
||||||
|
assert global_config is not None
|
||||||
# 获取所有的embedding
|
# 获取所有的embedding
|
||||||
array = []
|
array = []
|
||||||
self.idx2hash = {}
|
self.idx2hash = {}
|
||||||
|
|||||||
@@ -344,14 +344,15 @@ class ImageManager:
|
|||||||
# --- 新的帧选择逻辑:均匀抽取4帧 ---
|
# --- 新的帧选择逻辑:均匀抽取4帧 ---
|
||||||
num_frames = len(all_frames)
|
num_frames = len(all_frames)
|
||||||
if num_frames <= 4:
|
if num_frames <= 4:
|
||||||
# 如果总帧数小于等于4,则全部选中
|
# 如果总宽度小于等于4,则全部选中
|
||||||
selected_frames = all_frames
|
selected_frames = all_frames
|
||||||
|
indices = list(range(num_frames))
|
||||||
else:
|
else:
|
||||||
# 使用linspace计算4个均匀分布的索引
|
# 使用linspace计算4个均匀分布的索引
|
||||||
indices = np.linspace(0, num_frames - 1, 4, dtype=int)
|
indices = np.linspace(0, num_frames - 1, 4, dtype=int)
|
||||||
selected_frames = [all_frames[i] for i in indices]
|
selected_frames = [all_frames[i] for i in indices]
|
||||||
|
|
||||||
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices if num_frames > 4 else list(range(num_frames))}")
|
logger.debug(f"GIF Frame Analysis: Total frames={num_frames}, Selected indices={indices}")
|
||||||
# --- 帧选择逻辑结束 ---
|
# --- 帧选择逻辑结束 ---
|
||||||
|
|
||||||
# 如果选择后连一帧都没有(比如GIF只有一帧且后续处理失败?)或者原始GIF就没帧,也返回None
|
# 如果选择后连一帧都没有(比如GIF只有一帧且后续处理失败?)或者原始GIF就没帧,也返回None
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ _locks_guard = asyncio.Lock()
|
|||||||
|
|
||||||
logger = get_logger("utils_video")
|
logger = get_logger("utils_video")
|
||||||
|
|
||||||
from inkfox import video
|
from inkfox import video # type: ignore
|
||||||
|
|
||||||
|
|
||||||
class VideoAnalyzer:
|
class VideoAnalyzer:
|
||||||
@@ -123,7 +123,6 @@ class VideoAnalyzer:
|
|||||||
# ---- 批量分析 ----
|
# ---- 批量分析 ----
|
||||||
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||||
from src.llm_models.utils_model import RequestType
|
|
||||||
|
|
||||||
prompt = self.batch_analysis_prompt.format(
|
prompt = self.batch_analysis_prompt.format(
|
||||||
personality_core=self.personality_core, personality_side=self.personality_side
|
personality_core=self.personality_core, personality_side=self.personality_side
|
||||||
@@ -139,12 +138,7 @@ class VideoAnalyzer:
|
|||||||
for b64, _ in frames:
|
for b64, _ in frames:
|
||||||
mb.add_image_content("jpeg", b64)
|
mb.add_image_content("jpeg", b64)
|
||||||
message = mb.build()
|
message = mb.build()
|
||||||
model_info, api_provider, client = self.video_llm._select_model()
|
resp = await self.video_llm.execute_with_messages(
|
||||||
resp = await self.video_llm._execute_request(
|
|
||||||
api_provider=api_provider,
|
|
||||||
client=client,
|
|
||||||
request_type=RequestType.RESPONSE,
|
|
||||||
model_info=model_info,
|
|
||||||
message_list=[message],
|
message_list=[message],
|
||||||
temperature=None,
|
temperature=None,
|
||||||
max_tokens=None,
|
max_tokens=None,
|
||||||
|
|||||||
@@ -31,9 +31,9 @@ def _extract_frames_worker(
|
|||||||
max_image_size: int,
|
max_image_size: int,
|
||||||
frame_extraction_mode: str,
|
frame_extraction_mode: str,
|
||||||
frame_interval_seconds: float | None,
|
frame_interval_seconds: float | None,
|
||||||
) -> list[Any] | list[tuple[str, str]]:
|
) -> list[tuple[str, float]] | list[tuple[str, str]]:
|
||||||
"""线程池中提取视频帧的工作函数"""
|
"""线程池中提取视频帧的工作函数"""
|
||||||
frames = []
|
frames: list[tuple[str, float]] = []
|
||||||
try:
|
try:
|
||||||
cap = cv2.VideoCapture(video_path)
|
cap = cv2.VideoCapture(video_path)
|
||||||
fps = cap.get(cv2.CAP_PROP_FPS)
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
||||||
@@ -42,7 +42,7 @@ def _extract_frames_worker(
|
|||||||
|
|
||||||
if frame_extraction_mode == "time_interval":
|
if frame_extraction_mode == "time_interval":
|
||||||
# 新模式:按时间间隔抽帧
|
# 新模式:按时间间隔抽帧
|
||||||
time_interval = frame_interval_seconds
|
time_interval = frame_interval_seconds or 2.0
|
||||||
next_frame_time = 0.0
|
next_frame_time = 0.0
|
||||||
extracted_count = 0 # 初始化提取帧计数器
|
extracted_count = 0 # 初始化提取帧计数器
|
||||||
|
|
||||||
@@ -61,7 +61,7 @@ def _extract_frames_worker(
|
|||||||
# 调整图像大小
|
# 调整图像大小
|
||||||
if max(pil_image.size) > max_image_size:
|
if max(pil_image.size) > max_image_size:
|
||||||
ratio = max_image_size / max(pil_image.size)
|
ratio = max_image_size / max(pil_image.size)
|
||||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
|
||||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
# 转换为base64
|
# 转换为base64
|
||||||
@@ -240,6 +240,7 @@ class LegacyVideoAnalyzer:
|
|||||||
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
|
estimated_frames = min(self.max_frames, total_frames // frame_interval + 1)
|
||||||
else:
|
else:
|
||||||
estimated_frames = self.max_frames
|
estimated_frames = self.max_frames
|
||||||
|
frame_interval = 1
|
||||||
|
|
||||||
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
|
logger.info(f"计算得出帧间隔: {frame_interval} (将提取约{estimated_frames}帧)")
|
||||||
|
|
||||||
@@ -276,7 +277,7 @@ class LegacyVideoAnalyzer:
|
|||||||
return await self._extract_frames_fallback(video_path)
|
return await self._extract_frames_fallback(video_path)
|
||||||
|
|
||||||
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
|
logger.info(f"✅ 成功提取{len(frames)}帧 (线程池模式)")
|
||||||
return frames
|
return frames # type: ignore
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"线程池帧提取失败: {e}")
|
logger.error(f"线程池帧提取失败: {e}")
|
||||||
@@ -315,7 +316,7 @@ class LegacyVideoAnalyzer:
|
|||||||
# 调整图像大小
|
# 调整图像大小
|
||||||
if max(pil_image.size) > self.max_image_size:
|
if max(pil_image.size) > self.max_image_size:
|
||||||
ratio = self.max_image_size / max(pil_image.size)
|
ratio = self.max_image_size / max(pil_image.size)
|
||||||
new_size = tuple(int(dim * ratio) for dim in pil_image.size)
|
new_size = (int(pil_image.size[0] * ratio), int(pil_image.size[1] * ratio))
|
||||||
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS)
|
||||||
|
|
||||||
# 转换为base64
|
# 转换为base64
|
||||||
@@ -463,11 +464,11 @@ class LegacyVideoAnalyzer:
|
|||||||
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
|
||||||
|
|
||||||
# 获取模型信息和客户端
|
# 获取模型信息和客户端
|
||||||
model_info, api_provider, client = self.video_llm._select_model()
|
model_info, api_provider, client = self.video_llm._select_model() # type: ignore
|
||||||
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
|
||||||
|
|
||||||
# 直接执行多图片请求
|
# 直接执行多图片请求
|
||||||
api_response = await self.video_llm._execute_request(
|
api_response = await self.video_llm._execute_request( # type: ignore
|
||||||
api_provider=api_provider,
|
api_provider=api_provider,
|
||||||
client=client,
|
client=client,
|
||||||
request_type=RequestType.RESPONSE,
|
request_type=RequestType.RESPONSE,
|
||||||
|
|||||||
@@ -38,6 +38,9 @@ class CacheManager:
|
|||||||
初始化缓存管理器。
|
初始化缓存管理器。
|
||||||
"""
|
"""
|
||||||
if not hasattr(self, "_initialized"):
|
if not hasattr(self, "_initialized"):
|
||||||
|
assert global_config is not None
|
||||||
|
assert model_config is not None
|
||||||
|
|
||||||
self.default_ttl = default_ttl or 3600
|
self.default_ttl = default_ttl or 3600
|
||||||
self.semantic_cache_collection_name = "semantic_cache"
|
self.semantic_cache_collection_name = "semantic_cache"
|
||||||
|
|
||||||
@@ -87,6 +90,7 @@ class CacheManager:
|
|||||||
embedding_array = embedding_array.flatten()
|
embedding_array = embedding_array.flatten()
|
||||||
|
|
||||||
# 检查维度是否符合预期
|
# 检查维度是否符合预期
|
||||||
|
assert global_config is not None
|
||||||
expected_dim = (
|
expected_dim = (
|
||||||
getattr(CacheManager, "embedding_dimension", None)
|
getattr(CacheManager, "embedding_dimension", None)
|
||||||
or global_config.lpmm_knowledge.embedding_dimension
|
or global_config.lpmm_knowledge.embedding_dimension
|
||||||
|
|||||||
@@ -14,23 +14,29 @@ def resolve_embedding_dimension(fallback: int | None = None, *, sync_global: boo
|
|||||||
|
|
||||||
candidates: list[int | None] = []
|
candidates: list[int | None] = []
|
||||||
|
|
||||||
try:
|
if model_config is not None:
|
||||||
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
try:
|
||||||
if embedding_task is not None:
|
embedding_task = getattr(model_config.model_task_config, "embedding", None)
|
||||||
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
if embedding_task is not None:
|
||||||
except Exception:
|
candidates.append(getattr(embedding_task, "embedding_dimension", None))
|
||||||
|
except Exception:
|
||||||
|
candidates.append(None)
|
||||||
|
else:
|
||||||
candidates.append(None)
|
candidates.append(None)
|
||||||
|
|
||||||
try:
|
if global_config is not None:
|
||||||
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
try:
|
||||||
except Exception:
|
candidates.append(getattr(global_config.lpmm_knowledge, "embedding_dimension", None))
|
||||||
|
except Exception:
|
||||||
|
candidates.append(None)
|
||||||
|
else:
|
||||||
candidates.append(None)
|
candidates.append(None)
|
||||||
|
|
||||||
candidates.append(fallback)
|
candidates.append(fallback)
|
||||||
|
|
||||||
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
resolved: int | None = next((int(dim) for dim in candidates if dim and int(dim) > 0), None)
|
||||||
|
|
||||||
if resolved and sync_global:
|
if resolved and sync_global and global_config is not None:
|
||||||
try:
|
try:
|
||||||
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
|
if getattr(global_config.lpmm_knowledge, "embedding_dimension", None) != resolved:
|
||||||
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
|
global_config.lpmm_knowledge.embedding_dimension = resolved # type: ignore[attr-defined]
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ import asyncio
|
|||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import time
|
import time
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable, Coroutine
|
||||||
from typing import Any, ParamSpec, TypeVar
|
from typing import Any, ParamSpec, TypeVar
|
||||||
|
|
||||||
from sqlalchemy.exc import DBAPIError, OperationalError
|
from sqlalchemy.exc import DBAPIError, OperationalError
|
||||||
@@ -82,7 +82,7 @@ def retry(
|
|||||||
return await session.execute(stmt)
|
return await session.execute(stmt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
last_exception = None
|
last_exception = None
|
||||||
@@ -130,7 +130,7 @@ def timeout(seconds: float):
|
|||||||
return await session.execute(complex_stmt)
|
return await session.execute(complex_stmt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
try:
|
try:
|
||||||
@@ -166,7 +166,7 @@ def cached(
|
|||||||
return await query_user(user_id)
|
return await query_user(user_id)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
# 延迟导入避免循环依赖
|
# 延迟导入避免循环依赖
|
||||||
@@ -228,7 +228,7 @@ def measure_time(log_slow: float | None = None):
|
|||||||
return await session.execute(stmt)
|
return await session.execute(stmt)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@@ -270,7 +270,7 @@ def transactional(auto_commit: bool = True, auto_rollback: bool = True):
|
|||||||
函数需要接受session参数
|
函数需要接受session参数
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||||
@functools.wraps(func)
|
@functools.wraps(func)
|
||||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||||
# 查找session参数
|
# 查找session参数
|
||||||
@@ -335,7 +335,7 @@ def db_operation(
|
|||||||
return await complex_operation()
|
return await complex_operation()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
|
def decorator(func: Callable[P, Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||||
# 从内到外应用装饰器
|
# 从内到外应用装饰器
|
||||||
wrapped = func
|
wrapped = func
|
||||||
|
|
||||||
|
|||||||
@@ -21,10 +21,15 @@ from structlog.typing import EventDict, WrappedLogger
|
|||||||
class DaemonQueueListener(QueueListener):
|
class DaemonQueueListener(QueueListener):
|
||||||
"""QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。"""
|
"""QueueListener 的工作线程作为守护进程运行,以避免阻塞关闭。"""
|
||||||
|
|
||||||
def _configure_listener(self):
|
def start(self):
|
||||||
super()._configure_listener()
|
"""Start the listener.
|
||||||
if hasattr(self, "_thread") and self._thread is not None: # type: ignore[attr-defined]
|
This starts up a background thread to monitor the queue for
|
||||||
self._thread.daemon = True # type: ignore[attr-defined]
|
LogRecords to process.
|
||||||
|
"""
|
||||||
|
# 覆盖 start 方法以设置 daemon=True
|
||||||
|
# 注意:_monitor 是 QueueListener 的内部方法
|
||||||
|
self._thread = threading.Thread(target=self._monitor, daemon=True) # type: ignore
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
def stop(self):
|
def stop(self):
|
||||||
"""停止监听器,避免在退出时无限期阻塞。"""
|
"""停止监听器,避免在退出时无限期阻塞。"""
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from mofox_wire import MessageServer
|
from mofox_wire import MessageServer
|
||||||
|
|
||||||
@@ -18,6 +19,8 @@ def get_global_api() -> MessageServer:
|
|||||||
if global_api is not None:
|
if global_api is not None:
|
||||||
return global_api
|
return global_api
|
||||||
|
|
||||||
|
assert global_config is not None
|
||||||
|
|
||||||
bus_config = global_config.message_bus
|
bus_config = global_config.message_bus
|
||||||
host = os.getenv("HOST", "127.0.0.1")
|
host = os.getenv("HOST", "127.0.0.1")
|
||||||
port_str = os.getenv("PORT", "8000")
|
port_str = os.getenv("PORT", "8000")
|
||||||
@@ -27,7 +30,7 @@ def get_global_api() -> MessageServer:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
port = 8000
|
port = 8000
|
||||||
|
|
||||||
kwargs: dict[str, object] = {
|
kwargs: dict[str, Any] = {
|
||||||
"host": host,
|
"host": host,
|
||||||
"port": port,
|
"port": port,
|
||||||
"app": get_global_server().get_app(),
|
"app": get_global_server().get_app(),
|
||||||
|
|||||||
@@ -52,6 +52,7 @@ async def find_messages(
|
|||||||
消息字典列表,如果出错则返回空列表。
|
消息字典列表,如果出错则返回空列表。
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
assert global_config is not None
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
query = select(Messages)
|
query = select(Messages)
|
||||||
|
|
||||||
|
|||||||
@@ -42,6 +42,7 @@ class TelemetryHeartBeatTask(AsyncTask):
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_sys_info() -> dict[str, str]:
|
def _get_sys_info() -> dict[str, str]:
|
||||||
"""获取系统信息"""
|
"""获取系统信息"""
|
||||||
|
assert global_config is not None
|
||||||
info_dict = {
|
info_dict = {
|
||||||
"os_type": "Unknown",
|
"os_type": "Unknown",
|
||||||
"py_version": platform.python_version(),
|
"py_version": platform.python_version(),
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str:
|
|||||||
FastAPI 依赖项,用于验证API密钥。
|
FastAPI 依赖项,用于验证API密钥。
|
||||||
从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。
|
从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。
|
||||||
"""
|
"""
|
||||||
|
assert bot_config is not None
|
||||||
valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys
|
valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys
|
||||||
if not valid_keys:
|
if not valid_keys:
|
||||||
logger.warning("API密钥认证已启用,但未配置任何有效的API密钥。所有请求都将被拒绝。")
|
logger.warning("API密钥认证已启用,但未配置任何有效的API密钥。所有请求都将被拒绝。")
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
|
|||||||
|
|
||||||
class Server:
|
class Server:
|
||||||
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"):
|
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"):
|
||||||
|
assert bot_config is not None
|
||||||
# 根据配置初始化速率限制器
|
# 根据配置初始化速率限制器
|
||||||
limiter = Limiter(
|
limiter = Limiter(
|
||||||
key_func=get_remote_address,
|
key_func=get_remote_address,
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ def _remove_obsolete_keys(target: TOMLDocument | dict | Table, reference: TOMLDo
|
|||||||
if key not in reference:
|
if key not in reference:
|
||||||
del target[key]
|
del target[key]
|
||||||
elif isinstance(target.get(key), dict | Table) and isinstance(reference.get(key), dict | Table):
|
elif isinstance(target.get(key), dict | Table) and isinstance(reference.get(key), dict | Table):
|
||||||
_remove_obsolete_keys(target[key], reference[key])
|
_remove_obsolete_keys(target[key], reference[key]) # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
def _update_dict(target: TOMLDocument | dict | Table, source: TOMLDocument | dict):
|
||||||
@@ -433,9 +433,9 @@ class Config(ValidatedConfigBase):
|
|||||||
class APIAdapterConfig(ValidatedConfigBase):
|
class APIAdapterConfig(ValidatedConfigBase):
|
||||||
"""API Adapter配置类"""
|
"""API Adapter配置类"""
|
||||||
|
|
||||||
models: list[ModelInfo] = Field(..., min_items=1, description="模型列表")
|
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||||
api_providers: list[APIProvider] = Field(..., min_items=1, description="API提供商列表")
|
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||||||
|
|
||||||
def __init__(self, **data):
|
def __init__(self, **data):
|
||||||
super().__init__(**data)
|
super().__init__(**data)
|
||||||
|
|||||||
@@ -177,7 +177,9 @@ class ValidatedConfigBase(BaseModel):
|
|||||||
element_index = field_path[1]
|
element_index = field_path[1]
|
||||||
|
|
||||||
# 尝试获取父字段的类型信息
|
# 尝试获取父字段的类型信息
|
||||||
parent_field_info = cls.model_fields.get(parent_field)
|
parent_field_info = None
|
||||||
|
if isinstance(parent_field, str):
|
||||||
|
parent_field_info = cls.model_fields.get(parent_field)
|
||||||
|
|
||||||
if parent_field_info and hasattr(parent_field_info, "annotation"):
|
if parent_field_info and hasattr(parent_field_info, "annotation"):
|
||||||
expected_type = parent_field_info.annotation
|
expected_type = parent_field_info.annotation
|
||||||
@@ -214,7 +216,9 @@ class ValidatedConfigBase(BaseModel):
|
|||||||
# 处理模型类型错误
|
# 处理模型类型错误
|
||||||
elif error_type in ["model_type", "dict_type", "is_instance_of"]:
|
elif error_type in ["model_type", "dict_type", "is_instance_of"]:
|
||||||
field_name = field_path[0] if field_path else "unknown"
|
field_name = field_path[0] if field_path else "unknown"
|
||||||
field_info = cls.model_fields.get(field_name)
|
field_info = None
|
||||||
|
if isinstance(field_name, str):
|
||||||
|
field_info = cls.model_fields.get(field_name)
|
||||||
|
|
||||||
if field_info and hasattr(field_info, "annotation"):
|
if field_info and hasattr(field_info, "annotation"):
|
||||||
expected_type = field_info.annotation
|
expected_type = field_info.annotation
|
||||||
|
|||||||
@@ -638,15 +638,15 @@ class PlanningSystemConfig(ValidatedConfigBase):
|
|||||||
"""规划系统配置 (日程与月度计划)"""
|
"""规划系统配置 (日程与月度计划)"""
|
||||||
|
|
||||||
# --- 日程生成 (原 ScheduleConfig) ---
|
# --- 日程生成 (原 ScheduleConfig) ---
|
||||||
schedule_enable: bool = Field(True, description="是否启用每日日程生成功能")
|
schedule_enable: bool = Field(default=True, description="是否启用每日日程生成功能")
|
||||||
schedule_guidelines: str = Field("", description="日程生成指导原则")
|
schedule_guidelines: str = Field(default="", description="日程生成指导原则")
|
||||||
|
|
||||||
# --- 月度计划 (原 MonthlyPlanSystemConfig) ---
|
# --- 月度计划 (原 MonthlyPlanSystemConfig) ---
|
||||||
monthly_plan_enable: bool = Field(True, description="是否启用月度计划系统")
|
monthly_plan_enable: bool = Field(default=True, description="是否启用月度计划系统")
|
||||||
monthly_plan_guidelines: str = Field("", description="月度计划生成指导原则")
|
monthly_plan_guidelines: str = Field(default="", description="月度计划生成指导原则")
|
||||||
max_plans_per_month: int = Field(10, description="每月最多生成的计划数量")
|
max_plans_per_month: int = Field(default=10, description="每月最多生成的计划数量")
|
||||||
avoid_repetition_days: int = Field(7, description="避免在多少天内重复使用同一个月度计划")
|
avoid_repetition_days: int = Field(default=7, description="避免在多少天内重复使用同一个月度计划")
|
||||||
completion_threshold: int = Field(3, description="一个月度计划被使用多少次后算作完成")
|
completion_threshold: int = Field(default=3, description="一个月度计划被使用多少次后算作完成")
|
||||||
|
|
||||||
|
|
||||||
class DependencyManagementConfig(ValidatedConfigBase):
|
class DependencyManagementConfig(ValidatedConfigBase):
|
||||||
|
|||||||
@@ -6,10 +6,18 @@ import orjson
|
|||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config as _global_config, model_config as _model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.person_info.person_info import get_person_info_manager
|
from src.person_info.person_info import get_person_info_manager
|
||||||
|
|
||||||
|
if _global_config is None:
|
||||||
|
raise ValueError("global_config is not initialized")
|
||||||
|
if _model_config is None:
|
||||||
|
raise ValueError("model_config is not initialized")
|
||||||
|
|
||||||
|
global_config = _global_config
|
||||||
|
model_config = _model_config
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
logger = get_logger("individuality")
|
logger = get_logger("individuality")
|
||||||
|
|||||||
@@ -1152,6 +1152,25 @@ class LLMRequest:
|
|||||||
|
|
||||||
return embeddings, model_info.name # type: ignore[return-value]
|
return embeddings, model_info.name # type: ignore[return-value]
|
||||||
|
|
||||||
|
async def execute_with_messages(
|
||||||
|
self,
|
||||||
|
message_list: list[Message],
|
||||||
|
temperature: float | None = None,
|
||||||
|
max_tokens: int | None = None,
|
||||||
|
) -> APIResponse:
|
||||||
|
"""
|
||||||
|
使用自定义消息列表执行请求(支持多模态/多图)。
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
response, model_info = await self._strategy.execute_with_failover(
|
||||||
|
RequestType.RESPONSE,
|
||||||
|
message_list=message_list,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
)
|
||||||
|
await self._record_usage(model_info, response.usage, time.time() - start_time, "/chat/completions")
|
||||||
|
return response
|
||||||
|
|
||||||
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
|
async def _record_usage(self, model_info: ModelInfo, usage: UsageRecord | None, time_cost: float, endpoint: str):
|
||||||
"""
|
"""
|
||||||
记录模型使用情况。
|
记录模型使用情况。
|
||||||
|
|||||||
Reference in New Issue
Block a user