ruff归零
This commit is contained in:
6
bot.py
6
bot.py
@@ -612,9 +612,9 @@ async def wait_for_user_input():
|
||||
# 在非生产环境下,使用异步方式等待输入
|
||||
if os.getenv("ENVIRONMENT") != "production":
|
||||
logger.info("程序执行完成,按 Ctrl+C 退出...")
|
||||
# 简单的异步等待,避免阻塞事件循环
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
# 使用 Event 替代 sleep 循环,避免阻塞事件循环
|
||||
shutdown_event = asyncio.Event()
|
||||
await shutdown_event.wait()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("用户中断程序")
|
||||
return True
|
||||
|
||||
@@ -192,8 +192,8 @@ class BilibiliPlugin(BasePlugin):
|
||||
# 插件基本信息
|
||||
plugin_name: str = "bilibili_video_watcher"
|
||||
enable_plugin: bool = False
|
||||
dependencies: list[str] = []
|
||||
python_dependencies: list[str] = []
|
||||
dependencies: ClassVar[list[str]] = []
|
||||
python_dependencies: ClassVar[list[str]] = []
|
||||
config_file_name: str = "config.toml"
|
||||
|
||||
# 配置节描述
|
||||
|
||||
@@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import aiofiles
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
@@ -158,8 +159,9 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
with file_lock:
|
||||
if os.path.exists(temp_file_path):
|
||||
try:
|
||||
with open(temp_file_path, "rb") as f:
|
||||
return orjson.loads(f.read()), None
|
||||
async with aiofiles.open(temp_file_path, "rb") as f:
|
||||
content = await f.read()
|
||||
return orjson.loads(content), None
|
||||
except orjson.JSONDecodeError:
|
||||
os.remove(temp_file_path)
|
||||
|
||||
@@ -182,8 +184,8 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
||||
"extracted_triples": extracted_data.get("triples", []),
|
||||
}
|
||||
with file_lock:
|
||||
with open(temp_file_path, "wb") as f:
|
||||
f.write(orjson.dumps(doc_item))
|
||||
async with aiofiles.open(temp_file_path, "wb") as f:
|
||||
await f.write(orjson.dumps(doc_item))
|
||||
return doc_item, None
|
||||
except Exception as e:
|
||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
import orjson
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -729,8 +730,9 @@ class ExpressionLearnerManager:
|
||||
if not os.path.exists(expr_file):
|
||||
continue
|
||||
try:
|
||||
with open(expr_file, encoding="utf-8") as f:
|
||||
expressions = orjson.loads(f.read())
|
||||
async with aiofiles.open(expr_file, encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
expressions = orjson.loads(content)
|
||||
|
||||
if not isinstance(expressions, list):
|
||||
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
|
||||
@@ -791,8 +793,8 @@ class ExpressionLearnerManager:
|
||||
os.makedirs(done_parent_dir, exist_ok=True)
|
||||
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
||||
|
||||
with open(done_flag, "w", encoding="utf-8") as f:
|
||||
f.write("done\n")
|
||||
async with aiofiles.open(done_flag, "w", encoding="utf-8") as f:
|
||||
await f.write("done\n")
|
||||
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
||||
except PermissionError as e:
|
||||
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
||||
|
||||
@@ -4,6 +4,7 @@ import os
|
||||
from dataclasses import dataclass
|
||||
|
||||
# import tqdm
|
||||
import aiofiles
|
||||
import faiss
|
||||
import numpy as np
|
||||
import orjson
|
||||
@@ -194,8 +195,8 @@ class EmbeddingStore:
|
||||
test_vectors[str(idx)] = []
|
||||
|
||||
|
||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||
f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
async with aiofiles.open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||
await f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||
|
||||
logger.info("测试字符串嵌入向量保存完成")
|
||||
|
||||
|
||||
@@ -25,6 +25,9 @@ from src.llm_models.utils_model import LLMRequest
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# 全局背景任务集合
|
||||
_background_tasks = set()
|
||||
|
||||
|
||||
@dataclass
|
||||
class HippocampusSampleConfig:
|
||||
@@ -89,7 +92,9 @@ class HippocampusSampler:
|
||||
task_config = getattr(model_config.model_task_config, "utils", None)
|
||||
if task_config:
|
||||
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
|
||||
asyncio.create_task(self.start_background_sampling())
|
||||
task = asyncio.create_task(self.start_background_sampling())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
logger.info("✅ 海马体采样器初始化成功")
|
||||
else:
|
||||
raise RuntimeError("未找到记忆构建模型配置")
|
||||
|
||||
@@ -19,6 +19,9 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio
|
||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||
|
||||
# 全局背景任务集合
|
||||
_background_tasks = set()
|
||||
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
|
||||
|
||||
|
||||
@@ -1611,7 +1614,9 @@ class MemorySystem:
|
||||
def start_hippocampus_sampling(self):
|
||||
"""启动海马体采样"""
|
||||
if self.hippocampus_sampler:
|
||||
asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
|
||||
task = asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
logger.info("海马体后台采样已启动")
|
||||
else:
|
||||
logger.warning("海马体采样器未初始化,无法启动采样")
|
||||
|
||||
@@ -19,6 +19,9 @@ from .distribution_manager import stream_loop_manager
|
||||
|
||||
logger = get_logger("context_manager")
|
||||
|
||||
# 全局背景任务集合
|
||||
_background_tasks = set()
|
||||
|
||||
|
||||
class SingleStreamContextManager:
|
||||
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
|
||||
@@ -42,7 +45,9 @@ class SingleStreamContextManager:
|
||||
logger.debug(f"单流上下文管理器初始化: {stream_id}")
|
||||
|
||||
# 异步初始化历史消息(不阻塞构造函数)
|
||||
asyncio.create_task(self._initialize_history_from_db())
|
||||
task = asyncio.create_task(self._initialize_history_from_db())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
def get_context(self) -> StreamContext:
|
||||
"""获取流上下文"""
|
||||
@@ -93,7 +98,9 @@ class SingleStreamContextManager:
|
||||
logger.debug(f"消息已缓存,等待当前处理完成: stream={self.stream_id}")
|
||||
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||
task = asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
logger.debug(f"添加消息到缓存系统: {self.stream_id}")
|
||||
return True
|
||||
else:
|
||||
@@ -113,7 +120,9 @@ class SingleStreamContextManager:
|
||||
self.total_messages += 1
|
||||
self.last_access_time = time.time()
|
||||
# 启动流的循环任务(如果还未启动)
|
||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||
task = asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
|
||||
@@ -3,6 +3,8 @@ from collections import defaultdict
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
|
||||
from src.common.database.compatibility import db_get, db_query
|
||||
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
||||
from src.common.logger import get_logger
|
||||
@@ -1002,8 +1004,8 @@ class StatisticOutputTask(AsyncTask):
|
||||
"""
|
||||
)
|
||||
|
||||
with open(self.record_file_path, "w", encoding="utf-8") as f:
|
||||
f.write(html_template)
|
||||
async with aiofiles.open(self.record_file_path, "w", encoding="utf-8") as f:
|
||||
await f.write(html_template)
|
||||
|
||||
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||
"""生成图表数据 (异步)"""
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
@@ -198,8 +199,8 @@ class ImageManager:
|
||||
os.makedirs(emoji_dir, exist_ok=True)
|
||||
file_path = os.path.join(emoji_dir, filename)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(image_bytes)
|
||||
logger.info(f"新表情包已保存至待注册目录: {file_path}")
|
||||
except Exception as e:
|
||||
logger.error(f"保存待注册表情包文件失败: {e!s}")
|
||||
@@ -436,8 +437,8 @@ class ImageManager:
|
||||
os.makedirs(image_dir, exist_ok=True)
|
||||
file_path = os.path.join(image_dir, filename)
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(image_bytes)
|
||||
|
||||
new_img = Images(
|
||||
image_id=image_id,
|
||||
|
||||
@@ -214,9 +214,9 @@ class AdaptiveBatchScheduler:
|
||||
for priority in sorted(Priority, reverse=True):
|
||||
queue = self.operation_queues[priority]
|
||||
count = min(len(queue), self.current_batch_size - len(operations))
|
||||
for _ in range(count):
|
||||
if queue:
|
||||
operations.append(queue.popleft())
|
||||
if queue and count > 0:
|
||||
# 使用 list.extend 代替循环 append
|
||||
operations.extend(queue.popleft() for _ in range(count))
|
||||
|
||||
if not operations:
|
||||
return
|
||||
|
||||
@@ -9,7 +9,6 @@
|
||||
from .decorators import (
|
||||
cached,
|
||||
db_operation,
|
||||
generate_cache_key,
|
||||
measure_time,
|
||||
retry,
|
||||
timeout,
|
||||
|
||||
@@ -19,6 +19,9 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message import get_global_api
|
||||
|
||||
# 全局背景任务集合
|
||||
_background_tasks = set()
|
||||
from src.common.remote import TelemetryHeartBeatTask
|
||||
from src.common.server import Server, get_global_server
|
||||
from src.config.config import global_config
|
||||
@@ -461,7 +464,9 @@ MoFox_Bot(第三方修改版)
|
||||
logger.info("情绪管理器初始化成功")
|
||||
|
||||
# 启动聊天管理器的自动保存任务
|
||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||
task = asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
# 初始化增强记忆系统
|
||||
if global_config.memory.enable_memory:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, ClassVar
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
@@ -33,7 +33,7 @@ class BaseCommand(PlusCommand):
|
||||
"""命令匹配的正则表达式"""
|
||||
|
||||
# 用于存储正则匹配组
|
||||
matched_groups: dict[str, str] = {}
|
||||
matched_groups: ClassVar[dict[str, str]] = {}
|
||||
|
||||
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||
"""初始化Command组件"""
|
||||
|
||||
@@ -14,6 +14,9 @@ from .component_registry import component_registry
|
||||
|
||||
logger = get_logger("plugin_manager")
|
||||
|
||||
# 全局背景任务集合
|
||||
_background_tasks = set()
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""
|
||||
@@ -142,7 +145,9 @@ class PluginManager:
|
||||
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
|
||||
try:
|
||||
# 使用 asyncio.create_task 确保它不会阻塞加载流程
|
||||
asyncio.create_task(plugin_instance.on_plugin_loaded())
|
||||
task = asyncio.create_task(plugin_instance.on_plugin_loaded())
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except Exception as e:
|
||||
logger.error(f"调用插件 '{plugin_name}' 的 on_plugin_loaded 钩子时出错: {e}")
|
||||
|
||||
|
||||
@@ -14,6 +14,9 @@ from src.config.config import global_config
|
||||
|
||||
logger = get_logger("plan_executor")
|
||||
|
||||
# 全局背景任务集合
|
||||
_background_tasks = set()
|
||||
|
||||
|
||||
class ChatterPlanExecutor:
|
||||
"""
|
||||
@@ -89,7 +92,9 @@ class ChatterPlanExecutor:
|
||||
|
||||
# 将其他动作放入后台任务执行,避免阻塞主流程
|
||||
if other_actions:
|
||||
asyncio.create_task(self._execute_other_actions(other_actions, plan))
|
||||
task = asyncio.create_task(self._execute_other_actions(other_actions, plan))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
logger.info(f"已将 {len(other_actions)} 个其他动作放入后台任务执行。")
|
||||
# 注意:后台任务的结果不会立即计入本次返回的统计数据
|
||||
|
||||
|
||||
@@ -11,6 +11,9 @@ from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||
from src.plugin_system.base.component_types import PermissionNodeField
|
||||
from src.plugin_system.base.config_types import ConfigField
|
||||
|
||||
# 全局背景任务集合
|
||||
_background_tasks = set()
|
||||
|
||||
from .actions.read_feed_action import ReadFeedAction
|
||||
from .actions.send_feed_action import SendFeedAction
|
||||
from .commands.send_feed_command import SendFeedCommand
|
||||
@@ -117,8 +120,14 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
||||
logger.info("MaiZone重构版插件服务已注册。")
|
||||
|
||||
# --- 启动后台任务 ---
|
||||
asyncio.create_task(scheduler_service.start())
|
||||
asyncio.create_task(monitor_service.start())
|
||||
task1 = asyncio.create_task(scheduler_service.start())
|
||||
_background_tasks.add(task1)
|
||||
task1.add_done_callback(_background_tasks.discard)
|
||||
|
||||
task2 = asyncio.create_task(monitor_service.start())
|
||||
_background_tasks.add(task2)
|
||||
task2.add_done_callback(_background_tasks.discard)
|
||||
|
||||
logger.info("MaiZone后台监控和定时任务已启动。")
|
||||
|
||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||
|
||||
@@ -7,6 +7,7 @@ import base64
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
|
||||
from src.common.logger import get_logger
|
||||
@@ -86,8 +87,8 @@ class ImageService:
|
||||
if b64_json:
|
||||
image_bytes = base64.b64decode(b64_json)
|
||||
file_path = Path(image_dir) / f"image_{i + 1}.png"
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
async with aiofiles.open(file_path, "wb") as f:
|
||||
await f.write(image_bytes)
|
||||
logger.info(f"成功保存AI图片到: {file_path}")
|
||||
return True
|
||||
else:
|
||||
|
||||
@@ -12,6 +12,7 @@ from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import aiofiles
|
||||
import aiohttp
|
||||
import bs4
|
||||
import json5
|
||||
@@ -397,8 +398,8 @@ class QZoneService:
|
||||
}
|
||||
# 成功获取后,异步写入本地文件作为备份
|
||||
try:
|
||||
with open(cookie_file_path, "wb") as f:
|
||||
f.write(orjson.dumps(parsed_cookies))
|
||||
async with aiofiles.open(cookie_file_path, "wb") as f:
|
||||
await f.write(orjson.dumps(parsed_cookies))
|
||||
logger.info(f"通过Napcat服务成功更新Cookie,并已保存至: {cookie_file_path}")
|
||||
except Exception as e:
|
||||
logger.warning(f"保存Cookie到文件时出错: {e}")
|
||||
@@ -413,8 +414,9 @@ class QZoneService:
|
||||
logger.info("尝试从本地Cookie文件加载...")
|
||||
if cookie_file_path.exists():
|
||||
try:
|
||||
with open(cookie_file_path, "rb") as f:
|
||||
cookies = orjson.loads(f.read())
|
||||
async with aiofiles.open(cookie_file_path, "rb") as f:
|
||||
content = await f.read()
|
||||
cookies = orjson.loads(content)
|
||||
logger.info(f"成功从本地文件加载Cookie: {cookie_file_path}")
|
||||
return cookies
|
||||
except Exception as e:
|
||||
|
||||
@@ -13,6 +13,8 @@ logger = get_logger("stt_whisper_plugin")
|
||||
# 全局变量来缓存模型,避免重复加载
|
||||
_whisper_model = None
|
||||
_is_loading = False
|
||||
_model_ready_event = asyncio.Event()
|
||||
_background_tasks = set() # 背景任务集合
|
||||
|
||||
class LocalASRTool(BaseTool):
|
||||
"""
|
||||
@@ -29,7 +31,7 @@ class LocalASRTool(BaseTool):
|
||||
"""
|
||||
一个类方法,用于在插件加载时触发一次模型加载。
|
||||
"""
|
||||
global _whisper_model, _is_loading
|
||||
global _whisper_model, _is_loading, _model_ready_event
|
||||
if _whisper_model is None and not _is_loading:
|
||||
_is_loading = True
|
||||
try:
|
||||
@@ -47,6 +49,7 @@ class LocalASRTool(BaseTool):
|
||||
_whisper_model = None
|
||||
finally:
|
||||
_is_loading = False
|
||||
_model_ready_event.set() # 通知等待的任务
|
||||
|
||||
async def execute(self, function_args: dict) -> str:
|
||||
audio_path = function_args.get("audio_path")
|
||||
@@ -55,9 +58,9 @@ class LocalASRTool(BaseTool):
|
||||
return "错误:缺少 audio_path 参数。"
|
||||
|
||||
global _whisper_model
|
||||
# 增强的等待逻辑:只要模型还没准备好,就一直等待后台加载任务完成
|
||||
while _is_loading:
|
||||
await asyncio.sleep(0.2)
|
||||
# 使用 Event 等待模型加载完成
|
||||
if _is_loading:
|
||||
await _model_ready_event.wait()
|
||||
|
||||
if _whisper_model is None:
|
||||
return "Whisper 模型加载失败,无法识别语音。"
|
||||
@@ -90,7 +93,9 @@ class STTWhisperPlugin(BasePlugin):
|
||||
from src.config.config import global_config
|
||||
if global_config.voice.asr_provider == "local":
|
||||
# 使用 create_task 在后台开始加载,不阻塞主流程
|
||||
asyncio.create_task(LocalASRTool.load_model_once(self.config or {}))
|
||||
task = asyncio.create_task(LocalASRTool.load_model_once(self.config or {}))
|
||||
_background_tasks.add(task)
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
except Exception as e:
|
||||
logger.error(f"触发 Whisper 模型预加载时出错: {e}")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user