ruff归零

This commit is contained in:
明天好像没什么
2025-11-01 21:32:41 +08:00
parent ce9727bdb0
commit 30658afdb4
20 changed files with 106 additions and 48 deletions

6
bot.py
View File

@@ -612,9 +612,9 @@ async def wait_for_user_input():
# 在非生产环境下,使用异步方式等待输入
if os.getenv("ENVIRONMENT") != "production":
logger.info("程序执行完成,按 Ctrl+C 退出...")
# 简单的异步等待,避免阻塞事件循环
while True:
await asyncio.sleep(1)
# 使用 Event 替代 sleep 循环,避免阻塞事件循环
shutdown_event = asyncio.Event()
await shutdown_event.wait()
except KeyboardInterrupt:
logger.info("用户中断程序")
return True

View File

@@ -192,8 +192,8 @@ class BilibiliPlugin(BasePlugin):
# 插件基本信息
plugin_name: str = "bilibili_video_watcher"
enable_plugin: bool = False
dependencies: list[str] = []
python_dependencies: list[str] = []
dependencies: ClassVar[list[str]] = []
python_dependencies: ClassVar[list[str]] = []
config_file_name: str = "config.toml"
# 配置节描述

View File

@@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
from threading import Lock
import aiofiles
import orjson
from json_repair import repair_json
@@ -158,8 +159,9 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
with file_lock:
if os.path.exists(temp_file_path):
try:
with open(temp_file_path, "rb") as f:
return orjson.loads(f.read()), None
async with aiofiles.open(temp_file_path, "rb") as f:
content = await f.read()
return orjson.loads(content), None
except orjson.JSONDecodeError:
os.remove(temp_file_path)
@@ -182,8 +184,8 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
"extracted_triples": extracted_data.get("triples", []),
}
with file_lock:
with open(temp_file_path, "wb") as f:
f.write(orjson.dumps(doc_item))
async with aiofiles.open(temp_file_path, "wb") as f:
await f.write(orjson.dumps(doc_item))
return doc_item, None
except Exception as e:
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")

View File

@@ -4,6 +4,7 @@ import time
from datetime import datetime
from typing import Any
import aiofiles
import orjson
from sqlalchemy import select
@@ -729,8 +730,9 @@ class ExpressionLearnerManager:
if not os.path.exists(expr_file):
continue
try:
with open(expr_file, encoding="utf-8") as f:
expressions = orjson.loads(f.read())
async with aiofiles.open(expr_file, encoding="utf-8") as f:
content = await f.read()
expressions = orjson.loads(content)
if not isinstance(expressions, list):
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
@@ -791,8 +793,8 @@ class ExpressionLearnerManager:
os.makedirs(done_parent_dir, exist_ok=True)
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
async with aiofiles.open(done_flag, "w", encoding="utf-8") as f:
await f.write("done\n")
logger.info(f"表达方式JSON迁移已完成共迁移 {migrated_count} 个表达方式已写入done.done标记文件")
except PermissionError as e:
logger.error(f"权限不足无法写入done.done标记文件: {e}")

View File

@@ -4,6 +4,7 @@ import os
from dataclasses import dataclass
# import tqdm
import aiofiles
import faiss
import numpy as np
import orjson
@@ -194,8 +195,8 @@ class EmbeddingStore:
test_vectors[str(idx)] = []
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
async with aiofiles.open(self.get_test_file_path(), "w", encoding="utf-8") as f:
await f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
logger.info("测试字符串嵌入向量保存完成")

View File

@@ -25,6 +25,9 @@ from src.llm_models.utils_model import LLMRequest
logger = get_logger(__name__)
# 全局背景任务集合
_background_tasks = set()
@dataclass
class HippocampusSampleConfig:
@@ -89,7 +92,9 @@ class HippocampusSampler:
task_config = getattr(model_config.model_task_config, "utils", None)
if task_config:
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
asyncio.create_task(self.start_background_sampling())
task = asyncio.create_task(self.start_background_sampling())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.info("✅ 海马体采样器初始化成功")
else:
raise RuntimeError("未找到记忆构建模型配置")

View File

@@ -19,6 +19,9 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
# 全局背景任务集合
_background_tasks = set()
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
@@ -1611,7 +1614,9 @@ class MemorySystem:
def start_hippocampus_sampling(self):
"""启动海马体采样"""
if self.hippocampus_sampler:
asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
task = asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.info("海马体后台采样已启动")
else:
logger.warning("海马体采样器未初始化,无法启动采样")

View File

@@ -19,6 +19,9 @@ from .distribution_manager import stream_loop_manager
logger = get_logger("context_manager")
# 全局背景任务集合
_background_tasks = set()
class SingleStreamContextManager:
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
@@ -42,7 +45,9 @@ class SingleStreamContextManager:
logger.debug(f"单流上下文管理器初始化: {stream_id}")
# 异步初始化历史消息(不阻塞构造函数)
asyncio.create_task(self._initialize_history_from_db())
task = asyncio.create_task(self._initialize_history_from_db())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
def get_context(self) -> StreamContext:
"""获取流上下文"""
@@ -93,7 +98,9 @@ class SingleStreamContextManager:
logger.debug(f"消息已缓存,等待当前处理完成: stream={self.stream_id}")
# 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
task = asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.debug(f"添加消息到缓存系统: {self.stream_id}")
return True
else:
@@ -113,7 +120,9 @@ class SingleStreamContextManager:
self.total_messages += 1
self.last_access_time = time.time()
# 启动流的循环任务(如果还未启动)
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
task = asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True
except Exception as e:

View File

@@ -3,6 +3,8 @@ from collections import defaultdict
from datetime import datetime, timedelta
from typing import Any
import aiofiles
from src.common.database.compatibility import db_get, db_query
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
from src.common.logger import get_logger
@@ -1002,8 +1004,8 @@ class StatisticOutputTask(AsyncTask):
"""
)
with open(self.record_file_path, "w", encoding="utf-8") as f:
f.write(html_template)
async with aiofiles.open(self.record_file_path, "w", encoding="utf-8") as f:
await f.write(html_template)
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
"""生成图表数据 (异步)"""

View File

@@ -7,6 +7,7 @@ import time
import uuid
from typing import Any
import aiofiles
import numpy as np
from PIL import Image
from rich.traceback import install
@@ -198,8 +199,8 @@ class ImageManager:
os.makedirs(emoji_dir, exist_ok=True)
file_path = os.path.join(emoji_dir, filename)
with open(file_path, "wb") as f:
f.write(image_bytes)
async with aiofiles.open(file_path, "wb") as f:
await f.write(image_bytes)
logger.info(f"新表情包已保存至待注册目录: {file_path}")
except Exception as e:
logger.error(f"保存待注册表情包文件失败: {e!s}")
@@ -436,8 +437,8 @@ class ImageManager:
os.makedirs(image_dir, exist_ok=True)
file_path = os.path.join(image_dir, filename)
with open(file_path, "wb") as f:
f.write(image_bytes)
async with aiofiles.open(file_path, "wb") as f:
await f.write(image_bytes)
new_img = Images(
image_id=image_id,

View File

@@ -214,9 +214,9 @@ class AdaptiveBatchScheduler:
for priority in sorted(Priority, reverse=True):
queue = self.operation_queues[priority]
count = min(len(queue), self.current_batch_size - len(operations))
for _ in range(count):
if queue:
operations.append(queue.popleft())
if queue and count > 0:
# 使用 list.extend 代替循环 append
operations.extend(queue.popleft() for _ in range(count))
if not operations:
return

View File

@@ -9,7 +9,6 @@
from .decorators import (
cached,
db_operation,
generate_cache_key,
measure_time,
retry,
timeout,

View File

@@ -19,6 +19,9 @@ from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from src.common.logger import get_logger
from src.common.message import get_global_api
# 全局背景任务集合
_background_tasks = set()
from src.common.remote import TelemetryHeartBeatTask
from src.common.server import Server, get_global_server
from src.config.config import global_config
@@ -461,7 +464,9 @@ MoFox_Bot(第三方修改版)
logger.info("情绪管理器初始化成功")
# 启动聊天管理器的自动保存任务
asyncio.create_task(get_chat_manager()._auto_save_task())
task = asyncio.create_task(get_chat_manager()._auto_save_task())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
# 初始化增强记忆系统
if global_config.memory.enable_memory:

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, ClassVar
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger
@@ -33,7 +33,7 @@ class BaseCommand(PlusCommand):
"""命令匹配的正则表达式"""
# 用于存储正则匹配组
matched_groups: dict[str, str] = {}
matched_groups: ClassVar[dict[str, str]] = {}
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
"""初始化Command组件"""

View File

@@ -14,6 +14,9 @@ from .component_registry import component_registry
logger = get_logger("plugin_manager")
# 全局背景任务集合
_background_tasks = set()
class PluginManager:
"""
@@ -142,7 +145,9 @@ class PluginManager:
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
try:
# 使用 asyncio.create_task 确保它不会阻塞加载流程
asyncio.create_task(plugin_instance.on_plugin_loaded())
task = asyncio.create_task(plugin_instance.on_plugin_loaded())
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except Exception as e:
logger.error(f"调用插件 '{plugin_name}' 的 on_plugin_loaded 钩子时出错: {e}")

View File

@@ -14,6 +14,9 @@ from src.config.config import global_config
logger = get_logger("plan_executor")
# 全局背景任务集合
_background_tasks = set()
class ChatterPlanExecutor:
"""
@@ -89,7 +92,9 @@ class ChatterPlanExecutor:
# 将其他动作放入后台任务执行,避免阻塞主流程
if other_actions:
asyncio.create_task(self._execute_other_actions(other_actions, plan))
task = asyncio.create_task(self._execute_other_actions(other_actions, plan))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
logger.info(f"已将 {len(other_actions)} 个其他动作放入后台任务执行。")
# 注意:后台任务的结果不会立即计入本次返回的统计数据

View File

@@ -11,6 +11,9 @@ from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
from src.plugin_system.base.component_types import PermissionNodeField
from src.plugin_system.base.config_types import ConfigField
# 全局背景任务集合
_background_tasks = set()
from .actions.read_feed_action import ReadFeedAction
from .actions.send_feed_action import SendFeedAction
from .commands.send_feed_command import SendFeedCommand
@@ -117,8 +120,14 @@ class MaiZoneRefactoredPlugin(BasePlugin):
logger.info("MaiZone重构版插件服务已注册。")
# --- 启动后台任务 ---
asyncio.create_task(scheduler_service.start())
asyncio.create_task(monitor_service.start())
task1 = asyncio.create_task(scheduler_service.start())
_background_tasks.add(task1)
task1.add_done_callback(_background_tasks.discard)
task2 = asyncio.create_task(monitor_service.start())
_background_tasks.add(task2)
task2.add_done_callback(_background_tasks.discard)
logger.info("MaiZone后台监控和定时任务已启动。")
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:

View File

@@ -7,6 +7,7 @@ import base64
from collections.abc import Callable
from pathlib import Path
import aiofiles
import aiohttp
from src.common.logger import get_logger
@@ -86,8 +87,8 @@ class ImageService:
if b64_json:
image_bytes = base64.b64decode(b64_json)
file_path = Path(image_dir) / f"image_{i + 1}.png"
with open(file_path, "wb") as f:
f.write(image_bytes)
async with aiofiles.open(file_path, "wb") as f:
await f.write(image_bytes)
logger.info(f"成功保存AI图片到: {file_path}")
return True
else:

View File

@@ -12,6 +12,7 @@ from collections.abc import Callable
from pathlib import Path
from typing import Any
import aiofiles
import aiohttp
import bs4
import json5
@@ -397,8 +398,8 @@ class QZoneService:
}
# 成功获取后,异步写入本地文件作为备份
try:
with open(cookie_file_path, "wb") as f:
f.write(orjson.dumps(parsed_cookies))
async with aiofiles.open(cookie_file_path, "wb") as f:
await f.write(orjson.dumps(parsed_cookies))
logger.info(f"通过Napcat服务成功更新Cookie并已保存至: {cookie_file_path}")
except Exception as e:
logger.warning(f"保存Cookie到文件时出错: {e}")
@@ -413,8 +414,9 @@ class QZoneService:
logger.info("尝试从本地Cookie文件加载...")
if cookie_file_path.exists():
try:
with open(cookie_file_path, "rb") as f:
cookies = orjson.loads(f.read())
async with aiofiles.open(cookie_file_path, "rb") as f:
content = await f.read()
cookies = orjson.loads(content)
logger.info(f"成功从本地文件加载Cookie: {cookie_file_path}")
return cookies
except Exception as e:

View File

@@ -13,6 +13,8 @@ logger = get_logger("stt_whisper_plugin")
# 全局变量来缓存模型,避免重复加载
_whisper_model = None
_is_loading = False
_model_ready_event = asyncio.Event()
_background_tasks = set() # 背景任务集合
class LocalASRTool(BaseTool):
"""
@@ -29,7 +31,7 @@ class LocalASRTool(BaseTool):
"""
一个类方法,用于在插件加载时触发一次模型加载。
"""
global _whisper_model, _is_loading
global _whisper_model, _is_loading, _model_ready_event
if _whisper_model is None and not _is_loading:
_is_loading = True
try:
@@ -47,6 +49,7 @@ class LocalASRTool(BaseTool):
_whisper_model = None
finally:
_is_loading = False
_model_ready_event.set() # 通知等待的任务
async def execute(self, function_args: dict) -> str:
audio_path = function_args.get("audio_path")
@@ -55,9 +58,9 @@ class LocalASRTool(BaseTool):
return "错误:缺少 audio_path 参数。"
global _whisper_model
# 增强的等待逻辑:只要模型还没准备好,就一直等待后台加载任务完成
while _is_loading:
await asyncio.sleep(0.2)
# 使用 Event 等待模型加载完成
if _is_loading:
await _model_ready_event.wait()
if _whisper_model is None:
return "Whisper 模型加载失败,无法识别语音。"
@@ -90,7 +93,9 @@ class STTWhisperPlugin(BasePlugin):
from src.config.config import global_config
if global_config.voice.asr_provider == "local":
# 使用 create_task 在后台开始加载,不阻塞主流程
asyncio.create_task(LocalASRTool.load_model_once(self.config or {}))
task = asyncio.create_task(LocalASRTool.load_model_once(self.config or {}))
_background_tasks.add(task)
task.add_done_callback(_background_tasks.discard)
except Exception as e:
logger.error(f"触发 Whisper 模型预加载时出错: {e}")