Files
Mofox-Core/src/config/api_ada_configs.py
Windpicker-owo dc3ad19809 feat: 采用三层内存系统实现统一内存管理器
- 添加了UnifiedMemoryManager,以整合感知层、短期记忆层和长期记忆层。
- 实现了初始化、消息添加和内存搜索功能。
- 引入了记忆从短期存储到长期存储的自动转移机制。
- 开发了用于结构化内存表示的内存格式化工具。
- 增强日志记录功能,以便在内存操作过程中更好地进行追踪。
2025-11-18 16:17:25 +08:00

235 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from threading import Lock
from typing import Any, Literal
from pydantic import Field
from src.config.config_base import ValidatedConfigBase
class APIProvider(ValidatedConfigBase):
"""API提供商配置类"""
name: str = Field(..., min_length=1, description="API提供商名称")
base_url: str = Field(..., description="API基础URL")
api_key: str | list[str] = Field(..., min_length=1, description="API密钥支持单个密钥或密钥列表轮询")
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(
default="openai", description="客户端类型如openai/google等默认为openai"
)
max_retry: int = Field(default=2, ge=0, description="最大重试次数单个模型API调用失败最多重试的次数")
timeout: int = Field(
default=10, ge=1, description="API调用的超时时长超过这个时长本次请求将被视为'请求超时',单位:秒)"
)
retry_interval: int = Field(default=10, ge=0, description="重试间隔如果API调用失败重试的间隔时间单位")
@classmethod
def validate_base_url(cls, v):
"""验证base_url确保URL格式正确"""
if v and not (v.startswith("http://") or v.startswith("https://")):
raise ValueError("base_url必须以http://或https://开头")
return v
@classmethod
def validate_api_key(cls, v):
"""验证API密钥不能为空"""
if isinstance(v, str):
if not v.strip():
raise ValueError("API密钥不能为空")
elif isinstance(v, list):
if not v:
raise ValueError("API密钥列表不能为空")
for key in v:
if not isinstance(key, str) or not key.strip():
raise ValueError("API密钥列表中的密钥不能为空")
else:
raise ValueError("API密钥必须是字符串或字符串列表")
return v
def __init__(self, **data):
super().__init__(**data)
self._api_key_lock = Lock()
self._api_key_index = 0
def get_api_key(self) -> str:
with self._api_key_lock:
if isinstance(self.api_key, str):
return self.api_key
if not self.api_key:
raise ValueError("API密钥列表为空")
key = self.api_key[self._api_key_index]
self._api_key_index = (self._api_key_index + 1) % len(self.api_key)
return key
class ModelInfo(ValidatedConfigBase):
"""单个模型信息配置类"""
model_identifier: str = Field(..., min_length=1, description="模型标识符用于URL调用")
name: str = Field(..., min_length=1, description="模型名称(用于模块调用)")
api_provider: str = Field(..., min_length=1, description="API提供商如OpenAI、Azure等")
price_in: float = Field(default=0.0, ge=0, description="每M token输入价格")
price_out: float = Field(default=0.0, ge=0, description="每M token输出价格")
force_stream_mode: bool = Field(default=False, description="是否强制使用流式输出模式")
extra_params: dict[str, Any] = Field(default_factory=dict, description="额外参数用于API调用时的额外配置")
anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断")
enable_prompt_perturbation: bool = Field(default=False, description="是否启用提示词扰动(合并了内容混淆和注意力优化)")
perturbation_strength: Literal["light", "medium", "heavy"] = Field(
default="light", description="扰动强度light/medium/heavy"
)
enable_semantic_variants: bool = Field(default=False, description="是否启用语义变体作为扰动策略")
@classmethod
def validate_prices(cls, v):
"""验证价格必须为非负数"""
if v < 0:
raise ValueError("价格不能为负数")
return v
@classmethod
def validate_model_identifier(cls, v):
"""验证模型标识符不能为空且不能包含特殊字符"""
if not v or not v.strip():
raise ValueError("模型标识符不能为空")
# 检查是否包含危险字符
if any(char in v for char in [" ", "\n", "\t", "\r"]):
raise ValueError("模型标识符不能包含空格或换行符")
return v
@classmethod
def validate_name(cls, v):
"""验证模型名称不能为空"""
if not v or not v.strip():
raise ValueError("模型名称不能为空")
return v
class TaskConfig(ValidatedConfigBase):
"""任务配置类"""
model_list: list[str] = Field(..., description="任务使用的模型列表")
max_tokens: int = Field(default=800, description="任务最大输出token数")
temperature: float = Field(default=0.7, description="模型温度")
concurrency_count: int = Field(default=1, description="并发请求数量")
embedding_dimension: int | None = Field(
default=None,
description="嵌入模型输出向量维度,仅在嵌入任务中使用",
ge=1,
)
@classmethod
def validate_model_list(cls, v):
"""验证模型列表不能为空"""
if not v:
raise ValueError("模型列表不能为空")
if len(v) != len(set(v)):
raise ValueError("模型列表中不能有重复的模型")
return v
class ModelTaskConfig(ValidatedConfigBase):
"""模型配置类"""
# 必需配置项
utils: TaskConfig = Field(..., description="组件模型配置")
utils_small: TaskConfig = Field(..., description="组件小模型配置")
replyer: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置")
maizone: TaskConfig = Field(..., description="maizone专用模型")
emotion: TaskConfig = Field(..., description="情绪模型配置")
vlm: TaskConfig = Field(..., description="视觉语言模型配置")
voice: TaskConfig = Field(..., description="语音识别模型配置")
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
planner: TaskConfig = Field(..., description="规划模型配置")
embedding: TaskConfig = Field(..., description="嵌入模型配置")
lpmm_entity_extract: TaskConfig = Field(..., description="LPMM实体提取模型配置")
lpmm_rdf_build: TaskConfig = Field(..., description="LPMM RDF构建模型配置")
lpmm_qa: TaskConfig = Field(..., description="LPMM问答模型配置")
schedule_generator: TaskConfig = Field(..., description="日程生成模型配置")
monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置")
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
relationship_tracker: TaskConfig = Field(..., description="关系追踪模型配置")
# 处理配置文件中命名不一致的问题
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
# 记忆系统专用模型配置
memory_short_term_builder: TaskConfig = Field(..., description="短期记忆构建模型配置(感知→短期格式化)")
memory_short_term_decider: TaskConfig = Field(..., description="短期记忆决策模型配置(合并/更新/新建/丢弃)")
memory_long_term_builder: TaskConfig = Field(..., description="长期记忆构建模型配置(短期→长期图结构)")
memory_judge: TaskConfig = Field(..., description="记忆检索裁判模型配置(判断检索是否充足)")
@property
def video_analysis(self) -> TaskConfig:
"""视频分析模型配置(提供向后兼容的属性访问)"""
return self.utils_video
def get_task(self, task_name: str) -> TaskConfig:
"""获取指定任务的配置"""
# 处理向后兼容性如果请求video_analysis返回utils_video
if task_name == "video_analysis":
task_name = "utils_video"
if hasattr(self, task_name):
config = getattr(self, task_name)
if config is None:
raise ValueError(f"任务 '{task_name}' 未配置")
return config
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
def __init__(self, **data):
super().__init__(**data)
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
self.models_dict = {model.name: model for model in self.models}
@classmethod
def validate_models_list(cls, v):
"""验证模型列表"""
if not v:
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
# 检查模型名称是否重复
model_names = [model.name for model in v]
if len(model_names) != len(set(model_names)):
raise ValueError("模型名称存在重复,请检查配置文件。")
# 检查模型标识符是否有效
for model in v:
if not model.model_identifier:
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
return v
@classmethod
def validate_api_providers_list(cls, v):
"""验证API提供商列表"""
if not v:
raise ValueError("API提供商列表不能为空请在配置中设置有效的API提供商列表。")
# 检查API提供商名称是否重复
provider_names = [provider.name for provider in v]
if len(provider_names) != len(set(provider_names)):
raise ValueError("API提供商名称存在重复请检查配置文件。")
return v
def get_model_info(self, model_name: str) -> ModelInfo:
"""根据模型名称获取模型信息"""
if not model_name:
raise ValueError("模型名称不能为空")
if model_name not in self.models_dict:
raise KeyError(f"模型 '{model_name}' 不存在")
return self.models_dict[model_name]
def get_provider(self, provider_name: str) -> APIProvider:
"""根据提供商名称获取API提供商信息"""
if not provider_name:
raise ValueError("API提供商名称不能为空")
if provider_name not in self.api_providers_dict:
raise KeyError(f"API提供商 '{provider_name}' 不存在")
return self.api_providers_dict[provider_name]