- 添加了UnifiedMemoryManager,以整合感知层、短期记忆层和长期记忆层。 - 实现了初始化、消息添加和内存搜索功能。 - 引入了记忆从短期存储到长期存储的自动转移机制。 - 开发了用于结构化内存表示的内存格式化工具。 - 增强日志记录功能,以便在内存操作过程中更好地进行追踪。
235 lines
11 KiB
Python
235 lines
11 KiB
Python
from threading import Lock
|
||
from typing import Any, Literal
|
||
|
||
from pydantic import Field
|
||
|
||
from src.config.config_base import ValidatedConfigBase
|
||
|
||
|
||
class APIProvider(ValidatedConfigBase):
|
||
"""API提供商配置类"""
|
||
|
||
name: str = Field(..., min_length=1, description="API提供商名称")
|
||
base_url: str = Field(..., description="API基础URL")
|
||
api_key: str | list[str] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询")
|
||
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(
|
||
default="openai", description="客户端类型(如openai/google等,默认为openai)"
|
||
)
|
||
max_retry: int = Field(default=2, ge=0, description="最大重试次数(单个模型API调用失败,最多重试的次数)")
|
||
timeout: int = Field(
|
||
default=10, ge=1, description="API调用的超时时长(超过这个时长,本次请求将被视为'请求超时',单位:秒)"
|
||
)
|
||
retry_interval: int = Field(default=10, ge=0, description="重试间隔(如果API调用失败,重试的间隔时间,单位:秒)")
|
||
|
||
@classmethod
|
||
def validate_base_url(cls, v):
|
||
"""验证base_url,确保URL格式正确"""
|
||
if v and not (v.startswith("http://") or v.startswith("https://")):
|
||
raise ValueError("base_url必须以http://或https://开头")
|
||
return v
|
||
|
||
@classmethod
|
||
def validate_api_key(cls, v):
|
||
"""验证API密钥不能为空"""
|
||
if isinstance(v, str):
|
||
if not v.strip():
|
||
raise ValueError("API密钥不能为空")
|
||
elif isinstance(v, list):
|
||
if not v:
|
||
raise ValueError("API密钥列表不能为空")
|
||
for key in v:
|
||
if not isinstance(key, str) or not key.strip():
|
||
raise ValueError("API密钥列表中的密钥不能为空")
|
||
else:
|
||
raise ValueError("API密钥必须是字符串或字符串列表")
|
||
return v
|
||
|
||
def __init__(self, **data):
|
||
super().__init__(**data)
|
||
self._api_key_lock = Lock()
|
||
self._api_key_index = 0
|
||
|
||
def get_api_key(self) -> str:
|
||
with self._api_key_lock:
|
||
if isinstance(self.api_key, str):
|
||
return self.api_key
|
||
if not self.api_key:
|
||
raise ValueError("API密钥列表为空")
|
||
key = self.api_key[self._api_key_index]
|
||
self._api_key_index = (self._api_key_index + 1) % len(self.api_key)
|
||
return key
|
||
|
||
|
||
class ModelInfo(ValidatedConfigBase):
|
||
"""单个模型信息配置类"""
|
||
|
||
model_identifier: str = Field(..., min_length=1, description="模型标识符(用于URL调用)")
|
||
name: str = Field(..., min_length=1, description="模型名称(用于模块调用)")
|
||
api_provider: str = Field(..., min_length=1, description="API提供商(如OpenAI、Azure等)")
|
||
price_in: float = Field(default=0.0, ge=0, description="每M token输入价格")
|
||
price_out: float = Field(default=0.0, ge=0, description="每M token输出价格")
|
||
force_stream_mode: bool = Field(default=False, description="是否强制使用流式输出模式")
|
||
extra_params: dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)")
|
||
anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断")
|
||
enable_prompt_perturbation: bool = Field(default=False, description="是否启用提示词扰动(合并了内容混淆和注意力优化)")
|
||
perturbation_strength: Literal["light", "medium", "heavy"] = Field(
|
||
default="light", description="扰动强度(light/medium/heavy)"
|
||
)
|
||
enable_semantic_variants: bool = Field(default=False, description="是否启用语义变体作为扰动策略")
|
||
@classmethod
|
||
def validate_prices(cls, v):
|
||
"""验证价格必须为非负数"""
|
||
if v < 0:
|
||
raise ValueError("价格不能为负数")
|
||
return v
|
||
|
||
@classmethod
|
||
def validate_model_identifier(cls, v):
|
||
"""验证模型标识符不能为空且不能包含特殊字符"""
|
||
if not v or not v.strip():
|
||
raise ValueError("模型标识符不能为空")
|
||
# 检查是否包含危险字符
|
||
if any(char in v for char in [" ", "\n", "\t", "\r"]):
|
||
raise ValueError("模型标识符不能包含空格或换行符")
|
||
return v
|
||
|
||
@classmethod
|
||
def validate_name(cls, v):
|
||
"""验证模型名称不能为空"""
|
||
if not v or not v.strip():
|
||
raise ValueError("模型名称不能为空")
|
||
return v
|
||
|
||
|
||
class TaskConfig(ValidatedConfigBase):
|
||
"""任务配置类"""
|
||
|
||
model_list: list[str] = Field(..., description="任务使用的模型列表")
|
||
max_tokens: int = Field(default=800, description="任务最大输出token数")
|
||
temperature: float = Field(default=0.7, description="模型温度")
|
||
concurrency_count: int = Field(default=1, description="并发请求数量")
|
||
embedding_dimension: int | None = Field(
|
||
default=None,
|
||
description="嵌入模型输出向量维度,仅在嵌入任务中使用",
|
||
ge=1,
|
||
)
|
||
|
||
@classmethod
|
||
def validate_model_list(cls, v):
|
||
"""验证模型列表不能为空"""
|
||
if not v:
|
||
raise ValueError("模型列表不能为空")
|
||
if len(v) != len(set(v)):
|
||
raise ValueError("模型列表中不能有重复的模型")
|
||
return v
|
||
|
||
|
||
class ModelTaskConfig(ValidatedConfigBase):
|
||
"""模型配置类"""
|
||
|
||
# 必需配置项
|
||
utils: TaskConfig = Field(..., description="组件模型配置")
|
||
utils_small: TaskConfig = Field(..., description="组件小模型配置")
|
||
replyer: TaskConfig = Field(..., description="normal_chat首要回复模型模型配置")
|
||
maizone: TaskConfig = Field(..., description="maizone专用模型")
|
||
emotion: TaskConfig = Field(..., description="情绪模型配置")
|
||
vlm: TaskConfig = Field(..., description="视觉语言模型配置")
|
||
voice: TaskConfig = Field(..., description="语音识别模型配置")
|
||
tool_use: TaskConfig = Field(..., description="专注工具使用模型配置")
|
||
planner: TaskConfig = Field(..., description="规划模型配置")
|
||
embedding: TaskConfig = Field(..., description="嵌入模型配置")
|
||
lpmm_entity_extract: TaskConfig = Field(..., description="LPMM实体提取模型配置")
|
||
lpmm_rdf_build: TaskConfig = Field(..., description="LPMM RDF构建模型配置")
|
||
lpmm_qa: TaskConfig = Field(..., description="LPMM问答模型配置")
|
||
schedule_generator: TaskConfig = Field(..., description="日程生成模型配置")
|
||
monthly_plan_generator: TaskConfig = Field(..., description="月层计划生成模型配置")
|
||
emoji_vlm: TaskConfig = Field(..., description="表情包识别模型配置")
|
||
anti_injection: TaskConfig = Field(..., description="反注入检测专用模型配置")
|
||
relationship_tracker: TaskConfig = Field(..., description="关系追踪模型配置")
|
||
# 处理配置文件中命名不一致的问题
|
||
utils_video: TaskConfig = Field(..., description="视频分析模型配置(兼容配置文件中的命名)")
|
||
|
||
# 记忆系统专用模型配置
|
||
memory_short_term_builder: TaskConfig = Field(..., description="短期记忆构建模型配置(感知→短期格式化)")
|
||
memory_short_term_decider: TaskConfig = Field(..., description="短期记忆决策模型配置(合并/更新/新建/丢弃)")
|
||
memory_long_term_builder: TaskConfig = Field(..., description="长期记忆构建模型配置(短期→长期图结构)")
|
||
memory_judge: TaskConfig = Field(..., description="记忆检索裁判模型配置(判断检索是否充足)")
|
||
|
||
@property
|
||
def video_analysis(self) -> TaskConfig:
|
||
"""视频分析模型配置(提供向后兼容的属性访问)"""
|
||
return self.utils_video
|
||
|
||
def get_task(self, task_name: str) -> TaskConfig:
|
||
"""获取指定任务的配置"""
|
||
# 处理向后兼容性:如果请求video_analysis,返回utils_video
|
||
if task_name == "video_analysis":
|
||
task_name = "utils_video"
|
||
|
||
if hasattr(self, task_name):
|
||
config = getattr(self, task_name)
|
||
if config is None:
|
||
raise ValueError(f"任务 '{task_name}' 未配置")
|
||
return config
|
||
raise ValueError(f"任务 '{task_name}' 未找到对应的配置")
|
||
|
||
|
||
class APIAdapterConfig(ValidatedConfigBase):
|
||
"""API Adapter配置类"""
|
||
|
||
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||
|
||
def __init__(self, **data):
|
||
super().__init__(**data)
|
||
self.api_providers_dict = {provider.name: provider for provider in self.api_providers}
|
||
self.models_dict = {model.name: model for model in self.models}
|
||
|
||
@classmethod
|
||
def validate_models_list(cls, v):
|
||
"""验证模型列表"""
|
||
if not v:
|
||
raise ValueError("模型列表不能为空,请在配置中设置有效的模型列表。")
|
||
|
||
# 检查模型名称是否重复
|
||
model_names = [model.name for model in v]
|
||
if len(model_names) != len(set(model_names)):
|
||
raise ValueError("模型名称存在重复,请检查配置文件。")
|
||
|
||
# 检查模型标识符是否有效
|
||
for model in v:
|
||
if not model.model_identifier:
|
||
raise ValueError(f"模型 '{model.name}' 的 model_identifier 不能为空")
|
||
|
||
return v
|
||
|
||
@classmethod
|
||
def validate_api_providers_list(cls, v):
|
||
"""验证API提供商列表"""
|
||
if not v:
|
||
raise ValueError("API提供商列表不能为空,请在配置中设置有效的API提供商列表。")
|
||
|
||
# 检查API提供商名称是否重复
|
||
provider_names = [provider.name for provider in v]
|
||
if len(provider_names) != len(set(provider_names)):
|
||
raise ValueError("API提供商名称存在重复,请检查配置文件。")
|
||
|
||
return v
|
||
|
||
def get_model_info(self, model_name: str) -> ModelInfo:
|
||
"""根据模型名称获取模型信息"""
|
||
if not model_name:
|
||
raise ValueError("模型名称不能为空")
|
||
if model_name not in self.models_dict:
|
||
raise KeyError(f"模型 '{model_name}' 不存在")
|
||
return self.models_dict[model_name]
|
||
|
||
def get_provider(self, provider_name: str) -> APIProvider:
|
||
"""根据提供商名称获取API提供商信息"""
|
||
if not provider_name:
|
||
raise ValueError("API提供商名称不能为空")
|
||
if provider_name not in self.api_providers_dict:
|
||
raise KeyError(f"API提供商 '{provider_name}' 不存在")
|
||
return self.api_providers_dict[provider_name]
|