re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -1,6 +1,7 @@
|
||||
from typing import List, Dict, Any, Literal, Union, Optional
|
||||
from pydantic import Field
|
||||
from threading import Lock
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from src.config.config_base import ValidatedConfigBase
|
||||
|
||||
@@ -10,7 +11,7 @@ class APIProvider(ValidatedConfigBase):
|
||||
|
||||
name: str = Field(..., min_length=1, description="API提供商名称")
|
||||
base_url: str = Field(..., description="API基础URL")
|
||||
api_key: Union[str, List[str]] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询")
|
||||
api_key: str | list[str] = Field(..., min_length=1, description="API密钥,支持单个密钥或密钥列表轮询")
|
||||
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(
|
||||
default="openai", description="客户端类型(如openai/google等,默认为openai)"
|
||||
)
|
||||
@@ -70,7 +71,7 @@ class ModelInfo(ValidatedConfigBase):
|
||||
price_in: float = Field(default=0.0, ge=0, description="每M token输入价格")
|
||||
price_out: float = Field(default=0.0, ge=0, description="每M token输出价格")
|
||||
force_stream_mode: bool = Field(default=False, description="是否强制使用流式输出模式")
|
||||
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)")
|
||||
extra_params: dict[str, Any] = Field(default_factory=dict, description="额外参数(用于API调用时的额外配置)")
|
||||
anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断")
|
||||
|
||||
@classmethod
|
||||
@@ -101,11 +102,11 @@ class ModelInfo(ValidatedConfigBase):
|
||||
class TaskConfig(ValidatedConfigBase):
|
||||
"""任务配置类"""
|
||||
|
||||
model_list: List[str] = Field(..., description="任务使用的模型列表")
|
||||
model_list: list[str] = Field(..., description="任务使用的模型列表")
|
||||
max_tokens: int = Field(default=800, description="任务最大输出token数")
|
||||
temperature: float = Field(default=0.7, description="模型温度")
|
||||
concurrency_count: int = Field(default=1, description="并发请求数量")
|
||||
embedding_dimension: Optional[int] = Field(
|
||||
embedding_dimension: int | None = Field(
|
||||
default=None,
|
||||
description="嵌入模型输出向量维度,仅在嵌入任务中使用",
|
||||
ge=1,
|
||||
@@ -168,9 +169,9 @@ class ModelTaskConfig(ValidatedConfigBase):
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
models: List[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||||
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
|
||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||
api_providers: List[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||||
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
@@ -1,59 +1,58 @@
|
||||
import os
|
||||
import tomlkit
|
||||
import shutil
|
||||
import sys
|
||||
|
||||
from datetime import datetime
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table, KeyType
|
||||
from rich.traceback import install
|
||||
from typing import List, Optional
|
||||
|
||||
import tomlkit
|
||||
from pydantic import Field
|
||||
from rich.traceback import install
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import KeyType, Table
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config_base import ValidatedConfigBase
|
||||
from src.config.official_configs import (
|
||||
DatabaseConfig,
|
||||
AffinityFlowConfig,
|
||||
AntiPromptInjectionConfig,
|
||||
BotConfig,
|
||||
PersonalityConfig,
|
||||
ExpressionConfig,
|
||||
ChatConfig,
|
||||
EmojiConfig,
|
||||
MemoryConfig,
|
||||
MoodConfig,
|
||||
KeywordReactionConfig,
|
||||
ChineseTypoConfig,
|
||||
CommandConfig,
|
||||
CrossContextConfig,
|
||||
CustomPromptConfig,
|
||||
DatabaseConfig,
|
||||
DebugConfig,
|
||||
DependencyManagementConfig,
|
||||
EmojiConfig,
|
||||
ExperimentalConfig,
|
||||
ExpressionConfig,
|
||||
KeywordReactionConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
MaimMessageConfig,
|
||||
MemoryConfig,
|
||||
MessageReceiveConfig,
|
||||
MoodConfig,
|
||||
NormalChatConfig,
|
||||
PermissionConfig,
|
||||
PersonalityConfig,
|
||||
PlanningSystemConfig,
|
||||
ProactiveThinkingConfig,
|
||||
RelationshipConfig,
|
||||
ResponsePostProcessConfig,
|
||||
ResponseSplitterConfig,
|
||||
ExperimentalConfig,
|
||||
MessageReceiveConfig,
|
||||
MaimMessageConfig,
|
||||
LPMMKnowledgeConfig,
|
||||
RelationshipConfig,
|
||||
ToolConfig,
|
||||
VoiceConfig,
|
||||
DebugConfig,
|
||||
CustomPromptConfig,
|
||||
VideoAnalysisConfig,
|
||||
DependencyManagementConfig,
|
||||
WebSearchConfig,
|
||||
AntiPromptInjectionConfig,
|
||||
SleepSystemConfig,
|
||||
CrossContextConfig,
|
||||
PermissionConfig,
|
||||
CommandConfig,
|
||||
PlanningSystemConfig,
|
||||
AffinityFlowConfig,
|
||||
ProactiveThinkingConfig,
|
||||
ToolConfig,
|
||||
VideoAnalysisConfig,
|
||||
VoiceConfig,
|
||||
WebSearchConfig,
|
||||
)
|
||||
|
||||
from .api_ada_configs import (
|
||||
ModelTaskConfig,
|
||||
ModelInfo,
|
||||
APIProvider,
|
||||
ModelInfo,
|
||||
ModelTaskConfig,
|
||||
)
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
|
||||
@@ -154,11 +153,11 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
|
||||
return logs, changes
|
||||
|
||||
|
||||
def _get_version_from_toml(toml_path) -> Optional[str]:
|
||||
def _get_version_from_toml(toml_path) -> str | None:
|
||||
"""从TOML文件中获取版本号"""
|
||||
if not os.path.exists(toml_path):
|
||||
return None
|
||||
with open(toml_path, "r", encoding="utf-8") as f:
|
||||
with open(toml_path, encoding="utf-8") as f:
|
||||
doc = tomlkit.load(f)
|
||||
if "inner" in doc and "version" in doc["inner"]: # type: ignore
|
||||
return doc["inner"]["version"] # type: ignore
|
||||
@@ -270,17 +269,17 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
|
||||
# 先读取 compare 下的模板(如果有),用于默认值变动检测
|
||||
if os.path.exists(compare_path):
|
||||
with open(compare_path, "r", encoding="utf-8") as f:
|
||||
with open(compare_path, encoding="utf-8") as f:
|
||||
compare_config = tomlkit.load(f)
|
||||
|
||||
# 读取当前模板
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
with open(template_path, encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查默认值变化并处理(只有 compare_config 存在时才做)
|
||||
if compare_config:
|
||||
# 读取旧配置
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
with open(old_config_path, encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
logs, changes = compare_default_values(new_config, compare_config)
|
||||
if logs:
|
||||
@@ -318,7 +317,7 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
|
||||
# 读取旧配置文件和模板文件(如果前面没读过 old_config,这里再读一次)
|
||||
if old_config is None:
|
||||
with open(old_config_path, "r", encoding="utf-8") as f:
|
||||
with open(old_config_path, encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
# new_config 已经读取
|
||||
|
||||
@@ -364,7 +363,7 @@ def _update_config_generic(config_name: str, template_name: str):
|
||||
|
||||
# 移除在新模板中已不存在的旧配置项
|
||||
logger.info(f"开始移除{config_name}中已废弃的配置项...")
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
with open(template_path, encoding="utf-8") as f:
|
||||
template_doc = tomlkit.load(f)
|
||||
_remove_obsolete_keys(new_config, template_doc)
|
||||
logger.info(f"已移除{config_name}中已废弃的配置项")
|
||||
@@ -442,9 +441,9 @@ class Config(ValidatedConfigBase):
|
||||
class APIAdapterConfig(ValidatedConfigBase):
|
||||
"""API Adapter配置类"""
|
||||
|
||||
models: List[ModelInfo] = Field(..., min_items=1, description="模型列表")
|
||||
models: list[ModelInfo] = Field(..., min_items=1, description="模型列表")
|
||||
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
|
||||
api_providers: List[APIProvider] = Field(..., min_items=1, description="API提供商列表")
|
||||
api_providers: list[APIProvider] = Field(..., min_items=1, description="API提供商列表")
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
@@ -508,7 +507,7 @@ def load_config(config_path: str) -> Config:
|
||||
Config对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建Config对象(各个配置类会自动进行 Pydantic 验证)
|
||||
@@ -531,7 +530,7 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
|
||||
APIAdapterConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
config_dict = dict(config_data)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from dataclasses import dataclass, fields, MISSING
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||||
from dataclasses import MISSING, dataclass, fields
|
||||
from typing import Any, Literal, TypeVar, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Self
|
||||
|
||||
T = TypeVar("T", bound="ConfigBase")
|
||||
|
||||
@@ -19,7 +21,7 @@ class ConfigBase:
|
||||
"""配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
|
||||
def from_dict(cls, data: dict[str, Any]) -> Self:
|
||||
"""从字典加载配置字段"""
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
|
||||
@@ -53,7 +55,7 @@ class ConfigBase:
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
def _convert_field(cls, value: Any, field_type: type[Any]) -> Any:
|
||||
"""
|
||||
转换字段值为指定类型
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Literal, Optional, List
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from src.config.config_base import ValidatedConfigBase
|
||||
@@ -42,7 +43,7 @@ class BotConfig(ValidatedConfigBase):
|
||||
platform: str = Field(..., description="平台")
|
||||
qq_account: int = Field(..., description="QQ账号")
|
||||
nickname: str = Field(..., description="昵称")
|
||||
alias_names: List[str] = Field(default_factory=list, description="别名列表")
|
||||
alias_names: list[str] = Field(default_factory=list, description="别名列表")
|
||||
|
||||
|
||||
class PersonalityConfig(ValidatedConfigBase):
|
||||
@@ -54,7 +55,7 @@ class PersonalityConfig(ValidatedConfigBase):
|
||||
background_story: str = Field(
|
||||
default="", description="世界观背景故事,这部分内容会作为背景知识,LLM被指导不应主动复述"
|
||||
)
|
||||
safety_guidelines: List[str] = Field(
|
||||
safety_guidelines: list[str] = Field(
|
||||
default_factory=list, description="安全与互动底线,Bot在任何情况下都必须遵守的原则"
|
||||
)
|
||||
reply_style: str = Field(default="", description="表达风格")
|
||||
@@ -63,7 +64,7 @@ class PersonalityConfig(ValidatedConfigBase):
|
||||
compress_identity: bool = Field(default=True, description="是否压缩身份")
|
||||
|
||||
# 回复规则配置
|
||||
reply_targeting_rules: List[str] = Field(
|
||||
reply_targeting_rules: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"拒绝任何包含骚扰、冒犯、暴力、色情或危险内容的请求。",
|
||||
"在拒绝时,请使用符合你人设的、坚定的语气。",
|
||||
@@ -72,7 +73,7 @@ class PersonalityConfig(ValidatedConfigBase):
|
||||
description="安全与互动底线规则,Bot在任何情况下都必须遵守的原则",
|
||||
)
|
||||
|
||||
message_targeting_analysis: List[str] = Field(
|
||||
message_targeting_analysis: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"**直接针对你**:@你、回复你、明确询问你 → 必须回应",
|
||||
"**间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与",
|
||||
@@ -82,7 +83,7 @@ class PersonalityConfig(ValidatedConfigBase):
|
||||
description="消息针对性分析规则,用于判断是否需要回复",
|
||||
)
|
||||
|
||||
reply_principles: List[str] = Field(
|
||||
reply_principles: list[str] = Field(
|
||||
default_factory=lambda: [
|
||||
"明确回应目标消息,而不是宽泛地评论。",
|
||||
"可以分享你的看法、提出相关问题,或者开个合适的玩笑。",
|
||||
@@ -111,7 +112,7 @@ class ChatConfig(ValidatedConfigBase):
|
||||
at_bot_inevitable_reply: bool = Field(default=False, description="@机器人的必然回复")
|
||||
allow_reply_self: bool = Field(default=False, description="是否允许回复自己说的话")
|
||||
focus_value: float = Field(default=1.0, description="专注值")
|
||||
focus_mode_quiet_groups: List[str] = Field(
|
||||
focus_mode_quiet_groups: list[str] = Field(
|
||||
default_factory=list,
|
||||
description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]',
|
||||
)
|
||||
@@ -140,8 +141,8 @@ class ChatConfig(ValidatedConfigBase):
|
||||
class MessageReceiveConfig(ValidatedConfigBase):
|
||||
"""消息接收配置类"""
|
||||
|
||||
ban_words: List[str] = Field(default_factory=lambda: list(), description="禁用词列表")
|
||||
ban_msgs_regex: List[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表")
|
||||
ban_words: list[str] = Field(default_factory=lambda: list(), description="禁用词列表")
|
||||
ban_msgs_regex: list[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表")
|
||||
|
||||
|
||||
class NormalChatConfig(ValidatedConfigBase):
|
||||
@@ -155,16 +156,16 @@ class ExpressionRule(ValidatedConfigBase):
|
||||
use_expression: bool = Field(default=True, description="是否使用学到的表达")
|
||||
learn_expression: bool = Field(default=True, description="是否学习表达")
|
||||
learning_strength: float = Field(default=1.0, description="学习强度")
|
||||
group: Optional[str] = Field(default=None, description="表达共享组")
|
||||
group: str | None = Field(default=None, description="表达共享组")
|
||||
|
||||
|
||||
class ExpressionConfig(ValidatedConfigBase):
|
||||
"""表达配置类"""
|
||||
|
||||
rules: List[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
|
||||
rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
|
||||
|
||||
@staticmethod
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
|
||||
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
|
||||
"""
|
||||
解析流配置字符串并生成对应的 chat_id
|
||||
|
||||
@@ -199,7 +200,7 @@ class ExpressionConfig(ValidatedConfigBase):
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, float]:
|
||||
def get_expression_config_for_chat(self, chat_stream_id: str | None = None) -> tuple[bool, bool, float]:
|
||||
"""
|
||||
根据聊天流ID获取表达配置
|
||||
|
||||
@@ -362,7 +363,7 @@ class KeywordRuleConfig(ValidatedConfigBase):
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as e:
|
||||
raise ValueError(f"无效的正则表达式 '{pattern}': {str(e)}") from e
|
||||
raise ValueError(f"无效的正则表达式 '{pattern}': {e!s}") from e
|
||||
|
||||
|
||||
class KeywordReactionConfig(ValidatedConfigBase):
|
||||
@@ -561,10 +562,10 @@ class SleepSystemConfig(ValidatedConfigBase):
|
||||
|
||||
# --- 失眠机制相关参数 ---
|
||||
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")
|
||||
insomnia_trigger_delay_minutes: List[int] = Field(
|
||||
insomnia_trigger_delay_minutes: list[int] = Field(
|
||||
default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
|
||||
)
|
||||
insomnia_duration_minutes: List[int] = Field(
|
||||
insomnia_duration_minutes: list[int] = Field(
|
||||
default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)"
|
||||
)
|
||||
sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值")
|
||||
@@ -590,7 +591,7 @@ class ContextGroup(ValidatedConfigBase):
|
||||
"""上下文共享组配置"""
|
||||
|
||||
name: str = Field(..., description="共享组的名称")
|
||||
chat_ids: List[List[str]] = Field(
|
||||
chat_ids: list[list[str]] = Field(
|
||||
...,
|
||||
description='属于该组的聊天ID列表,格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]',
|
||||
)
|
||||
@@ -600,20 +601,20 @@ class CrossContextConfig(ValidatedConfigBase):
|
||||
"""跨群聊上下文共享配置"""
|
||||
|
||||
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
|
||||
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
||||
groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
|
||||
|
||||
|
||||
class CommandConfig(ValidatedConfigBase):
|
||||
"""命令系统配置类"""
|
||||
|
||||
command_prefixes: List[str] = Field(default_factory=lambda: ["/", "!", ".", "#"], description="支持的命令前缀列表")
|
||||
command_prefixes: list[str] = Field(default_factory=lambda: ["/", "!", ".", "#"], description="支持的命令前缀列表")
|
||||
|
||||
|
||||
class PermissionConfig(ValidatedConfigBase):
|
||||
"""权限系统配置类"""
|
||||
|
||||
# Master用户配置(拥有最高权限,无视所有权限节点)
|
||||
master_users: List[List[str]] = Field(
|
||||
master_users: list[list[str]] = Field(
|
||||
default_factory=list, description="Master用户列表,格式: [[platform, user_id], ...]"
|
||||
)
|
||||
|
||||
@@ -668,10 +669,10 @@ class ProactiveThinkingConfig(ValidatedConfigBase):
|
||||
# --- 作用范围 ---
|
||||
enable_in_private: bool = Field(default=True, description="是否允许在私聊中主动发起对话")
|
||||
enable_in_group: bool = Field(default=True, description="是否允许在群聊中主动发起对话")
|
||||
enabled_private_chats: List[str] = Field(
|
||||
enabled_private_chats: list[str] = Field(
|
||||
default_factory=list, description='私聊白名单,为空则对所有私聊生效。格式: ["platform:user_id", ...]'
|
||||
)
|
||||
enabled_group_chats: List[str] = Field(
|
||||
enabled_group_chats: list[str] = Field(
|
||||
default_factory=list, description='群聊白名单,为空则对所有群聊生效。格式: ["platform:group_id", ...]'
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user