ruff: 清理代码并规范导入顺序

对整个代码库进行了大规模的清理和重构,主要包括:
- 统一并修复了多个文件中的 `import` 语句顺序,使其符合 PEP 8 规范。
- 移除了大量未使用的导入和变量,减少了代码冗余。
- 修复了多处代码风格问题,例如多余的空行、不一致的引号使用等。
- 简化了异常处理逻辑,移除了不必要的 `noqa` 注释。
- 在多个文件中使用了更现代的类型注解语法(例如 `list[str]` 替代 `List[str]`)。
This commit is contained in:
minecraft1024a
2025-10-05 20:38:56 +08:00
parent 2908cfead1
commit 7a7f737f71
20 changed files with 163 additions and 171 deletions

152
bot.py
View File

@@ -1,22 +1,20 @@
# import asyncio
import asyncio
import os
import platform
import sys
import time
import platform
import traceback
from pathlib import Path
from contextlib import asynccontextmanager
import hashlib
from typing import Optional, Dict, Any
from pathlib import Path
# 初始化基础工具
from colorama import init, Fore
from colorama import Fore, init
from dotenv import load_dotenv
from rich.traceback import install
# 初始化日志系统
from src.common.logger import initialize_logging, get_logger, shutdown_logging
from src.common.logger import get_logger, initialize_logging, shutdown_logging
# 初始化日志和错误显示
initialize_logging()
@@ -24,7 +22,7 @@ logger = get_logger("main")
install(extra_lines=3)
# 常量定义
SUPPORTED_DATABASES = ['sqlite', 'mysql', 'postgresql']
SUPPORTED_DATABASES = ["sqlite", "mysql", "postgresql"]
SHUTDOWN_TIMEOUT = 10.0
EULA_CHECK_INTERVAL = 2
MAX_EULA_CHECK_ATTEMPTS = 30
@@ -37,18 +35,18 @@ logger.info("工作目录已设置")
class ConfigManager:
"""配置管理器"""
@staticmethod
def ensure_env_file():
"""确保.env文件存在如果不存在则从模板创建"""
env_file = Path(".env")
template_env = Path("template/template.env")
if not env_file.exists():
if template_env.exists():
logger.info("未找到.env文件正在从模板创建...")
try:
env_file.write_text(template_env.read_text(encoding='utf-8'), encoding='utf-8')
env_file.write_text(template_env.read_text(encoding="utf-8"), encoding="utf-8")
logger.info("已从template/template.env创建.env文件")
logger.warning("请编辑.env文件将EULA_CONFIRMED设置为true并配置其他必要参数")
except Exception as e:
@@ -64,23 +62,23 @@ class ConfigManager:
env_file = Path(".env")
if not env_file.exists():
return False
# 检查文件大小
file_size = env_file.stat().st_size
if file_size == 0 or file_size > MAX_ENV_FILE_SIZE:
logger.error(f".env文件大小异常: {file_size}字节")
return False
# 检查文件内容是否包含必要字段
try:
content = env_file.read_text(encoding='utf-8')
if 'EULA_CONFIRMED' not in content:
content = env_file.read_text(encoding="utf-8")
if "EULA_CONFIRMED" not in content:
logger.error(".env文件缺少EULA_CONFIRMED字段")
return False
except Exception as e:
logger.error(f"读取.env文件失败: {e}")
return False
return True
@staticmethod
@@ -90,7 +88,7 @@ class ConfigManager:
if not ConfigManager.verify_env_file_integrity():
logger.error(".env文件完整性验证失败")
return False
load_dotenv()
logger.info("环境变量加载成功")
return True
@@ -100,44 +98,44 @@ class ConfigManager:
class EULAManager:
"""EULA管理类"""
@staticmethod
async def check_eula():
"""检查EULA和隐私条款确认状态"""
confirm_logger = get_logger("confirm")
if not ConfigManager.safe_load_dotenv():
confirm_logger.error("无法加载环境变量EULA检查失败")
sys.exit(1)
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
if eula_confirmed == 'true':
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == "true":
logger.info("EULA已通过环境变量确认")
return
# 提示用户确认EULA
confirm_logger.critical("您需要同意EULA和隐私条款才能使用MoFox_Bot")
confirm_logger.critical("请阅读以下文件:")
confirm_logger.critical(" - EULA.md (用户许可协议)")
confirm_logger.critical(" - PRIVACY.md (隐私条款)")
confirm_logger.critical("然后编辑 .env 文件,将 'EULA_CONFIRMED=false' 改为 'EULA_CONFIRMED=true'")
attempts = 0
while attempts < MAX_EULA_CHECK_ATTEMPTS:
try:
await asyncio.sleep(EULA_CHECK_INTERVAL)
attempts += 1
# 重新加载环境变量
ConfigManager.safe_load_dotenv()
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
if eula_confirmed == 'true':
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
if eula_confirmed == "true":
confirm_logger.info("EULA确认成功感谢您的同意")
return
if attempts % 5 == 0:
confirm_logger.critical(f"请修改 .env 文件中的 EULA_CONFIRMED=true (尝试 {attempts}/{MAX_EULA_CHECK_ATTEMPTS})")
except KeyboardInterrupt:
confirm_logger.info("用户取消,程序退出")
sys.exit(0)
@@ -146,43 +144,43 @@ class EULAManager:
if attempts >= MAX_EULA_CHECK_ATTEMPTS:
confirm_logger.error("达到最大检查次数,程序退出")
sys.exit(1)
confirm_logger.error("EULA确认超时程序退出")
sys.exit(1)
class TaskManager:
"""任务管理器"""
@staticmethod
async def cancel_pending_tasks(loop, timeout=SHUTDOWN_TIMEOUT):
"""取消所有待处理的任务"""
remaining_tasks = [
t for t in asyncio.all_tasks(loop)
t for t in asyncio.all_tasks(loop)
if t is not asyncio.current_task(loop) and not t.done()
]
if not remaining_tasks:
logger.info("没有待取消的任务")
return True
logger.info(f"正在取消 {len(remaining_tasks)} 个剩余任务...")
# 取消任务
for task in remaining_tasks:
task.cancel()
# 等待任务完成
try:
results = await asyncio.wait_for(
asyncio.gather(*remaining_tasks, return_exceptions=True),
asyncio.gather(*remaining_tasks, return_exceptions=True),
timeout=timeout
)
# 检查任务结果
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.warning(f"任务 {i} 取消时发生异常: {result}")
logger.info("所有剩余任务已成功取消")
return True
except asyncio.TimeoutError:
@@ -208,30 +206,30 @@ class TaskManager:
class ShutdownManager:
"""关闭管理器"""
@staticmethod
async def graceful_shutdown(loop=None):
"""优雅关闭程序"""
try:
logger.info("正在优雅关闭麦麦...")
start_time = time.time()
# 停止异步任务
tasks_stopped = await TaskManager.stop_async_tasks()
# 取消待处理任务
tasks_cancelled = True
if loop and not loop.is_closed():
tasks_cancelled = await TaskManager.cancel_pending_tasks(loop)
shutdown_time = time.time() - start_time
success = tasks_stopped and tasks_cancelled
if success:
logger.info(f"麦麦优雅关闭完成,耗时: {shutdown_time:.2f}")
else:
logger.warning(f"麦麦关闭完成,但部分操作未成功,耗时: {shutdown_time:.2f}")
return success
except Exception as e:
@@ -264,29 +262,29 @@ async def create_event_loop_context():
class DatabaseManager:
"""数据库连接管理器"""
def __init__(self):
self._connection = None
async def __aenter__(self):
"""异步上下文管理器入口"""
try:
from src.common.database.database import initialize_sql_database
from src.config.config import global_config
logger.info("正在初始化数据库连接...")
start_time = time.time()
# 使用线程执行器运行潜在的阻塞操作
await asyncio.to_thread(initialize_sql_database, global_config.database)
elapsed_time = time.time() - start_time
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库,耗时: {elapsed_time:.2f}")
return self
except Exception as e:
logger.error(f"数据库连接初始化失败: {e}")
raise
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""异步上下文管理器出口"""
if exc_type:
@@ -295,34 +293,34 @@ class DatabaseManager:
class ConfigurationValidator:
"""配置验证器"""
@staticmethod
def validate_configuration():
"""验证关键配置"""
try:
from src.config.config import global_config
# 检查必要的配置节
required_sections = ['database', 'bot']
required_sections = ["database", "bot"]
for section in required_sections:
if not hasattr(global_config, section):
logger.error(f"配置中缺少{section}配置节")
return False
# 验证数据库配置
db_config = global_config.database
if not hasattr(db_config, 'database_type') or not db_config.database_type:
if not hasattr(db_config, "database_type") or not db_config.database_type:
logger.error("数据库配置缺少database_type字段")
return False
if db_config.database_type not in SUPPORTED_DATABASES:
logger.error(f"不支持的数据库类型: {db_config.database_type}")
logger.info(f"支持的数据库类型: {', '.join(SUPPORTED_DATABASES)}")
return False
logger.info("配置验证通过")
return True
except ImportError:
logger.error("无法导入全局配置模块")
return False
@@ -332,16 +330,16 @@ class ConfigurationValidator:
class EasterEgg:
"""彩蛋功能"""
_initialized = False
@classmethod
def show(cls):
"""显示彩色文本"""
if not cls._initialized:
init()
cls._initialized = True
text = "多年以后面对AI行刑队张三将会回想起他2023年在会议上讨论人工智能的那个下午"
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
rainbow_text = ""
@@ -394,27 +392,27 @@ class MaiBotMain:
"""执行同步初始化步骤"""
self.setup_timezone()
await EULAManager.check_eula()
if not ConfigurationValidator.validate_configuration():
raise RuntimeError("配置验证失败,请检查配置文件")
return self.create_main_system()
async def run_async_init(self, main_system):
"""执行异步初始化步骤"""
# 初始化数据库连接
await self.initialize_database()
# 初始化数据库表结构
await self.initialize_database_async()
# 初始化主系统
await main_system.initialize()
# 初始化知识库
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
initialize_lpmm_knowledge()
# 显示彩蛋
EasterEgg.show()
@@ -422,7 +420,7 @@ async def wait_for_user_input():
"""等待用户输入(异步方式)"""
try:
# 在非生产环境下,使用异步方式等待输入
if os.getenv('ENVIRONMENT') != 'production':
if os.getenv("ENVIRONMENT") != "production":
logger.info("程序执行完成,按 Ctrl+C 退出...")
# 简单的异步等待,避免阻塞事件循环
while True:
@@ -438,30 +436,30 @@ async def main_async():
"""主异步函数"""
exit_code = 0
main_task = None
async with create_event_loop_context() as loop:
try:
# 确保环境文件存在
ConfigManager.ensure_env_file()
# 创建主程序实例并执行初始化
maibot = MaiBotMain()
main_system = await maibot.run_sync_init()
await maibot.run_async_init(main_system)
# 运行主任务
main_task = asyncio.create_task(main_system.schedule_tasks())
logger.info("麦麦机器人启动完成,开始运行主任务...")
# 同时运行主任务和用户输入等待
user_input_done = asyncio.create_task(wait_for_user_input())
# 使用wait等待任意一个任务完成
done, pending = await asyncio.wait(
[main_task, user_input_done],
return_when=asyncio.FIRST_COMPLETED
)
# 如果用户输入任务完成用户按了Ctrl+C取消主任务
if user_input_done in done and main_task not in done:
logger.info("用户请求退出,正在取消主任务...")
@@ -472,7 +470,7 @@ async def main_async():
logger.info("主任务已取消")
except Exception as e:
logger.error(f"主任务取消时发生错误: {e}")
except KeyboardInterrupt:
logger.warning("收到中断信号,正在优雅关闭...")
if main_task and not main_task.done():
@@ -481,7 +479,7 @@ async def main_async():
logger.error(f"主程序发生异常: {e}")
logger.debug(f"异常详情: {traceback.format_exc()}")
exit_code = 1
return exit_code
if __name__ == "__main__":
@@ -500,5 +498,5 @@ if __name__ == "__main__":
shutdown_logging()
except Exception as e:
print(f"关闭日志系统时出错: {e}")
sys.exit(exit_code)

View File

@@ -21,6 +21,7 @@ from .memory_chunk import MemoryChunk as Memory
# 遗忘引擎
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
from .memory_formatter import format_memories_bracket_style
# 记忆管理器
from .memory_manager import MemoryManager, MemoryResult, memory_manager
@@ -30,7 +31,6 @@ from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system,
# Vector DB存储系统
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
from .memory_formatter import format_memories_bracket_style
__all__ = [
# 核心数据结构

View File

@@ -17,8 +17,9 @@
"""
from __future__ import annotations
from typing import Any, Iterable
import time
from collections.abc import Iterable
from typing import Any
def _format_timestamp(ts: Any) -> str:

View File

@@ -2,9 +2,8 @@
记忆元数据索引。
"""
from dataclasses import dataclass, asdict
from dataclasses import asdict, dataclass
from typing import Any
from time import time
from src.common.logger import get_logger
@@ -12,6 +11,7 @@ logger = get_logger(__name__)
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
@dataclass
class MemoryMetadataIndexEntry:
memory_id: str
@@ -51,7 +51,7 @@ class MemoryMetadataIndex:
if payload:
try:
self._rust.batch_add(payload)
except Exception as ex: # noqa: BLE001
except Exception as ex:
logger.error(f"Rust 元数据批量添加失败: {ex}")
def add_or_update(self, entry: MemoryMetadataIndexEntry):
@@ -88,7 +88,7 @@ class MemoryMetadataIndex:
if flexible_mode:
return list(self._rust.search_flexible(params))
return list(self._rust.search_strict(params))
except Exception as ex: # noqa: BLE001
except Exception as ex:
logger.error(f"Rust 搜索失败返回空: {ex}")
return []
@@ -105,18 +105,18 @@ class MemoryMetadataIndex:
"keywords_count": raw.get("keywords_indexed", 0),
"tags_count": raw.get("tags_indexed", 0),
}
except Exception as ex: # noqa: BLE001
except Exception as ex:
logger.warning(f"读取 Rust stats 失败: {ex}")
return {"total_memories": 0}
def save(self): # 仅调用 rust save
try:
self._rust.save()
except Exception as ex: # noqa: BLE001
except Exception as ex:
logger.warning(f"Rust save 失败: {ex}")
__all__ = [
"MemoryMetadataIndexEntry",
"MemoryMetadataIndex",
"MemoryMetadataIndexEntry",
]

View File

@@ -263,7 +263,7 @@ class MessageRecv(Message):
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {str(e)}")
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
@@ -277,7 +277,7 @@ class MessageRecv(Message):
logger.info("未启用视频识别")
return "[视频]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"

View File

@@ -1,5 +1,4 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""纯 inkfox 视频关键帧分析工具
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
@@ -14,25 +13,25 @@
from __future__ import annotations
import os
import io
import asyncio
import base64
import tempfile
from pathlib import Path
from typing import List, Tuple, Optional, Dict, Any
import hashlib
import io
import os
import tempfile
import time
from pathlib import Path
from typing import Any
from PIL import Image
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
from src.common.logger import get_logger
from src.config.config import global_config, model_config
from src.llm_models.utils_model import LLMRequest
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
# 简易并发控制:同一 hash 只处理一次
_video_locks: Dict[str, asyncio.Lock] = {}
_video_locks: dict[str, asyncio.Lock] = {}
_locks_guard = asyncio.Lock()
logger = get_logger("utils_video")
@@ -90,7 +89,7 @@ class VideoAnalyzer:
logger.debug(f"获取系统信息失败: {e}")
# ---- 关键帧提取 ----
async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]:
async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]:
"""提取关键帧并返回 (base64, timestamp_seconds) 列表"""
with tempfile.TemporaryDirectory() as tmp:
result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
@@ -105,7 +104,7 @@ class VideoAnalyzer:
)
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
total_ms = getattr(result, "total_time_ms", 0)
frames: List[Tuple[str, float]] = []
frames: list[tuple[str, float]] = []
for i, f in enumerate(files):
img = Image.open(f).convert("RGB")
if max(img.size) > self.max_image_size:
@@ -119,7 +118,7 @@ class VideoAnalyzer:
return frames
# ---- 批量分析 ----
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
from src.llm_models.payload_content.message import MessageBuilder, RoleType
from src.llm_models.utils_model import RequestType
prompt = self.batch_analysis_prompt.format(
@@ -149,8 +148,8 @@ class VideoAnalyzer:
return resp.content or "❌ 未获得响应"
# ---- 逐帧分析 ----
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
results: List[str] = []
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
results: list[str] = []
for i, (b64, ts) in enumerate(frames):
prompt = f"分析第{i+1}" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
if question:
@@ -174,7 +173,7 @@ class VideoAnalyzer:
return "\n".join(results)
# ---- 主入口 ----
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]:
if not os.path.exists(video_path):
return False, "❌ 文件不存在"
frames = await self.extract_keyframes(video_path)
@@ -189,10 +188,10 @@ class VideoAnalyzer:
async def analyze_video_from_bytes(
self,
video_bytes: bytes,
filename: Optional[str] = None,
prompt: Optional[str] = None,
question: Optional[str] = None,
) -> Dict[str, str]:
filename: str | None = None,
prompt: str | None = None,
question: str | None = None,
) -> dict[str, str]:
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
if not video_bytes:
return {"summary": "❌ 空视频数据"}
@@ -271,7 +270,7 @@ class VideoAnalyzer:
# ---- 外部接口 ----
_INSTANCE: Optional[VideoAnalyzer] = None
_INSTANCE: VideoAnalyzer | None = None
def get_video_analyzer() -> VideoAnalyzer:
@@ -285,7 +284,7 @@ def is_video_analysis_available() -> bool:
return True
def get_video_analysis_status() -> Dict[str, Any]:
def get_video_analysis_status() -> dict[str, Any]:
try:
info = video.get_system_info() # type: ignore[attr-defined]
except Exception as e: # pragma: no cover
@@ -297,4 +296,4 @@ def get_video_analysis_status() -> Dict[str, Any]:
"modes": ["auto", "batch", "sequential"],
"max_frames_default": inst.max_frames,
"implementation": "inkfox",
}
}

View File

@@ -53,8 +53,8 @@ class StreamContext(BaseDataModel):
priority_mode: str | None = None
priority_info: dict | None = None
def add_action_to_message(self, message_id: str, action: str):
"""
向指定消息添加执行的动作

View File

@@ -5,7 +5,6 @@ MCP (Model Context Protocol) SSE (Server-Sent Events) 客户端实现
import asyncio
import io
import json
from collections.abc import Callable
from typing import Any
@@ -20,7 +19,6 @@ from ..exceptions import (
NetworkConnectionError,
ReqAbortException,
RespNotOkException,
RespParseException,
)
from ..payload_content.message import Message, RoleType
from ..payload_content.resp_format import RespFormat

View File

@@ -6,7 +6,7 @@ import time
import traceback
from functools import partial
from random import choices
from typing import Any, List, Tuple
from typing import Any
from maim_message import MessageServer
from rich.traceback import install
@@ -36,7 +36,7 @@ install(extra_lines=3)
logger = get_logger("main")
# 预定义彩蛋短语,避免在每次初始化时重新创建
EGG_PHRASES: List[Tuple[str, int]] = [
EGG_PHRASES: list[tuple[str, int]] = [
("我们的代码里真的没有bug只有'特性'", 10),
("你知道吗?阿范喜欢被切成臊子😡", 10),
("你知道吗,雅诺狐的耳朵其实很好摸", 5),
@@ -69,22 +69,22 @@ def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float)
class MainSystem:
"""主系统类,负责协调所有组件"""
def __init__(self) -> None:
# 使用增强记忆系统
self.memory_manager = memory_manager
self.individuality: Individuality = get_individuality()
# 使用消息API替代直接的FastAPI实例
self.app: MessageServer = get_global_api()
self.server: Server = get_global_server()
# 设置信号处理器用于优雅退出
self._shutting_down = False
self._setup_signal_handlers()
# 存储清理任务的引用
self._cleanup_tasks: List[asyncio.Task] = []
self._cleanup_tasks: list[asyncio.Task] = []
def _setup_signal_handlers(self) -> None:
"""设置信号处理器"""
@@ -92,7 +92,7 @@ class MainSystem:
if self._shutting_down:
logger.warning("系统已经在关闭过程中,忽略重复信号")
return
self._shutting_down = True
logger.info("收到退出信号,正在优雅关闭系统...")
@@ -148,7 +148,7 @@ class MainSystem:
# 尝试注册所有可用的计算器
registered_calculators = []
for calc_name, calc_info in interest_calculators.items():
enabled = getattr(calc_info, "enabled", True)
default_enabled = getattr(calc_info, "enabled_by_default", True)
@@ -169,7 +169,7 @@ class MainSystem:
# 创建组件实例
calculator_instance = component_class()
# 初始化组件
if not await calculator_instance.initialize():
logger.error(f"兴趣计算器 {calc_name} 初始化失败")
@@ -199,12 +199,12 @@ class MainSystem:
"""异步清理资源"""
if self._shutting_down:
return
self._shutting_down = True
logger.info("开始系统清理流程...")
cleanup_tasks = []
# 停止数据库服务
try:
from src.common.database.database import stop_database
@@ -236,14 +236,14 @@ class MainSystem:
# 触发停止事件
try:
from src.plugin_system.core.event_manager import event_manager
cleanup_tasks.append(("插件系统停止事件",
cleanup_tasks.append(("插件系统停止事件",
event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")))
except Exception as e:
logger.error(f"准备触发停止事件时出错: {e}")
# 停止表情管理器
try:
cleanup_tasks.append(("表情管理器",
cleanup_tasks.append(("表情管理器",
asyncio.get_event_loop().run_in_executor(None, get_emoji_manager().shutdown)))
except Exception as e:
logger.error(f"准备停止表情管理器时出错: {e}")
@@ -270,21 +270,21 @@ class MainSystem:
logger.info(f"开始并行执行 {len(cleanup_tasks)} 个清理任务...")
tasks = [task for _, task in cleanup_tasks]
task_names = [name for name, _ in cleanup_tasks]
# 使用asyncio.gather并行执行设置超时防止卡死
try:
results = await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True),
timeout=30.0 # 30秒超时
)
# 记录结果
for i, (name, result) in enumerate(zip(task_names, results)):
if isinstance(result, Exception):
logger.error(f"停止 {name} 时出错: {result}")
else:
logger.info(f"🛑 {name} 已停止")
except asyncio.TimeoutError:
logger.error("清理任务超时,强制退出")
except Exception as e:
@@ -311,16 +311,16 @@ class MainSystem:
try:
start_time = time.time()
message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN")
# 检查系统是否正在关闭
if self._shutting_down:
logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}")
return
# 创建后台任务
task = asyncio.create_task(chat_bot.message_process(message_data))
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
# 添加一个回调函数,当任务完成时,它会被调用
task.add_done_callback(partial(_task_done_callback, message_id=message_id, start_time=start_time))
except Exception:
@@ -330,19 +330,19 @@ class MainSystem:
async def initialize(self) -> None:
"""初始化系统组件"""
# 检查必要的配置
if not hasattr(global_config, 'bot') or not hasattr(global_config.bot, 'nickname'):
if not hasattr(global_config, "bot") or not hasattr(global_config.bot, "nickname"):
logger.error("缺少必要的bot配置")
raise ValueError("Bot配置不完整")
logger.info(f"正在唤醒{global_config.bot.nickname}......")
# 初始化组件
await self._init_components()
# 随机选择彩蛋
egg_texts, weights = zip(*EGG_PHRASES)
selected_egg = choices(egg_texts, weights=weights, k=1)[0]
logger.info(f"""
全部系统初始化完成,{global_config.bot.nickname}已成功唤醒
=========================================================
@@ -367,7 +367,7 @@ MoFox_Bot(第三方修改版)
async_task_manager.add_task(StatisticOutputTask()),
async_task_manager.add_task(TelemetryHeartBeatTask()),
]
await asyncio.gather(*base_init_tasks, return_exceptions=True)
logger.info("基础定时任务初始化成功")
@@ -399,7 +399,7 @@ MoFox_Bot(第三方修改版)
# 处理所有缓存的事件订阅(插件加载完成后)
event_manager.process_all_pending_subscriptions()
# 初始化MCP工具提供器
try:
mcp_config = global_config.get("mcp_servers", [])
@@ -412,24 +412,24 @@ MoFox_Bot(第三方修改版)
# 并行初始化其他管理器
manager_init_tasks = []
# 表情管理器
manager_init_tasks.append(self._safe_init("表情包管理器", get_emoji_manager().initialize))
# 情绪管理器
manager_init_tasks.append(self._safe_init("情绪管理器", mood_manager.start))
# 聊天管理器
manager_init_tasks.append(self._safe_init("聊天管理器", get_chat_manager()._initialize))
# 等待所有管理器初始化完成
results = await asyncio.gather(*manager_init_tasks, return_exceptions=True)
# 检查初始化结果
for i, result in enumerate(results):
if isinstance(result, Exception):
logger.error(f"组件初始化失败: {result}")
# 启动聊天管理器的自动保存任务
asyncio.create_task(get_chat_manager()._auto_save_task())
@@ -558,7 +558,7 @@ MoFox_Bot(第三方修改版)
"""关闭系统组件"""
if self._shutting_down:
return
logger.info("正在关闭MainSystem...")
await self._async_cleanup()
logger.info("MainSystem关闭完成")

View File

@@ -19,7 +19,7 @@ from src.common.logger import get_logger
from src.plugin_system.base.component_types import ActionInfo
if TYPE_CHECKING:
from src.chat.replyer.default_generator import DefaultReplyer
pass
install(extra_lines=3)

View File

@@ -19,7 +19,7 @@ def get_tool_instance(tool_name: str) -> BaseTool | None:
tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
if tool_class:
return tool_class(plugin_config)
# 如果不是常规工具检查是否是MCP工具
# MCP工具不需要返回实例会在execute_tool_call中特殊处理
return None
@@ -35,7 +35,7 @@ def get_llm_available_tool_definitions():
llm_available_tools = component_registry.get_llm_available_tools()
tool_definitions = [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
# 添加MCP工具
try:
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
@@ -45,5 +45,5 @@ def get_llm_available_tool_definitions():
logger.debug(f"已添加 {len(mcp_tools)} 个MCP工具到可用工具列表")
except Exception as e:
logger.debug(f"获取MCP工具失败可能未配置: {e}")
return tool_definitions

View File

@@ -279,7 +279,7 @@ class ToolExecutor:
logger.info(
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
)
# 检查是否是MCP工具
try:
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
@@ -295,7 +295,7 @@ class ToolExecutor:
}
except Exception as e:
logger.debug(f"检查MCP工具时出错: {e}")
function_args["llm_called"] = True # 标记为LLM调用
# 检查是否是二步工具的第二步调用

View File

@@ -3,11 +3,9 @@ MCP (Model Context Protocol) 连接器
负责连接MCP服务器获取和执行工具
"""
import asyncio
from typing import Any
import aiohttp
import orjson
from src.common.logger import get_logger

View File

@@ -3,7 +3,6 @@ MCP工具提供器 - 简化版
直接集成到工具系统,无需复杂的插件架构
"""
import asyncio
from typing import Any
from src.common.logger import get_logger

View File

@@ -4,9 +4,10 @@
"""
import time
import orjson
from typing import TYPE_CHECKING
import orjson
from src.chat.interest_system import bot_interest_manager
from src.common.logger import get_logger
from src.config.config import global_config

View File

@@ -230,11 +230,11 @@ class ChatterPlanExecutor:
except Exception as e:
error_message = str(e)
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
'''
"""
# 记录用户关系追踪
if success and action_info.action_message:
await self._track_user_interaction(action_info, plan, reply_content)
'''
"""
execution_time = time.time() - start_time
self.execution_stats["execution_times"].append(execution_time)

View File

@@ -10,10 +10,10 @@ from typing import TYPE_CHECKING, Any
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager
from src.plugin_system.base.component_types import ChatMode
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
from src.plugin_system.base.component_types import ChatMode
if TYPE_CHECKING:
from src.chat.planner_actions.action_manager import ChatterActionManager

View File

@@ -6,9 +6,7 @@ SearXNG search engine implementation
from __future__ import annotations
import asyncio
import functools
from typing import Any, List
from typing import Any
import httpx
@@ -39,13 +37,13 @@ class SearXNGSearchEngine(BaseSearchEngine):
instances = config_api.get_global_config("web_search.searxng_instances", None)
if isinstance(instances, list):
# 过滤空值
self.instances: List[str] = [u.rstrip("/") for u in instances if isinstance(u, str) and u.strip()]
self.instances: list[str] = [u.rstrip("/") for u in instances if isinstance(u, str) and u.strip()]
else:
self.instances = []
api_keys = config_api.get_global_config("web_search.searxng_api_keys", None)
if isinstance(api_keys, list):
self.api_keys: List[str | None] = [k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys]
self.api_keys: list[str | None] = [k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys]
else:
self.api_keys = []
@@ -85,7 +83,7 @@ class SearXNGSearchEngine(BaseSearchEngine):
results.extend(instance_results)
if len(results) >= num_results:
break
except Exception as e: # noqa: BLE001
except Exception as e:
logger.warning(f"SearXNG 实例 {base_url} 调用失败: {e}")
continue
@@ -116,12 +114,12 @@ class SearXNGSearchEngine(BaseSearchEngine):
try:
resp = await self._client.get(url, params=params, headers=headers)
resp.raise_for_status()
except Exception as e: # noqa: BLE001
except Exception as e:
raise RuntimeError(f"请求失败: {e}") from e
try:
data = resp.json()
except Exception as e: # noqa: BLE001
except Exception as e:
raise RuntimeError(f"解析 JSON 失败: {e}") from e
raw_results = data.get("results", []) if isinstance(data, dict) else []
@@ -141,5 +139,5 @@ class SearXNGSearchEngine(BaseSearchEngine):
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb): # noqa: D401
async def __aexit__(self, exc_type, exc, tb):
await self._client.aclose()

View File

@@ -41,8 +41,8 @@ class WEBSEARCHPLUGIN(BasePlugin):
from .engines.bing_engine import BingSearchEngine
from .engines.ddg_engine import DDGSearchEngine
from .engines.exa_engine import ExaSearchEngine
from .engines.tavily_engine import TavilySearchEngine
from .engines.searxng_engine import SearXNGSearchEngine
from .engines.tavily_engine import TavilySearchEngine
# 实例化所有搜索引擎这会触发API密钥管理器的初始化
exa_engine = ExaSearchEngine()

View File

@@ -13,8 +13,8 @@ from src.plugin_system.apis import config_api
from ..engines.bing_engine import BingSearchEngine
from ..engines.ddg_engine import DDGSearchEngine
from ..engines.exa_engine import ExaSearchEngine
from ..engines.tavily_engine import TavilySearchEngine
from ..engines.searxng_engine import SearXNGSearchEngine
from ..engines.tavily_engine import TavilySearchEngine
from ..utils.formatters import deduplicate_results, format_search_results
logger = get_logger("web_search_tool")