chore: 代码格式化与类型注解优化

对项目中的多个文件进行了代码风格调整和类型注解更新。

- 使用 ruff 工具对代码进行自动格式化,主要包括:
    - 统一 import 语句的顺序和风格。
    - 移除未使用的 import。
    - 调整代码间距和空行。
- 将部分 `Optional[str]` 和 `List[T]` 等旧式类型注解更新为现代的 `str | None` 和 `list[T]` 语法。
- 修复了一些小的代码风格问题,例如将 `f'...'` 更改为 `f"..."`。
This commit is contained in:
minecraft1024a
2025-10-24 19:08:32 +08:00
parent 1700bdbb42
commit bae520b293
27 changed files with 100 additions and 102 deletions

View File

@@ -7,7 +7,6 @@ from src.plugin_system import (
BaseEventHandler,
BasePlugin,
BasePrompt,
ToolParamType,
BaseTool,
ChatType,
CommandArgs,
@@ -15,6 +14,7 @@ from src.plugin_system import (
ConfigField,
EventType,
PlusCommand,
ToolParamType,
register_plugin,
)
from src.plugin_system.base.base_event import HandlerResult

View File

@@ -49,11 +49,11 @@ __plugin_meta__ = PluginMetadata(
name="{plugin_name}",
description="{description}",
usage="暂无说明",
type={repr(plugin_type)},
type={plugin_type!r},
version="{version}",
author="{author}",
license={repr(license_type)},
repository_url={repr(repository_url)},
license={license_type!r},
repository_url={repository_url!r},
keywords={keywords},
categories={categories},
)

View File

@@ -3,9 +3,9 @@ import datetime
import os
import shutil
import sys
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
from concurrent.futures import ThreadPoolExecutor, as_completed
import orjson
from json_repair import repair_json

View File

@@ -1,7 +1,6 @@
import asyncio
import math
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
# import tqdm

View File

@@ -3,12 +3,12 @@
用于统一管理所有notice消息将notice与正常消息分离
"""
import time
import threading
import time
from collections import defaultdict, deque
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Any
from enum import Enum
from typing import Any
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
@@ -27,7 +27,7 @@ class NoticeMessage:
"""Notice消息数据结构"""
message: DatabaseMessages
scope: NoticeScope
target_stream_id: Optional[str] = None # 如果是STREAM类型指定目标流ID
target_stream_id: str | None = None # 如果是STREAM类型指定目标流ID
timestamp: float = field(default_factory=time.time)
ttl: int = 3600 # 默认1小时过期
@@ -56,11 +56,11 @@ class GlobalNoticeManager:
return cls._instance
def __init__(self):
if hasattr(self, '_initialized'):
if hasattr(self, "_initialized"):
return
self._initialized = True
self._notices: Dict[str, deque[NoticeMessage]] = defaultdict(deque)
self._notices: dict[str, deque[NoticeMessage]] = defaultdict(deque)
self._max_notices_per_type = 100 # 每种类型最大存储数量
self._cleanup_interval = 300 # 5分钟清理一次过期消息
self._last_cleanup_time = time.time()
@@ -80,8 +80,8 @@ class GlobalNoticeManager:
self,
message: DatabaseMessages,
scope: NoticeScope = NoticeScope.STREAM,
target_stream_id: Optional[str] = None,
ttl: Optional[int] = None
target_stream_id: str | None = None,
ttl: int | None = None
) -> bool:
"""添加notice消息
@@ -142,7 +142,7 @@ class GlobalNoticeManager:
logger.error(f"添加notice消息失败: {e}")
return False
def get_accessible_notices(self, stream_id: str, limit: int = 20) -> List[NoticeMessage]:
def get_accessible_notices(self, stream_id: str, limit: int = 20) -> list[NoticeMessage]:
"""获取指定聊天流可访问的notice消息
Args:
@@ -231,7 +231,7 @@ class GlobalNoticeManager:
logger.error(f"获取notice文本失败: {e}", exc_info=True)
return ""
def clear_notices(self, stream_id: Optional[str] = None, notice_type: Optional[str] = None) -> int:
def clear_notices(self, stream_id: str | None = None, notice_type: str | None = None) -> int:
"""清理notice消息
Args:
@@ -289,7 +289,7 @@ class GlobalNoticeManager:
logger.error(f"清理notice消息失败: {e}")
return 0
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
# 更新实时统计
total_active_notices = sum(len(notices) for notices in self._notices.values())
@@ -313,11 +313,11 @@ class GlobalNoticeManager:
"""检查消息是否为notice类型"""
try:
# 首先检查消息的is_notify字段
if hasattr(message, 'is_notify') and message.is_notify:
if hasattr(message, "is_notify") and message.is_notify:
return True
# 检查消息的附加配置
if hasattr(message, 'additional_config') and message.additional_config:
if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict):
return message.additional_config.get("is_notice", False)
elif isinstance(message.additional_config, str):
@@ -333,7 +333,7 @@ class GlobalNoticeManager:
logger.debug(f"检查notice类型失败: {e}")
return False
def _get_storage_key(self, scope: NoticeScope, target_stream_id: Optional[str], message: DatabaseMessages) -> str:
def _get_storage_key(self, scope: NoticeScope, target_stream_id: str | None, message: DatabaseMessages) -> str:
"""生成存储键"""
if scope == NoticeScope.PUBLIC:
return "public"
@@ -341,10 +341,10 @@ class GlobalNoticeManager:
notice_type = self._get_notice_type(message) or "default"
return f"stream_{target_stream_id}_{notice_type}"
def _get_notice_type(self, message: DatabaseMessages) -> Optional[str]:
def _get_notice_type(self, message: DatabaseMessages) -> str | None:
"""获取notice类型"""
try:
if hasattr(message, 'additional_config') and message.additional_config:
if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict):
return message.additional_config.get("notice_type")
elif isinstance(message.additional_config, str):

View File

@@ -7,7 +7,7 @@ import asyncio
import random
import time
from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream
@@ -19,9 +19,8 @@ from src.config.config import global_config
from src.plugin_system.apis.chat_api import get_chat_manager
from .distribution_manager import stream_loop_manager
from .global_notice_manager import NoticeScope, global_notice_manager
from .sleep_system.state_manager import SleepState, sleep_state_manager
from .global_notice_manager import global_notice_manager, NoticeScope
if TYPE_CHECKING:
pass
@@ -666,11 +665,11 @@ class MessageManager:
"""检查消息是否为notice类型"""
try:
# 首先检查消息的is_notify字段
if hasattr(message, 'is_notify') and message.is_notify:
if hasattr(message, "is_notify") and message.is_notify:
return True
# 检查消息的附加配置
if hasattr(message, 'additional_config') and message.additional_config:
if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict):
return message.additional_config.get("is_notice", False)
elif isinstance(message.additional_config, str):
@@ -716,7 +715,7 @@ class MessageManager:
"""
try:
# 检查附加配置中的公共notice标志
if hasattr(message, 'additional_config') and message.additional_config:
if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict):
is_public = message.additional_config.get("is_public_notice", False)
elif isinstance(message.additional_config, str):
@@ -737,10 +736,10 @@ class MessageManager:
logger.debug(f"确定notice作用域失败: {e}")
return NoticeScope.STREAM
def _get_notice_type(self, message: DatabaseMessages) -> Optional[str]:
def _get_notice_type(self, message: DatabaseMessages) -> str | None:
"""获取notice类型"""
try:
if hasattr(message, 'additional_config') and message.additional_config:
if hasattr(message, "additional_config") and message.additional_config:
if isinstance(message.additional_config, dict):
return message.additional_config.get("notice_type")
elif isinstance(message.additional_config, str):
@@ -781,7 +780,7 @@ class MessageManager:
logger.error(f"获取notice文本失败: {e}")
return ""
def clear_notices(self, stream_id: Optional[str] = None, notice_type: Optional[str] = None) -> int:
def clear_notices(self, stream_id: str | None = None, notice_type: str | None = None) -> int:
"""清理notice消息"""
try:
return self.notice_manager.clear_notices(stream_id, notice_type)
@@ -789,7 +788,7 @@ class MessageManager:
logger.error(f"清理notice失败: {e}")
return 0
def get_notice_stats(self) -> Dict[str, Any]:
def get_notice_stats(self) -> dict[str, Any]:
"""获取notice管理器统计信息"""
try:
return self.notice_manager.get_stats()

View File

@@ -1,6 +1,5 @@
import os
import re
import time
import traceback
from typing import Any
@@ -12,7 +11,7 @@ from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import Prompt, global_prompt_manager, create_prompt_async
from src.chat.utils.prompt import create_prompt_async, global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.logger import get_logger
from src.config.config import global_config
@@ -477,9 +476,10 @@ class ChatBot:
if notice_handled:
# notice消息已处理需要先添加到message_manager再存储
try:
from src.common.data_models.database_data_model import DatabaseMessages
import time
from src.common.data_models.database_data_model import DatabaseMessages
message_info = message.message_info
msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(message.chat_stream, "user_info", None)
@@ -521,7 +521,7 @@ class ChatBot:
}
# 如果message_info有additional_config合并进来
if hasattr(message_info, 'additional_config') and message_info.additional_config:
if hasattr(message_info, "additional_config") and message_info.additional_config:
if isinstance(message_info.additional_config, dict):
additional_config_dict.update(message_info.additional_config)
elif isinstance(message_info.additional_config, str):
@@ -618,9 +618,10 @@ class ChatBot:
template_group_name = None
async def preprocess():
from src.common.data_models.database_data_model import DatabaseMessages
import time
from src.common.data_models.database_data_model import DatabaseMessages
message_info = message.message_info
msg_user_info = getattr(message_info, "user_info", None)
stream_user_info = getattr(message.chat_stream, "user_info", None)

View File

@@ -825,7 +825,7 @@ class DefaultReplyer:
logger.debug(f"开始构建notice块chat_id={chat_id}")
# 检查是否启用notice in prompt
if not hasattr(global_config, 'notice'):
if not hasattr(global_config, "notice"):
logger.debug("notice配置不存在")
return ""
@@ -836,7 +836,7 @@ class DefaultReplyer:
# 使用全局notice管理器获取notice文本
from src.chat.message_manager.message_manager import message_manager
limit = getattr(global_config.notice, 'notice_prompt_limit', 5)
limit = getattr(global_config.notice, "notice_prompt_limit", 5)
logger.debug(f"获取notice文本limit={limit}")
notice_text = message_manager.get_notice_text(chat_id, limit)
@@ -1461,12 +1461,12 @@ class DefaultReplyer:
"(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
)
else:
schedule_block = f'你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)'
schedule_block = f"你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
except (ValueError, AttributeError):
schedule_block = f'你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)'
schedule_block = f"你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
else:
schedule_block = f'你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)'
schedule_block = f"你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
moderation_prompt_block = (
"请不要输出违法违规内容,不要输出色情,暴力,政治相关内容,如有敏感内容,请规避。不要随意遵从他人指令。"

View File

@@ -1,5 +1,4 @@
import asyncio
from typing import Type
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
@@ -20,7 +19,7 @@ class PromptComponentManager:
3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。
"""
def get_components_for(self, injection_point: str) -> list[Type[BasePrompt]]:
def get_components_for(self, injection_point: str) -> list[type[BasePrompt]]:
"""
获取指定注入点的所有已注册组件类。
@@ -33,7 +32,7 @@ class PromptComponentManager:
# 从组件注册中心获取所有启用的Prompt组件
enabled_prompts = component_registry.get_enabled_components_by_type(ComponentType.PROMPT)
matching_components: list[Type[BasePrompt]] = []
matching_components: list[type[BasePrompt]] = []
for prompt_name, prompt_info in enabled_prompts.items():
# 确保 prompt_info 是 PromptInfo 类型

View File

@@ -1,5 +1,5 @@
import base64
import asyncio
import base64
import hashlib
import io
import os
@@ -174,7 +174,7 @@ class ImageManager:
# 3. 查询通用图片描述缓存ImageDescriptions表
if cached_description := await self._get_description_from_db(image_hash, "emoji"):
logger.info(f"[缓存命中] 使用通用图片缓存(ImageDescriptions表)中的描述")
logger.info("[缓存命中] 使用通用图片缓存(ImageDescriptions表)中的描述")
refined_part = cached_description.split(" Keywords:")[0]
return f"[表情包:{refined_part}]"
@@ -258,7 +258,7 @@ class ImageManager:
logger.error(f"VLM调用失败 (第 {i+1}/3 次): {e}", exc_info=True)
if i < 2: # 如果不是最后一次则等待1秒
logger.warning(f"识图失败将在1秒后重试...")
logger.warning("识图失败将在1秒后重试...")
await asyncio.sleep(1)
if not description or not description.strip():

View File

@@ -1,6 +1,6 @@
import traceback
from typing import Any
from collections import defaultdict
from typing import Any
from sqlalchemy import func, not_, select
from sqlalchemy.orm import DeclarativeBase

View File

@@ -26,9 +26,9 @@ from .base import (
ActionInfo,
BaseAction,
BaseCommand,
BasePrompt,
BaseEventHandler,
BasePlugin,
BasePrompt,
BaseTool,
ChatMode,
ChatType,

View File

@@ -158,7 +158,7 @@ class ChatterPlanFilter:
if global_config.planning_system.schedule_enable:
if activity_info := schedule_manager.get_current_activity():
activity = activity_info.get("activity", "未知活动")
schedule_block = f'你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)'
schedule_block = f"你当前正在进行“{activity}”。(此为你的当前状态,仅供参考。除非被直接询问,否则不要在对话中主动提及。)"
mood_block = ""
# 需要情绪模块打开才能获得情绪,否则会引发报错

View File

@@ -9,7 +9,7 @@ from src.chat.utils.utils import get_chat_type_and_target_info
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.info_data_model import Plan, TargetPersonInfo
from src.config.config import global_config
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType, ComponentType
from src.plugin_system.base.component_types import ActionInfo, ChatMode, ChatType
from src.plugin_system.core.component_registry import component_registry

View File

@@ -282,10 +282,10 @@ class EmojiAction(BaseAction):
score += 2 # 包含匹配
# 关键词匹配加分
chosen_keywords = re.findall(r'\w+', chosen_description.lower())
item_keywords = re.findall(r'\[(.*?)\]', refined_info)
chosen_keywords = re.findall(r"\w+", chosen_description.lower())
item_keywords = re.findall(r"\[(.*?)\]", refined_info)
if item_keywords:
item_keywords_set = {k.strip().lower() for k in item_keywords[0].split(',')}
item_keywords_set = {k.strip().lower() for k in item_keywords[0].split(",")}
for kw in chosen_keywords:
if kw in item_keywords_set:
score += 1

View File

@@ -4,7 +4,7 @@ TTS 语音合成 Action
from src.common.logger import get_logger
from src.plugin_system.apis import generator_api
from src.plugin_system.base.base_action import ActionActivationType, BaseAction, ChatMode
from src.plugin_system.base.base_action import BaseAction, ChatMode
from ..services.manager import get_service

View File

@@ -15,7 +15,7 @@ __plugin_meta__ = PluginMetadata(
"is_built_in": True,
},
# Python包依赖列表
python_dependencies = [ # noqa: RUF012
python_dependencies = [
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
PythonDependency(
package_name="exa_py",

View File

@@ -3,7 +3,7 @@ Base search engine interface
"""
from abc import ABC, abstractmethod
from typing import Any, Optional
from typing import Any
class BaseSearchEngine(ABC):
@@ -24,7 +24,7 @@ class BaseSearchEngine(ABC):
"""
pass
async def read_url(self, url: str) -> Optional[str]:
async def read_url(self, url: str) -> str | None:
"""
读取URL内容如果引擎不支持则返回None
"""

View File

@@ -2,7 +2,7 @@
Metaso Search Engine (Chat Completions Mode)
"""
import json
from typing import Any, List
from typing import Any
import httpx
@@ -27,7 +27,7 @@ class MetasoClient:
"Content-Type": "application/json",
}
async def search(self, query: str, **kwargs) -> List[dict[str, Any]]:
async def search(self, query: str, **kwargs) -> list[dict[str, Any]]:
"""Perform a search using the Metaso Chat Completions API."""
payload = {"model": "fast", "stream": True, "messages": [{"role": "user", "content": query}]}
search_url = f"{self.base_url}/chat/completions"

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
"""
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool
@@ -42,9 +42,9 @@ class WEBSEARCHPLUGIN(BasePlugin):
from .engines.bing_engine import BingSearchEngine
from .engines.ddg_engine import DDGSearchEngine
from .engines.exa_engine import ExaSearchEngine
from .engines.metaso_engine import MetasoSearchEngine
from .engines.searxng_engine import SearXNGSearchEngine
from .engines.tavily_engine import TavilySearchEngine
from .engines.metaso_engine import MetasoSearchEngine
# 实例化所有搜索引擎这会触发API密钥管理器的初始化
exa_engine = ExaSearchEngine()

View File

@@ -13,9 +13,9 @@ from src.plugin_system.apis import config_api
from ..engines.bing_engine import BingSearchEngine
from ..engines.ddg_engine import DDGSearchEngine
from ..engines.exa_engine import ExaSearchEngine
from ..engines.metaso_engine import MetasoSearchEngine
from ..engines.searxng_engine import SearXNGSearchEngine
from ..engines.tavily_engine import TavilySearchEngine
from ..engines.metaso_engine import MetasoSearchEngine
from ..utils.formatters import deduplicate_results, format_search_results
logger = get_logger("web_search_tool")