From 921d07e30abed227b83b9bab0a729431794e8c72 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:27:47 +0800 Subject: [PATCH 1/2] Enforce strict type validation and update config types Enabled strict type checking in ValidatedConfigBase to fully disable type coercion. Updated MessageReceiveConfig and MemoryConfig fields from set/tuple to list types for compatibility with strict validation. --- src/config/config_base.py | 1 + src/config/official_configs.py | 6 +++--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/config/config_base.py b/src/config/config_base.py index 62c585c22..5e27c9de0 100644 --- a/src/config/config_base.py +++ b/src/config/config_base.py @@ -142,6 +142,7 @@ class ValidatedConfigBase(BaseModel): "extra": "allow", # 允许额外字段 "validate_assignment": True, # 验证赋值 "arbitrary_types_allowed": True, # 允许任意类型 + "strict": True, # 如果设为 True 会完全禁用类型转换 } @classmethod diff --git a/src/config/official_configs.py b/src/config/official_configs.py index e6c4869f9..704e58690 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -246,8 +246,8 @@ class ChatConfig(ValidatedConfigBase): class MessageReceiveConfig(ValidatedConfigBase): """消息接收配置类""" - ban_words: set[str] = Field(default_factory=lambda: set(), description="禁用词列表") - ban_msgs_regex: set[str] = Field(default_factory=lambda: set(), description="禁用消息正则列表") + ban_words: List[str] = Field(default_factory=lambda: list(), description="禁用词列表") + ban_msgs_regex: List[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表") @@ -426,7 +426,7 @@ class MemoryConfig(ValidatedConfigBase): enable_memory: bool = Field(default=True, description="启用记忆") memory_build_interval: int = Field(default=600, description="记忆构建间隔") - memory_build_distribution: tuple = Field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4), description="记忆构建分布") + memory_build_distribution: list[float] = Field(default_factory=lambda: [6.0, 3.0, 0.6, 32.0, 12.0, 0.4], description="记忆构建分布") memory_build_sample_num: int = Field(default=8, description="记忆构建样本数量") memory_build_sample_length: int = Field(default=40, description="记忆构建样本长度") memory_compress_rate: float = Field(default=0.1, description="记忆压缩率") From f959ca6bb2fdd5e91e1a7611ec9569fac74c2f3a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Wed, 20 Aug 2025 19:38:37 +0800 Subject: [PATCH 2/2] Use Literal types for config field validation Replaced manual string validation with Python's Literal type for 'client_type' in APIProvider and 'search_strategy' in WebSearchConfig. This simplifies validation and improves type safety by restricting allowed values at the type level. --- src/config/api_ada_configs.py | 13 ++----------- src/config/official_configs.py | 2 +- 2 files changed, 3 insertions(+), 12 deletions(-) diff --git a/src/config/api_ada_configs.py b/src/config/api_ada_configs.py index 3c7827b81..c39c0ea13 100644 --- a/src/config/api_ada_configs.py +++ b/src/config/api_ada_configs.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any +from typing import List, Dict, Any, Literal from pydantic import Field, field_validator from src.config.config_base import ValidatedConfigBase @@ -10,7 +10,7 @@ class APIProvider(ValidatedConfigBase): name: str = Field(..., min_length=1, description="API提供商名称") base_url: str = Field(..., description="API基础URL") api_key: str = Field(..., min_length=1, description="API密钥") - client_type: str = Field(default="openai", description="客户端类型(如openai/google等,默认为openai)") + client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(default="openai", description="客户端类型(如openai/google等,默认为openai)") max_retry: int = Field(default=2, ge=0, description="最大重试次数(单个模型API调用失败,最多重试的次数)") timeout: int = Field(default=10, ge=1, description="API调用的超时时长(超过这个时长,本次请求将被视为'请求超时',单位:秒)") retry_interval: int = Field(default=10, ge=0, description="重试间隔(如果API调用失败,重试的间隔时间,单位:秒)") @@ -33,15 +33,6 @@ class APIProvider(ValidatedConfigBase): raise ValueError("API密钥不能为空") return v - @field_validator('client_type') - @classmethod - def validate_client_type(cls, v): - """验证客户端类型""" - allowed_types = ["openai", "gemini","aiohttp_gemini"] - if v not in allowed_types: - raise ValueError(f"客户端类型必须是以下之一: {allowed_types}") - return v - def get_api_key(self) -> str: return self.api_key diff --git a/src/config/official_configs.py b/src/config/official_configs.py index 704e58690..9b31e6a57 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -618,7 +618,7 @@ class WebSearchConfig(ValidatedConfigBase): enable_web_search_tool: bool = Field(default=True, description="启用网络搜索工具") enable_url_tool: bool = Field(default=True, description="启用URL工具") enabled_engines: list[str] = Field(default_factory=lambda: ["ddg"], description="启用的搜索引擎") - search_strategy: str = Field(default="single", description="搜索策略") + search_strategy: Literal["fallback","single","parallel"] = Field(default="single", description="搜索策略") class AntiPromptInjectionConfig(ValidatedConfigBase):