Merge branch 'master' of https://github.com/MaiBot-Plus/MaiMbot-Pro-Max
This commit is contained in:
@@ -1,4 +1,4 @@
|
|||||||
from typing import List, Dict, Any
|
from typing import List, Dict, Any, Literal
|
||||||
from pydantic import Field, field_validator
|
from pydantic import Field, field_validator
|
||||||
|
|
||||||
from src.config.config_base import ValidatedConfigBase
|
from src.config.config_base import ValidatedConfigBase
|
||||||
@@ -10,7 +10,7 @@ class APIProvider(ValidatedConfigBase):
|
|||||||
name: str = Field(..., min_length=1, description="API提供商名称")
|
name: str = Field(..., min_length=1, description="API提供商名称")
|
||||||
base_url: str = Field(..., description="API基础URL")
|
base_url: str = Field(..., description="API基础URL")
|
||||||
api_key: str = Field(..., min_length=1, description="API密钥")
|
api_key: str = Field(..., min_length=1, description="API密钥")
|
||||||
client_type: str = Field(default="openai", description="客户端类型(如openai/google等,默认为openai)")
|
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(default="openai", description="客户端类型(如openai/google等,默认为openai)")
|
||||||
max_retry: int = Field(default=2, ge=0, description="最大重试次数(单个模型API调用失败,最多重试的次数)")
|
max_retry: int = Field(default=2, ge=0, description="最大重试次数(单个模型API调用失败,最多重试的次数)")
|
||||||
timeout: int = Field(default=10, ge=1, description="API调用的超时时长(超过这个时长,本次请求将被视为'请求超时',单位:秒)")
|
timeout: int = Field(default=10, ge=1, description="API调用的超时时长(超过这个时长,本次请求将被视为'请求超时',单位:秒)")
|
||||||
retry_interval: int = Field(default=10, ge=0, description="重试间隔(如果API调用失败,重试的间隔时间,单位:秒)")
|
retry_interval: int = Field(default=10, ge=0, description="重试间隔(如果API调用失败,重试的间隔时间,单位:秒)")
|
||||||
@@ -33,15 +33,6 @@ class APIProvider(ValidatedConfigBase):
|
|||||||
raise ValueError("API密钥不能为空")
|
raise ValueError("API密钥不能为空")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
@field_validator('client_type')
|
|
||||||
@classmethod
|
|
||||||
def validate_client_type(cls, v):
|
|
||||||
"""验证客户端类型"""
|
|
||||||
allowed_types = ["openai", "gemini","aiohttp_gemini"]
|
|
||||||
if v not in allowed_types:
|
|
||||||
raise ValueError(f"客户端类型必须是以下之一: {allowed_types}")
|
|
||||||
return v
|
|
||||||
|
|
||||||
def get_api_key(self) -> str:
|
def get_api_key(self) -> str:
|
||||||
return self.api_key
|
return self.api_key
|
||||||
|
|
||||||
|
|||||||
@@ -142,6 +142,7 @@ class ValidatedConfigBase(BaseModel):
|
|||||||
"extra": "allow", # 允许额外字段
|
"extra": "allow", # 允许额外字段
|
||||||
"validate_assignment": True, # 验证赋值
|
"validate_assignment": True, # 验证赋值
|
||||||
"arbitrary_types_allowed": True, # 允许任意类型
|
"arbitrary_types_allowed": True, # 允许任意类型
|
||||||
|
"strict": True, # 如果设为 True 会完全禁用类型转换
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -246,8 +246,8 @@ class ChatConfig(ValidatedConfigBase):
|
|||||||
class MessageReceiveConfig(ValidatedConfigBase):
|
class MessageReceiveConfig(ValidatedConfigBase):
|
||||||
"""消息接收配置类"""
|
"""消息接收配置类"""
|
||||||
|
|
||||||
ban_words: set[str] = Field(default_factory=lambda: set(), description="禁用词列表")
|
ban_words: List[str] = Field(default_factory=lambda: list(), description="禁用词列表")
|
||||||
ban_msgs_regex: set[str] = Field(default_factory=lambda: set(), description="禁用消息正则列表")
|
ban_msgs_regex: List[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -426,7 +426,7 @@ class MemoryConfig(ValidatedConfigBase):
|
|||||||
|
|
||||||
enable_memory: bool = Field(default=True, description="启用记忆")
|
enable_memory: bool = Field(default=True, description="启用记忆")
|
||||||
memory_build_interval: int = Field(default=600, description="记忆构建间隔")
|
memory_build_interval: int = Field(default=600, description="记忆构建间隔")
|
||||||
memory_build_distribution: tuple = Field(default_factory=lambda: (6.0, 3.0, 0.6, 32.0, 12.0, 0.4), description="记忆构建分布")
|
memory_build_distribution: list[float] = Field(default_factory=lambda: [6.0, 3.0, 0.6, 32.0, 12.0, 0.4], description="记忆构建分布")
|
||||||
memory_build_sample_num: int = Field(default=8, description="记忆构建样本数量")
|
memory_build_sample_num: int = Field(default=8, description="记忆构建样本数量")
|
||||||
memory_build_sample_length: int = Field(default=40, description="记忆构建样本长度")
|
memory_build_sample_length: int = Field(default=40, description="记忆构建样本长度")
|
||||||
memory_compress_rate: float = Field(default=0.1, description="记忆压缩率")
|
memory_compress_rate: float = Field(default=0.1, description="记忆压缩率")
|
||||||
@@ -618,7 +618,7 @@ class WebSearchConfig(ValidatedConfigBase):
|
|||||||
enable_web_search_tool: bool = Field(default=True, description="启用网络搜索工具")
|
enable_web_search_tool: bool = Field(default=True, description="启用网络搜索工具")
|
||||||
enable_url_tool: bool = Field(default=True, description="启用URL工具")
|
enable_url_tool: bool = Field(default=True, description="启用URL工具")
|
||||||
enabled_engines: list[str] = Field(default_factory=lambda: ["ddg"], description="启用的搜索引擎")
|
enabled_engines: list[str] = Field(default_factory=lambda: ["ddg"], description="启用的搜索引擎")
|
||||||
search_strategy: str = Field(default="single", description="搜索策略")
|
search_strategy: Literal["fallback","single","parallel"] = Field(default="single", description="搜索策略")
|
||||||
|
|
||||||
|
|
||||||
class AntiPromptInjectionConfig(ValidatedConfigBase):
|
class AntiPromptInjectionConfig(ValidatedConfigBase):
|
||||||
|
|||||||
Reference in New Issue
Block a user