ruff归零
This commit is contained in:
6
bot.py
6
bot.py
@@ -612,9 +612,9 @@ async def wait_for_user_input():
|
|||||||
# 在非生产环境下,使用异步方式等待输入
|
# 在非生产环境下,使用异步方式等待输入
|
||||||
if os.getenv("ENVIRONMENT") != "production":
|
if os.getenv("ENVIRONMENT") != "production":
|
||||||
logger.info("程序执行完成,按 Ctrl+C 退出...")
|
logger.info("程序执行完成,按 Ctrl+C 退出...")
|
||||||
# 简单的异步等待,避免阻塞事件循环
|
# 使用 Event 替代 sleep 循环,避免阻塞事件循环
|
||||||
while True:
|
shutdown_event = asyncio.Event()
|
||||||
await asyncio.sleep(1)
|
await shutdown_event.wait()
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("用户中断程序")
|
logger.info("用户中断程序")
|
||||||
return True
|
return True
|
||||||
|
|||||||
@@ -192,8 +192,8 @@ class BilibiliPlugin(BasePlugin):
|
|||||||
# 插件基本信息
|
# 插件基本信息
|
||||||
plugin_name: str = "bilibili_video_watcher"
|
plugin_name: str = "bilibili_video_watcher"
|
||||||
enable_plugin: bool = False
|
enable_plugin: bool = False
|
||||||
dependencies: list[str] = []
|
dependencies: ClassVar[list[str]] = []
|
||||||
python_dependencies: list[str] = []
|
python_dependencies: ClassVar[list[str]] = []
|
||||||
config_file_name: str = "config.toml"
|
config_file_name: str = "config.toml"
|
||||||
|
|
||||||
# 配置节描述
|
# 配置节描述
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import orjson
|
import orjson
|
||||||
from json_repair import repair_json
|
from json_repair import repair_json
|
||||||
|
|
||||||
@@ -158,8 +159,9 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
|||||||
with file_lock:
|
with file_lock:
|
||||||
if os.path.exists(temp_file_path):
|
if os.path.exists(temp_file_path):
|
||||||
try:
|
try:
|
||||||
with open(temp_file_path, "rb") as f:
|
async with aiofiles.open(temp_file_path, "rb") as f:
|
||||||
return orjson.loads(f.read()), None
|
content = await f.read()
|
||||||
|
return orjson.loads(content), None
|
||||||
except orjson.JSONDecodeError:
|
except orjson.JSONDecodeError:
|
||||||
os.remove(temp_file_path)
|
os.remove(temp_file_path)
|
||||||
|
|
||||||
@@ -182,8 +184,8 @@ async def extract_info_async(pg_hash, paragraph, llm_api):
|
|||||||
"extracted_triples": extracted_data.get("triples", []),
|
"extracted_triples": extracted_data.get("triples", []),
|
||||||
}
|
}
|
||||||
with file_lock:
|
with file_lock:
|
||||||
with open(temp_file_path, "wb") as f:
|
async with aiofiles.open(temp_file_path, "wb") as f:
|
||||||
f.write(orjson.dumps(doc_item))
|
await f.write(orjson.dumps(doc_item))
|
||||||
return doc_item, None
|
return doc_item, None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
logger.error(f"提取信息失败:{pg_hash}, 错误:{e}")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import time
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import orjson
|
import orjson
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
@@ -729,8 +730,9 @@ class ExpressionLearnerManager:
|
|||||||
if not os.path.exists(expr_file):
|
if not os.path.exists(expr_file):
|
||||||
continue
|
continue
|
||||||
try:
|
try:
|
||||||
with open(expr_file, encoding="utf-8") as f:
|
async with aiofiles.open(expr_file, encoding="utf-8") as f:
|
||||||
expressions = orjson.loads(f.read())
|
content = await f.read()
|
||||||
|
expressions = orjson.loads(content)
|
||||||
|
|
||||||
if not isinstance(expressions, list):
|
if not isinstance(expressions, list):
|
||||||
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
|
logger.warning(f"表达方式文件格式错误,跳过: {expr_file}")
|
||||||
@@ -791,8 +793,8 @@ class ExpressionLearnerManager:
|
|||||||
os.makedirs(done_parent_dir, exist_ok=True)
|
os.makedirs(done_parent_dir, exist_ok=True)
|
||||||
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
logger.debug(f"为done.done创建父目录: {done_parent_dir}")
|
||||||
|
|
||||||
with open(done_flag, "w", encoding="utf-8") as f:
|
async with aiofiles.open(done_flag, "w", encoding="utf-8") as f:
|
||||||
f.write("done\n")
|
await f.write("done\n")
|
||||||
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
logger.info(f"表达方式JSON迁移已完成,共迁移 {migrated_count} 个表达方式,已写入done.done标记文件")
|
||||||
except PermissionError as e:
|
except PermissionError as e:
|
||||||
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
logger.error(f"权限不足,无法写入done.done标记文件: {e}")
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import os
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
# import tqdm
|
# import tqdm
|
||||||
|
import aiofiles
|
||||||
import faiss
|
import faiss
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import orjson
|
import orjson
|
||||||
@@ -194,8 +195,8 @@ class EmbeddingStore:
|
|||||||
test_vectors[str(idx)] = []
|
test_vectors[str(idx)] = []
|
||||||
|
|
||||||
|
|
||||||
with open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
async with aiofiles.open(self.get_test_file_path(), "w", encoding="utf-8") as f:
|
||||||
f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
await f.write(orjson.dumps(test_vectors, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
||||||
|
|
||||||
logger.info("测试字符串嵌入向量保存完成")
|
logger.info("测试字符串嵌入向量保存完成")
|
||||||
|
|
||||||
|
|||||||
@@ -25,6 +25,9 @@ from src.llm_models.utils_model import LLMRequest
|
|||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
# 全局背景任务集合
|
||||||
|
_background_tasks = set()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class HippocampusSampleConfig:
|
class HippocampusSampleConfig:
|
||||||
@@ -89,7 +92,9 @@ class HippocampusSampler:
|
|||||||
task_config = getattr(model_config.model_task_config, "utils", None)
|
task_config = getattr(model_config.model_task_config, "utils", None)
|
||||||
if task_config:
|
if task_config:
|
||||||
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
|
self.memory_builder_model = LLMRequest(model_set=task_config, request_type="memory.hippocampus_build")
|
||||||
asyncio.create_task(self.start_background_sampling())
|
task = asyncio.create_task(self.start_background_sampling())
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
logger.info("✅ 海马体采样器初始化成功")
|
logger.info("✅ 海马体采样器初始化成功")
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("未找到记忆构建模型配置")
|
raise RuntimeError("未找到记忆构建模型配置")
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio
|
|||||||
from src.chat.memory_system.memory_chunk import MemoryChunk
|
from src.chat.memory_system.memory_chunk import MemoryChunk
|
||||||
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
|
||||||
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
|
||||||
|
|
||||||
|
# 全局背景任务集合
|
||||||
|
_background_tasks = set()
|
||||||
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
|
from src.chat.memory_system.message_collection_storage import MessageCollectionStorage
|
||||||
|
|
||||||
|
|
||||||
@@ -1611,7 +1614,9 @@ class MemorySystem:
|
|||||||
def start_hippocampus_sampling(self):
|
def start_hippocampus_sampling(self):
|
||||||
"""启动海马体采样"""
|
"""启动海马体采样"""
|
||||||
if self.hippocampus_sampler:
|
if self.hippocampus_sampler:
|
||||||
asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
|
task = asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
logger.info("海马体后台采样已启动")
|
logger.info("海马体后台采样已启动")
|
||||||
else:
|
else:
|
||||||
logger.warning("海马体采样器未初始化,无法启动采样")
|
logger.warning("海马体采样器未初始化,无法启动采样")
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ from .distribution_manager import stream_loop_manager
|
|||||||
|
|
||||||
logger = get_logger("context_manager")
|
logger = get_logger("context_manager")
|
||||||
|
|
||||||
|
# 全局背景任务集合
|
||||||
|
_background_tasks = set()
|
||||||
|
|
||||||
|
|
||||||
class SingleStreamContextManager:
|
class SingleStreamContextManager:
|
||||||
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
|
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
|
||||||
@@ -42,7 +45,9 @@ class SingleStreamContextManager:
|
|||||||
logger.debug(f"单流上下文管理器初始化: {stream_id}")
|
logger.debug(f"单流上下文管理器初始化: {stream_id}")
|
||||||
|
|
||||||
# 异步初始化历史消息(不阻塞构造函数)
|
# 异步初始化历史消息(不阻塞构造函数)
|
||||||
asyncio.create_task(self._initialize_history_from_db())
|
task = asyncio.create_task(self._initialize_history_from_db())
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
def get_context(self) -> StreamContext:
|
def get_context(self) -> StreamContext:
|
||||||
"""获取流上下文"""
|
"""获取流上下文"""
|
||||||
@@ -93,7 +98,9 @@ class SingleStreamContextManager:
|
|||||||
logger.debug(f"消息已缓存,等待当前处理完成: stream={self.stream_id}")
|
logger.debug(f"消息已缓存,等待当前处理完成: stream={self.stream_id}")
|
||||||
|
|
||||||
# 启动流的循环任务(如果还未启动)
|
# 启动流的循环任务(如果还未启动)
|
||||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
task = asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
logger.debug(f"添加消息到缓存系统: {self.stream_id}")
|
logger.debug(f"添加消息到缓存系统: {self.stream_id}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
@@ -113,7 +120,9 @@ class SingleStreamContextManager:
|
|||||||
self.total_messages += 1
|
self.total_messages += 1
|
||||||
self.last_access_time = time.time()
|
self.last_access_time = time.time()
|
||||||
# 启动流的循环任务(如果还未启动)
|
# 启动流的循环任务(如果还未启动)
|
||||||
asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
task = asyncio.create_task(stream_loop_manager.start_stream_loop(self.stream_id))
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
|
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ from collections import defaultdict
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
|
|
||||||
from src.common.database.compatibility import db_get, db_query
|
from src.common.database.compatibility import db_get, db_query
|
||||||
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
from src.common.database.core.models import LLMUsage, Messages, OnlineTime
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -1002,8 +1004,8 @@ class StatisticOutputTask(AsyncTask):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
with open(self.record_file_path, "w", encoding="utf-8") as f:
|
async with aiofiles.open(self.record_file_path, "w", encoding="utf-8") as f:
|
||||||
f.write(html_template)
|
await f.write(html_template)
|
||||||
|
|
||||||
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
async def _generate_chart_data(self, stat: dict[str, Any]) -> dict:
|
||||||
"""生成图表数据 (异步)"""
|
"""生成图表数据 (异步)"""
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -198,8 +199,8 @@ class ImageManager:
|
|||||||
os.makedirs(emoji_dir, exist_ok=True)
|
os.makedirs(emoji_dir, exist_ok=True)
|
||||||
file_path = os.path.join(emoji_dir, filename)
|
file_path = os.path.join(emoji_dir, filename)
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
async with aiofiles.open(file_path, "wb") as f:
|
||||||
f.write(image_bytes)
|
await f.write(image_bytes)
|
||||||
logger.info(f"新表情包已保存至待注册目录: {file_path}")
|
logger.info(f"新表情包已保存至待注册目录: {file_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"保存待注册表情包文件失败: {e!s}")
|
logger.error(f"保存待注册表情包文件失败: {e!s}")
|
||||||
@@ -436,8 +437,8 @@ class ImageManager:
|
|||||||
os.makedirs(image_dir, exist_ok=True)
|
os.makedirs(image_dir, exist_ok=True)
|
||||||
file_path = os.path.join(image_dir, filename)
|
file_path = os.path.join(image_dir, filename)
|
||||||
|
|
||||||
with open(file_path, "wb") as f:
|
async with aiofiles.open(file_path, "wb") as f:
|
||||||
f.write(image_bytes)
|
await f.write(image_bytes)
|
||||||
|
|
||||||
new_img = Images(
|
new_img = Images(
|
||||||
image_id=image_id,
|
image_id=image_id,
|
||||||
|
|||||||
@@ -214,9 +214,9 @@ class AdaptiveBatchScheduler:
|
|||||||
for priority in sorted(Priority, reverse=True):
|
for priority in sorted(Priority, reverse=True):
|
||||||
queue = self.operation_queues[priority]
|
queue = self.operation_queues[priority]
|
||||||
count = min(len(queue), self.current_batch_size - len(operations))
|
count = min(len(queue), self.current_batch_size - len(operations))
|
||||||
for _ in range(count):
|
if queue and count > 0:
|
||||||
if queue:
|
# 使用 list.extend 代替循环 append
|
||||||
operations.append(queue.popleft())
|
operations.extend(queue.popleft() for _ in range(count))
|
||||||
|
|
||||||
if not operations:
|
if not operations:
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -9,7 +9,6 @@
|
|||||||
from .decorators import (
|
from .decorators import (
|
||||||
cached,
|
cached,
|
||||||
db_operation,
|
db_operation,
|
||||||
generate_cache_key,
|
|
||||||
measure_time,
|
measure_time,
|
||||||
retry,
|
retry,
|
||||||
timeout,
|
timeout,
|
||||||
|
|||||||
@@ -19,6 +19,9 @@ from src.chat.message_receive.chat_stream import get_chat_manager
|
|||||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message import get_global_api
|
from src.common.message import get_global_api
|
||||||
|
|
||||||
|
# 全局背景任务集合
|
||||||
|
_background_tasks = set()
|
||||||
from src.common.remote import TelemetryHeartBeatTask
|
from src.common.remote import TelemetryHeartBeatTask
|
||||||
from src.common.server import Server, get_global_server
|
from src.common.server import Server, get_global_server
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -461,7 +464,9 @@ MoFox_Bot(第三方修改版)
|
|||||||
logger.info("情绪管理器初始化成功")
|
logger.info("情绪管理器初始化成功")
|
||||||
|
|
||||||
# 启动聊天管理器的自动保存任务
|
# 启动聊天管理器的自动保存任务
|
||||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
task = asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
# 初始化增强记忆系统
|
# 初始化增强记忆系统
|
||||||
if global_config.memory.enable_memory:
|
if global_config.memory.enable_memory:
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING, ClassVar
|
||||||
|
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -33,7 +33,7 @@ class BaseCommand(PlusCommand):
|
|||||||
"""命令匹配的正则表达式"""
|
"""命令匹配的正则表达式"""
|
||||||
|
|
||||||
# 用于存储正则匹配组
|
# 用于存储正则匹配组
|
||||||
matched_groups: dict[str, str] = {}
|
matched_groups: ClassVar[dict[str, str]] = {}
|
||||||
|
|
||||||
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
|
def __init__(self, message: DatabaseMessages, plugin_config: dict | None = None):
|
||||||
"""初始化Command组件"""
|
"""初始化Command组件"""
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ from .component_registry import component_registry
|
|||||||
|
|
||||||
logger = get_logger("plugin_manager")
|
logger = get_logger("plugin_manager")
|
||||||
|
|
||||||
|
# 全局背景任务集合
|
||||||
|
_background_tasks = set()
|
||||||
|
|
||||||
|
|
||||||
class PluginManager:
|
class PluginManager:
|
||||||
"""
|
"""
|
||||||
@@ -142,7 +145,9 @@ class PluginManager:
|
|||||||
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
|
logger.debug(f"为插件 '{plugin_name}' 调用 on_plugin_loaded 钩子")
|
||||||
try:
|
try:
|
||||||
# 使用 asyncio.create_task 确保它不会阻塞加载流程
|
# 使用 asyncio.create_task 确保它不会阻塞加载流程
|
||||||
asyncio.create_task(plugin_instance.on_plugin_loaded())
|
task = asyncio.create_task(plugin_instance.on_plugin_loaded())
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"调用插件 '{plugin_name}' 的 on_plugin_loaded 钩子时出错: {e}")
|
logger.error(f"调用插件 '{plugin_name}' 的 on_plugin_loaded 钩子时出错: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ from src.config.config import global_config
|
|||||||
|
|
||||||
logger = get_logger("plan_executor")
|
logger = get_logger("plan_executor")
|
||||||
|
|
||||||
|
# 全局背景任务集合
|
||||||
|
_background_tasks = set()
|
||||||
|
|
||||||
|
|
||||||
class ChatterPlanExecutor:
|
class ChatterPlanExecutor:
|
||||||
"""
|
"""
|
||||||
@@ -89,7 +92,9 @@ class ChatterPlanExecutor:
|
|||||||
|
|
||||||
# 将其他动作放入后台任务执行,避免阻塞主流程
|
# 将其他动作放入后台任务执行,避免阻塞主流程
|
||||||
if other_actions:
|
if other_actions:
|
||||||
asyncio.create_task(self._execute_other_actions(other_actions, plan))
|
task = asyncio.create_task(self._execute_other_actions(other_actions, plan))
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
logger.info(f"已将 {len(other_actions)} 个其他动作放入后台任务执行。")
|
logger.info(f"已将 {len(other_actions)} 个其他动作放入后台任务执行。")
|
||||||
# 注意:后台任务的结果不会立即计入本次返回的统计数据
|
# 注意:后台任务的结果不会立即计入本次返回的统计数据
|
||||||
|
|
||||||
|
|||||||
@@ -11,6 +11,9 @@ from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
|||||||
from src.plugin_system.base.component_types import PermissionNodeField
|
from src.plugin_system.base.component_types import PermissionNodeField
|
||||||
from src.plugin_system.base.config_types import ConfigField
|
from src.plugin_system.base.config_types import ConfigField
|
||||||
|
|
||||||
|
# 全局背景任务集合
|
||||||
|
_background_tasks = set()
|
||||||
|
|
||||||
from .actions.read_feed_action import ReadFeedAction
|
from .actions.read_feed_action import ReadFeedAction
|
||||||
from .actions.send_feed_action import SendFeedAction
|
from .actions.send_feed_action import SendFeedAction
|
||||||
from .commands.send_feed_command import SendFeedCommand
|
from .commands.send_feed_command import SendFeedCommand
|
||||||
@@ -117,8 +120,14 @@ class MaiZoneRefactoredPlugin(BasePlugin):
|
|||||||
logger.info("MaiZone重构版插件服务已注册。")
|
logger.info("MaiZone重构版插件服务已注册。")
|
||||||
|
|
||||||
# --- 启动后台任务 ---
|
# --- 启动后台任务 ---
|
||||||
asyncio.create_task(scheduler_service.start())
|
task1 = asyncio.create_task(scheduler_service.start())
|
||||||
asyncio.create_task(monitor_service.start())
|
_background_tasks.add(task1)
|
||||||
|
task1.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
|
task2 = asyncio.create_task(monitor_service.start())
|
||||||
|
_background_tasks.add(task2)
|
||||||
|
task2.add_done_callback(_background_tasks.discard)
|
||||||
|
|
||||||
logger.info("MaiZone后台监控和定时任务已启动。")
|
logger.info("MaiZone后台监控和定时任务已启动。")
|
||||||
|
|
||||||
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
def get_plugin_components(self) -> list[tuple[ComponentInfo, type]]:
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import base64
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -86,8 +87,8 @@ class ImageService:
|
|||||||
if b64_json:
|
if b64_json:
|
||||||
image_bytes = base64.b64decode(b64_json)
|
image_bytes = base64.b64decode(b64_json)
|
||||||
file_path = Path(image_dir) / f"image_{i + 1}.png"
|
file_path = Path(image_dir) / f"image_{i + 1}.png"
|
||||||
with open(file_path, "wb") as f:
|
async with aiofiles.open(file_path, "wb") as f:
|
||||||
f.write(image_bytes)
|
await f.write(image_bytes)
|
||||||
logger.info(f"成功保存AI图片到: {file_path}")
|
logger.info(f"成功保存AI图片到: {file_path}")
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from collections.abc import Callable
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import aiofiles
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import bs4
|
import bs4
|
||||||
import json5
|
import json5
|
||||||
@@ -397,8 +398,8 @@ class QZoneService:
|
|||||||
}
|
}
|
||||||
# 成功获取后,异步写入本地文件作为备份
|
# 成功获取后,异步写入本地文件作为备份
|
||||||
try:
|
try:
|
||||||
with open(cookie_file_path, "wb") as f:
|
async with aiofiles.open(cookie_file_path, "wb") as f:
|
||||||
f.write(orjson.dumps(parsed_cookies))
|
await f.write(orjson.dumps(parsed_cookies))
|
||||||
logger.info(f"通过Napcat服务成功更新Cookie,并已保存至: {cookie_file_path}")
|
logger.info(f"通过Napcat服务成功更新Cookie,并已保存至: {cookie_file_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"保存Cookie到文件时出错: {e}")
|
logger.warning(f"保存Cookie到文件时出错: {e}")
|
||||||
@@ -413,8 +414,9 @@ class QZoneService:
|
|||||||
logger.info("尝试从本地Cookie文件加载...")
|
logger.info("尝试从本地Cookie文件加载...")
|
||||||
if cookie_file_path.exists():
|
if cookie_file_path.exists():
|
||||||
try:
|
try:
|
||||||
with open(cookie_file_path, "rb") as f:
|
async with aiofiles.open(cookie_file_path, "rb") as f:
|
||||||
cookies = orjson.loads(f.read())
|
content = await f.read()
|
||||||
|
cookies = orjson.loads(content)
|
||||||
logger.info(f"成功从本地文件加载Cookie: {cookie_file_path}")
|
logger.info(f"成功从本地文件加载Cookie: {cookie_file_path}")
|
||||||
return cookies
|
return cookies
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ logger = get_logger("stt_whisper_plugin")
|
|||||||
# 全局变量来缓存模型,避免重复加载
|
# 全局变量来缓存模型,避免重复加载
|
||||||
_whisper_model = None
|
_whisper_model = None
|
||||||
_is_loading = False
|
_is_loading = False
|
||||||
|
_model_ready_event = asyncio.Event()
|
||||||
|
_background_tasks = set() # 背景任务集合
|
||||||
|
|
||||||
class LocalASRTool(BaseTool):
|
class LocalASRTool(BaseTool):
|
||||||
"""
|
"""
|
||||||
@@ -29,7 +31,7 @@ class LocalASRTool(BaseTool):
|
|||||||
"""
|
"""
|
||||||
一个类方法,用于在插件加载时触发一次模型加载。
|
一个类方法,用于在插件加载时触发一次模型加载。
|
||||||
"""
|
"""
|
||||||
global _whisper_model, _is_loading
|
global _whisper_model, _is_loading, _model_ready_event
|
||||||
if _whisper_model is None and not _is_loading:
|
if _whisper_model is None and not _is_loading:
|
||||||
_is_loading = True
|
_is_loading = True
|
||||||
try:
|
try:
|
||||||
@@ -47,6 +49,7 @@ class LocalASRTool(BaseTool):
|
|||||||
_whisper_model = None
|
_whisper_model = None
|
||||||
finally:
|
finally:
|
||||||
_is_loading = False
|
_is_loading = False
|
||||||
|
_model_ready_event.set() # 通知等待的任务
|
||||||
|
|
||||||
async def execute(self, function_args: dict) -> str:
|
async def execute(self, function_args: dict) -> str:
|
||||||
audio_path = function_args.get("audio_path")
|
audio_path = function_args.get("audio_path")
|
||||||
@@ -55,9 +58,9 @@ class LocalASRTool(BaseTool):
|
|||||||
return "错误:缺少 audio_path 参数。"
|
return "错误:缺少 audio_path 参数。"
|
||||||
|
|
||||||
global _whisper_model
|
global _whisper_model
|
||||||
# 增强的等待逻辑:只要模型还没准备好,就一直等待后台加载任务完成
|
# 使用 Event 等待模型加载完成
|
||||||
while _is_loading:
|
if _is_loading:
|
||||||
await asyncio.sleep(0.2)
|
await _model_ready_event.wait()
|
||||||
|
|
||||||
if _whisper_model is None:
|
if _whisper_model is None:
|
||||||
return "Whisper 模型加载失败,无法识别语音。"
|
return "Whisper 模型加载失败,无法识别语音。"
|
||||||
@@ -90,7 +93,9 @@ class STTWhisperPlugin(BasePlugin):
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
if global_config.voice.asr_provider == "local":
|
if global_config.voice.asr_provider == "local":
|
||||||
# 使用 create_task 在后台开始加载,不阻塞主流程
|
# 使用 create_task 在后台开始加载,不阻塞主流程
|
||||||
asyncio.create_task(LocalASRTool.load_model_once(self.config or {}))
|
task = asyncio.create_task(LocalASRTool.load_model_once(self.config or {}))
|
||||||
|
_background_tasks.add(task)
|
||||||
|
task.add_done_callback(_background_tasks.discard)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"触发 Whisper 模型预加载时出错: {e}")
|
logger.error(f"触发 Whisper 模型预加载时出错: {e}")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user