re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
committed by Windpicker-owo
parent 00ba07e0e1
commit a79253c714
263 changed files with 3781 additions and 3189 deletions

View File

@@ -1,6 +1,7 @@
from typing import List, Dict, Any, Literal, Union, Optional
from pydantic import Field
from threading import Lock
from typing import Any, Literal
from pydantic import Field
from src.config.config_base import ValidatedConfigBase
@@ -10,7 +11,7 @@ class APIProvider(ValidatedConfigBase):
name: str = Field(..., min_length=1, description="API提供商名称")
base_url: str = Field(..., description="API基础URL")
api_key: Union[str, List[str]] = Field(..., min_length=1, description="API密钥支持单个密钥或密钥列表轮询")
api_key: str | list[str] = Field(..., min_length=1, description="API密钥支持单个密钥或密钥列表轮询")
client_type: Literal["openai", "gemini", "aiohttp_gemini"] = Field(
default="openai", description="客户端类型如openai/google等默认为openai"
)
@@ -70,7 +71,7 @@ class ModelInfo(ValidatedConfigBase):
price_in: float = Field(default=0.0, ge=0, description="每M token输入价格")
price_out: float = Field(default=0.0, ge=0, description="每M token输出价格")
force_stream_mode: bool = Field(default=False, description="是否强制使用流式输出模式")
extra_params: Dict[str, Any] = Field(default_factory=dict, description="额外参数用于API调用时的额外配置")
extra_params: dict[str, Any] = Field(default_factory=dict, description="额外参数用于API调用时的额外配置")
anti_truncation: bool = Field(default=False, description="是否启用反截断功能,防止模型输出被截断")
@classmethod
@@ -101,11 +102,11 @@ class ModelInfo(ValidatedConfigBase):
class TaskConfig(ValidatedConfigBase):
"""任务配置类"""
model_list: List[str] = Field(..., description="任务使用的模型列表")
model_list: list[str] = Field(..., description="任务使用的模型列表")
max_tokens: int = Field(default=800, description="任务最大输出token数")
temperature: float = Field(default=0.7, description="模型温度")
concurrency_count: int = Field(default=1, description="并发请求数量")
embedding_dimension: Optional[int] = Field(
embedding_dimension: int | None = Field(
default=None,
description="嵌入模型输出向量维度,仅在嵌入任务中使用",
ge=1,
@@ -168,9 +169,9 @@ class ModelTaskConfig(ValidatedConfigBase):
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo] = Field(..., min_length=1, description="模型列表")
models: list[ModelInfo] = Field(..., min_length=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: List[APIProvider] = Field(..., min_length=1, description="API提供商列表")
api_providers: list[APIProvider] = Field(..., min_length=1, description="API提供商列表")
def __init__(self, **data):
super().__init__(**data)

View File

@@ -1,59 +1,58 @@
import os
import tomlkit
import shutil
import sys
from datetime import datetime
from tomlkit import TOMLDocument
from tomlkit.items import Table, KeyType
from rich.traceback import install
from typing import List, Optional
import tomlkit
from pydantic import Field
from rich.traceback import install
from tomlkit import TOMLDocument
from tomlkit.items import KeyType, Table
from src.common.logger import get_logger
from src.config.config_base import ValidatedConfigBase
from src.config.official_configs import (
DatabaseConfig,
AffinityFlowConfig,
AntiPromptInjectionConfig,
BotConfig,
PersonalityConfig,
ExpressionConfig,
ChatConfig,
EmojiConfig,
MemoryConfig,
MoodConfig,
KeywordReactionConfig,
ChineseTypoConfig,
CommandConfig,
CrossContextConfig,
CustomPromptConfig,
DatabaseConfig,
DebugConfig,
DependencyManagementConfig,
EmojiConfig,
ExperimentalConfig,
ExpressionConfig,
KeywordReactionConfig,
LPMMKnowledgeConfig,
MaimMessageConfig,
MemoryConfig,
MessageReceiveConfig,
MoodConfig,
NormalChatConfig,
PermissionConfig,
PersonalityConfig,
PlanningSystemConfig,
ProactiveThinkingConfig,
RelationshipConfig,
ResponsePostProcessConfig,
ResponseSplitterConfig,
ExperimentalConfig,
MessageReceiveConfig,
MaimMessageConfig,
LPMMKnowledgeConfig,
RelationshipConfig,
ToolConfig,
VoiceConfig,
DebugConfig,
CustomPromptConfig,
VideoAnalysisConfig,
DependencyManagementConfig,
WebSearchConfig,
AntiPromptInjectionConfig,
SleepSystemConfig,
CrossContextConfig,
PermissionConfig,
CommandConfig,
PlanningSystemConfig,
AffinityFlowConfig,
ProactiveThinkingConfig,
ToolConfig,
VideoAnalysisConfig,
VoiceConfig,
WebSearchConfig,
)
from .api_ada_configs import (
ModelTaskConfig,
ModelInfo,
APIProvider,
ModelInfo,
ModelTaskConfig,
)
install(extra_lines=3)
@@ -154,11 +153,11 @@ def compare_default_values(new, old, path=None, logs=None, changes=None):
return logs, changes
def _get_version_from_toml(toml_path) -> Optional[str]:
def _get_version_from_toml(toml_path) -> str | None:
"""从TOML文件中获取版本号"""
if not os.path.exists(toml_path):
return None
with open(toml_path, "r", encoding="utf-8") as f:
with open(toml_path, encoding="utf-8") as f:
doc = tomlkit.load(f)
if "inner" in doc and "version" in doc["inner"]: # type: ignore
return doc["inner"]["version"] # type: ignore
@@ -270,17 +269,17 @@ def _update_config_generic(config_name: str, template_name: str):
# 先读取 compare 下的模板(如果有),用于默认值变动检测
if os.path.exists(compare_path):
with open(compare_path, "r", encoding="utf-8") as f:
with open(compare_path, encoding="utf-8") as f:
compare_config = tomlkit.load(f)
# 读取当前模板
with open(template_path, "r", encoding="utf-8") as f:
with open(template_path, encoding="utf-8") as f:
new_config = tomlkit.load(f)
# 检查默认值变化并处理(只有 compare_config 存在时才做)
if compare_config:
# 读取旧配置
with open(old_config_path, "r", encoding="utf-8") as f:
with open(old_config_path, encoding="utf-8") as f:
old_config = tomlkit.load(f)
logs, changes = compare_default_values(new_config, compare_config)
if logs:
@@ -318,7 +317,7 @@ def _update_config_generic(config_name: str, template_name: str):
# 读取旧配置文件和模板文件(如果前面没读过 old_config这里再读一次
if old_config is None:
with open(old_config_path, "r", encoding="utf-8") as f:
with open(old_config_path, encoding="utf-8") as f:
old_config = tomlkit.load(f)
# new_config 已经读取
@@ -364,7 +363,7 @@ def _update_config_generic(config_name: str, template_name: str):
# 移除在新模板中已不存在的旧配置项
logger.info(f"开始移除{config_name}中已废弃的配置项...")
with open(template_path, "r", encoding="utf-8") as f:
with open(template_path, encoding="utf-8") as f:
template_doc = tomlkit.load(f)
_remove_obsolete_keys(new_config, template_doc)
logger.info(f"已移除{config_name}中已废弃的配置项")
@@ -442,9 +441,9 @@ class Config(ValidatedConfigBase):
class APIAdapterConfig(ValidatedConfigBase):
"""API Adapter配置类"""
models: List[ModelInfo] = Field(..., min_items=1, description="模型列表")
models: list[ModelInfo] = Field(..., min_items=1, description="模型列表")
model_task_config: ModelTaskConfig = Field(..., description="模型任务配置")
api_providers: List[APIProvider] = Field(..., min_items=1, description="API提供商列表")
api_providers: list[APIProvider] = Field(..., min_items=1, description="API提供商列表")
def __init__(self, **data):
super().__init__(**data)
@@ -508,7 +507,7 @@ def load_config(config_path: str) -> Config:
Config对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
# 创建Config对象各个配置类会自动进行 Pydantic 验证)
@@ -531,7 +530,7 @@ def api_ada_load_config(config_path: str) -> APIAdapterConfig:
APIAdapterConfig对象
"""
# 读取配置文件
with open(config_path, "r", encoding="utf-8") as f:
with open(config_path, encoding="utf-8") as f:
config_data = tomlkit.load(f)
config_dict = dict(config_data)

View File

@@ -1,6 +1,8 @@
from dataclasses import dataclass, fields, MISSING
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
from dataclasses import MISSING, dataclass, fields
from typing import Any, Literal, TypeVar, get_args, get_origin
from pydantic import BaseModel, ValidationError
from typing_extensions import Self
T = TypeVar("T", bound="ConfigBase")
@@ -19,7 +21,7 @@ class ConfigBase:
"""配置类的基类"""
@classmethod
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
def from_dict(cls, data: dict[str, Any]) -> Self:
"""从字典加载配置字段"""
if not isinstance(data, dict):
raise TypeError(f"Expected a dictionary, got {type(data).__name__}")
@@ -53,7 +55,7 @@ class ConfigBase:
return cls()
@classmethod
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
def _convert_field(cls, value: Any, field_type: type[Any]) -> Any:
"""
转换字段值为指定类型

View File

@@ -1,4 +1,5 @@
from typing import Literal, Optional, List
from typing import Literal
from pydantic import Field
from src.config.config_base import ValidatedConfigBase
@@ -42,7 +43,7 @@ class BotConfig(ValidatedConfigBase):
platform: str = Field(..., description="平台")
qq_account: int = Field(..., description="QQ账号")
nickname: str = Field(..., description="昵称")
alias_names: List[str] = Field(default_factory=list, description="别名列表")
alias_names: list[str] = Field(default_factory=list, description="别名列表")
class PersonalityConfig(ValidatedConfigBase):
@@ -54,7 +55,7 @@ class PersonalityConfig(ValidatedConfigBase):
background_story: str = Field(
default="", description="世界观背景故事这部分内容会作为背景知识LLM被指导不应主动复述"
)
safety_guidelines: List[str] = Field(
safety_guidelines: list[str] = Field(
default_factory=list, description="安全与互动底线Bot在任何情况下都必须遵守的原则"
)
reply_style: str = Field(default="", description="表达风格")
@@ -63,7 +64,7 @@ class PersonalityConfig(ValidatedConfigBase):
compress_identity: bool = Field(default=True, description="是否压缩身份")
# 回复规则配置
reply_targeting_rules: List[str] = Field(
reply_targeting_rules: list[str] = Field(
default_factory=lambda: [
"拒绝任何包含骚扰、冒犯、暴力、色情或危险内容的请求。",
"在拒绝时,请使用符合你人设的、坚定的语气。",
@@ -72,7 +73,7 @@ class PersonalityConfig(ValidatedConfigBase):
description="安全与互动底线规则Bot在任何情况下都必须遵守的原则",
)
message_targeting_analysis: List[str] = Field(
message_targeting_analysis: list[str] = Field(
default_factory=lambda: [
"**直接针对你**@你、回复你、明确询问你 → 必须回应",
"**间接相关**:涉及你感兴趣的话题但未直接问你 → 谨慎参与",
@@ -82,7 +83,7 @@ class PersonalityConfig(ValidatedConfigBase):
description="消息针对性分析规则,用于判断是否需要回复",
)
reply_principles: List[str] = Field(
reply_principles: list[str] = Field(
default_factory=lambda: [
"明确回应目标消息,而不是宽泛地评论。",
"可以分享你的看法、提出相关问题,或者开个合适的玩笑。",
@@ -111,7 +112,7 @@ class ChatConfig(ValidatedConfigBase):
at_bot_inevitable_reply: bool = Field(default=False, description="@机器人的必然回复")
allow_reply_self: bool = Field(default=False, description="是否允许回复自己说的话")
focus_value: float = Field(default=1.0, description="专注值")
focus_mode_quiet_groups: List[str] = Field(
focus_mode_quiet_groups: list[str] = Field(
default_factory=list,
description='专注模式下需要保持安静的群组列表, 格式: ["platform:group_id1", "platform:group_id2"]',
)
@@ -140,8 +141,8 @@ class ChatConfig(ValidatedConfigBase):
class MessageReceiveConfig(ValidatedConfigBase):
"""消息接收配置类"""
ban_words: List[str] = Field(default_factory=lambda: list(), description="禁用词列表")
ban_msgs_regex: List[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表")
ban_words: list[str] = Field(default_factory=lambda: list(), description="禁用词列表")
ban_msgs_regex: list[str] = Field(default_factory=lambda: list(), description="禁用消息正则列表")
class NormalChatConfig(ValidatedConfigBase):
@@ -155,16 +156,16 @@ class ExpressionRule(ValidatedConfigBase):
use_expression: bool = Field(default=True, description="是否使用学到的表达")
learn_expression: bool = Field(default=True, description="是否学习表达")
learning_strength: float = Field(default=1.0, description="学习强度")
group: Optional[str] = Field(default=None, description="表达共享组")
group: str | None = Field(default=None, description="表达共享组")
class ExpressionConfig(ValidatedConfigBase):
"""表达配置类"""
rules: List[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
rules: list[ExpressionRule] = Field(default_factory=list, description="表达学习规则")
@staticmethod
def _parse_stream_config_to_chat_id(stream_config_str: str) -> Optional[str]:
def _parse_stream_config_to_chat_id(stream_config_str: str) -> str | None:
"""
解析流配置字符串并生成对应的 chat_id
@@ -199,7 +200,7 @@ class ExpressionConfig(ValidatedConfigBase):
except (ValueError, IndexError):
return None
def get_expression_config_for_chat(self, chat_stream_id: Optional[str] = None) -> tuple[bool, bool, float]:
def get_expression_config_for_chat(self, chat_stream_id: str | None = None) -> tuple[bool, bool, float]:
"""
根据聊天流ID获取表达配置
@@ -362,7 +363,7 @@ class KeywordRuleConfig(ValidatedConfigBase):
try:
re.compile(pattern)
except re.error as e:
raise ValueError(f"无效的正则表达式 '{pattern}': {str(e)}") from e
raise ValueError(f"无效的正则表达式 '{pattern}': {e!s}") from e
class KeywordReactionConfig(ValidatedConfigBase):
@@ -561,10 +562,10 @@ class SleepSystemConfig(ValidatedConfigBase):
# --- 失眠机制相关参数 ---
enable_insomnia_system: bool = Field(default=True, description="是否启用失眠系统")
insomnia_trigger_delay_minutes: List[int] = Field(
insomnia_trigger_delay_minutes: list[int] = Field(
default_factory=lambda: [30, 60], description="入睡后触发失眠判定的延迟时间范围(分钟)"
)
insomnia_duration_minutes: List[int] = Field(
insomnia_duration_minutes: list[int] = Field(
default_factory=lambda: [15, 45], description="单次失眠状态的持续时间范围(分钟)"
)
sleep_pressure_threshold: float = Field(default=30.0, description="触发“压力不足型失眠”的睡眠压力阈值")
@@ -590,7 +591,7 @@ class ContextGroup(ValidatedConfigBase):
"""上下文共享组配置"""
name: str = Field(..., description="共享组的名称")
chat_ids: List[List[str]] = Field(
chat_ids: list[list[str]] = Field(
...,
description='属于该组的聊天ID列表格式为 [["type", "chat_id"], ...],例如 [["group", "123456"], ["private", "789012"]]',
)
@@ -600,20 +601,20 @@ class CrossContextConfig(ValidatedConfigBase):
"""跨群聊上下文共享配置"""
enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能")
groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
groups: list[ContextGroup] = Field(default_factory=list, description="上下文共享组列表")
class CommandConfig(ValidatedConfigBase):
"""命令系统配置类"""
command_prefixes: List[str] = Field(default_factory=lambda: ["/", "!", ".", "#"], description="支持的命令前缀列表")
command_prefixes: list[str] = Field(default_factory=lambda: ["/", "!", ".", "#"], description="支持的命令前缀列表")
class PermissionConfig(ValidatedConfigBase):
"""权限系统配置类"""
# Master用户配置拥有最高权限无视所有权限节点
master_users: List[List[str]] = Field(
master_users: list[list[str]] = Field(
default_factory=list, description="Master用户列表格式: [[platform, user_id], ...]"
)
@@ -668,10 +669,10 @@ class ProactiveThinkingConfig(ValidatedConfigBase):
# --- 作用范围 ---
enable_in_private: bool = Field(default=True, description="是否允许在私聊中主动发起对话")
enable_in_group: bool = Field(default=True, description="是否允许在群聊中主动发起对话")
enabled_private_chats: List[str] = Field(
enabled_private_chats: list[str] = Field(
default_factory=list, description='私聊白名单,为空则对所有私聊生效。格式: ["platform:user_id", ...]'
)
enabled_group_chats: List[str] = Field(
enabled_group_chats: list[str] = Field(
default_factory=list, description='群聊白名单,为空则对所有群聊生效。格式: ["platform:group_id", ...]'
)