ruff: 清理代码并规范导入顺序
对整个代码库进行了大规模的清理和重构,主要包括: - 统一并修复了多个文件中的 `import` 语句顺序,使其符合 PEP 8 规范。 - 移除了大量未使用的导入和变量,减少了代码冗余。 - 修复了多处代码风格问题,例如多余的空行、不一致的引号使用等。 - 简化了异常处理逻辑,移除了不必要的 `noqa` 注释。 - 在多个文件中使用了更现代的类型注解语法(例如 `list[str]` 替代 `List[str]`)。
This commit is contained in:
152
bot.py
152
bot.py
@@ -1,22 +1,20 @@
|
|||||||
# import asyncio
|
# import asyncio
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
|
import platform
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import platform
|
|
||||||
import traceback
|
import traceback
|
||||||
from pathlib import Path
|
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
import hashlib
|
from pathlib import Path
|
||||||
from typing import Optional, Dict, Any
|
|
||||||
|
|
||||||
# 初始化基础工具
|
# 初始化基础工具
|
||||||
from colorama import init, Fore
|
from colorama import Fore, init
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
# 初始化日志系统
|
# 初始化日志系统
|
||||||
from src.common.logger import initialize_logging, get_logger, shutdown_logging
|
from src.common.logger import get_logger, initialize_logging, shutdown_logging
|
||||||
|
|
||||||
# 初始化日志和错误显示
|
# 初始化日志和错误显示
|
||||||
initialize_logging()
|
initialize_logging()
|
||||||
@@ -24,7 +22,7 @@ logger = get_logger("main")
|
|||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
# 常量定义
|
# 常量定义
|
||||||
SUPPORTED_DATABASES = ['sqlite', 'mysql', 'postgresql']
|
SUPPORTED_DATABASES = ["sqlite", "mysql", "postgresql"]
|
||||||
SHUTDOWN_TIMEOUT = 10.0
|
SHUTDOWN_TIMEOUT = 10.0
|
||||||
EULA_CHECK_INTERVAL = 2
|
EULA_CHECK_INTERVAL = 2
|
||||||
MAX_EULA_CHECK_ATTEMPTS = 30
|
MAX_EULA_CHECK_ATTEMPTS = 30
|
||||||
@@ -37,18 +35,18 @@ logger.info("工作目录已设置")
|
|||||||
|
|
||||||
class ConfigManager:
|
class ConfigManager:
|
||||||
"""配置管理器"""
|
"""配置管理器"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ensure_env_file():
|
def ensure_env_file():
|
||||||
"""确保.env文件存在,如果不存在则从模板创建"""
|
"""确保.env文件存在,如果不存在则从模板创建"""
|
||||||
env_file = Path(".env")
|
env_file = Path(".env")
|
||||||
template_env = Path("template/template.env")
|
template_env = Path("template/template.env")
|
||||||
|
|
||||||
if not env_file.exists():
|
if not env_file.exists():
|
||||||
if template_env.exists():
|
if template_env.exists():
|
||||||
logger.info("未找到.env文件,正在从模板创建...")
|
logger.info("未找到.env文件,正在从模板创建...")
|
||||||
try:
|
try:
|
||||||
env_file.write_text(template_env.read_text(encoding='utf-8'), encoding='utf-8')
|
env_file.write_text(template_env.read_text(encoding="utf-8"), encoding="utf-8")
|
||||||
logger.info("已从template/template.env创建.env文件")
|
logger.info("已从template/template.env创建.env文件")
|
||||||
logger.warning("请编辑.env文件,将EULA_CONFIRMED设置为true并配置其他必要参数")
|
logger.warning("请编辑.env文件,将EULA_CONFIRMED设置为true并配置其他必要参数")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -64,23 +62,23 @@ class ConfigManager:
|
|||||||
env_file = Path(".env")
|
env_file = Path(".env")
|
||||||
if not env_file.exists():
|
if not env_file.exists():
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查文件大小
|
# 检查文件大小
|
||||||
file_size = env_file.stat().st_size
|
file_size = env_file.stat().st_size
|
||||||
if file_size == 0 or file_size > MAX_ENV_FILE_SIZE:
|
if file_size == 0 or file_size > MAX_ENV_FILE_SIZE:
|
||||||
logger.error(f".env文件大小异常: {file_size}字节")
|
logger.error(f".env文件大小异常: {file_size}字节")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 检查文件内容是否包含必要字段
|
# 检查文件内容是否包含必要字段
|
||||||
try:
|
try:
|
||||||
content = env_file.read_text(encoding='utf-8')
|
content = env_file.read_text(encoding="utf-8")
|
||||||
if 'EULA_CONFIRMED' not in content:
|
if "EULA_CONFIRMED" not in content:
|
||||||
logger.error(".env文件缺少EULA_CONFIRMED字段")
|
logger.error(".env文件缺少EULA_CONFIRMED字段")
|
||||||
return False
|
return False
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"读取.env文件失败: {e}")
|
logger.error(f"读取.env文件失败: {e}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -90,7 +88,7 @@ class ConfigManager:
|
|||||||
if not ConfigManager.verify_env_file_integrity():
|
if not ConfigManager.verify_env_file_integrity():
|
||||||
logger.error(".env文件完整性验证失败")
|
logger.error(".env文件完整性验证失败")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
logger.info("环境变量加载成功")
|
logger.info("环境变量加载成功")
|
||||||
return True
|
return True
|
||||||
@@ -100,44 +98,44 @@ class ConfigManager:
|
|||||||
|
|
||||||
class EULAManager:
|
class EULAManager:
|
||||||
"""EULA管理类"""
|
"""EULA管理类"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def check_eula():
|
async def check_eula():
|
||||||
"""检查EULA和隐私条款确认状态"""
|
"""检查EULA和隐私条款确认状态"""
|
||||||
confirm_logger = get_logger("confirm")
|
confirm_logger = get_logger("confirm")
|
||||||
|
|
||||||
if not ConfigManager.safe_load_dotenv():
|
if not ConfigManager.safe_load_dotenv():
|
||||||
confirm_logger.error("无法加载环境变量,EULA检查失败")
|
confirm_logger.error("无法加载环境变量,EULA检查失败")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
|
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
||||||
if eula_confirmed == 'true':
|
if eula_confirmed == "true":
|
||||||
logger.info("EULA已通过环境变量确认")
|
logger.info("EULA已通过环境变量确认")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 提示用户确认EULA
|
# 提示用户确认EULA
|
||||||
confirm_logger.critical("您需要同意EULA和隐私条款才能使用MoFox_Bot")
|
confirm_logger.critical("您需要同意EULA和隐私条款才能使用MoFox_Bot")
|
||||||
confirm_logger.critical("请阅读以下文件:")
|
confirm_logger.critical("请阅读以下文件:")
|
||||||
confirm_logger.critical(" - EULA.md (用户许可协议)")
|
confirm_logger.critical(" - EULA.md (用户许可协议)")
|
||||||
confirm_logger.critical(" - PRIVACY.md (隐私条款)")
|
confirm_logger.critical(" - PRIVACY.md (隐私条款)")
|
||||||
confirm_logger.critical("然后编辑 .env 文件,将 'EULA_CONFIRMED=false' 改为 'EULA_CONFIRMED=true'")
|
confirm_logger.critical("然后编辑 .env 文件,将 'EULA_CONFIRMED=false' 改为 'EULA_CONFIRMED=true'")
|
||||||
|
|
||||||
attempts = 0
|
attempts = 0
|
||||||
while attempts < MAX_EULA_CHECK_ATTEMPTS:
|
while attempts < MAX_EULA_CHECK_ATTEMPTS:
|
||||||
try:
|
try:
|
||||||
await asyncio.sleep(EULA_CHECK_INTERVAL)
|
await asyncio.sleep(EULA_CHECK_INTERVAL)
|
||||||
attempts += 1
|
attempts += 1
|
||||||
|
|
||||||
# 重新加载环境变量
|
# 重新加载环境变量
|
||||||
ConfigManager.safe_load_dotenv()
|
ConfigManager.safe_load_dotenv()
|
||||||
eula_confirmed = os.getenv('EULA_CONFIRMED', '').lower()
|
eula_confirmed = os.getenv("EULA_CONFIRMED", "").lower()
|
||||||
if eula_confirmed == 'true':
|
if eula_confirmed == "true":
|
||||||
confirm_logger.info("EULA确认成功,感谢您的同意")
|
confirm_logger.info("EULA确认成功,感谢您的同意")
|
||||||
return
|
return
|
||||||
|
|
||||||
if attempts % 5 == 0:
|
if attempts % 5 == 0:
|
||||||
confirm_logger.critical(f"请修改 .env 文件中的 EULA_CONFIRMED=true (尝试 {attempts}/{MAX_EULA_CHECK_ATTEMPTS})")
|
confirm_logger.critical(f"请修改 .env 文件中的 EULA_CONFIRMED=true (尝试 {attempts}/{MAX_EULA_CHECK_ATTEMPTS})")
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
confirm_logger.info("用户取消,程序退出")
|
confirm_logger.info("用户取消,程序退出")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
@@ -146,43 +144,43 @@ class EULAManager:
|
|||||||
if attempts >= MAX_EULA_CHECK_ATTEMPTS:
|
if attempts >= MAX_EULA_CHECK_ATTEMPTS:
|
||||||
confirm_logger.error("达到最大检查次数,程序退出")
|
confirm_logger.error("达到最大检查次数,程序退出")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
confirm_logger.error("EULA确认超时,程序退出")
|
confirm_logger.error("EULA确认超时,程序退出")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
class TaskManager:
|
class TaskManager:
|
||||||
"""任务管理器"""
|
"""任务管理器"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def cancel_pending_tasks(loop, timeout=SHUTDOWN_TIMEOUT):
|
async def cancel_pending_tasks(loop, timeout=SHUTDOWN_TIMEOUT):
|
||||||
"""取消所有待处理的任务"""
|
"""取消所有待处理的任务"""
|
||||||
remaining_tasks = [
|
remaining_tasks = [
|
||||||
t for t in asyncio.all_tasks(loop)
|
t for t in asyncio.all_tasks(loop)
|
||||||
if t is not asyncio.current_task(loop) and not t.done()
|
if t is not asyncio.current_task(loop) and not t.done()
|
||||||
]
|
]
|
||||||
|
|
||||||
if not remaining_tasks:
|
if not remaining_tasks:
|
||||||
logger.info("没有待取消的任务")
|
logger.info("没有待取消的任务")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
logger.info(f"正在取消 {len(remaining_tasks)} 个剩余任务...")
|
logger.info(f"正在取消 {len(remaining_tasks)} 个剩余任务...")
|
||||||
|
|
||||||
# 取消任务
|
# 取消任务
|
||||||
for task in remaining_tasks:
|
for task in remaining_tasks:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
# 等待任务完成
|
# 等待任务完成
|
||||||
try:
|
try:
|
||||||
results = await asyncio.wait_for(
|
results = await asyncio.wait_for(
|
||||||
asyncio.gather(*remaining_tasks, return_exceptions=True),
|
asyncio.gather(*remaining_tasks, return_exceptions=True),
|
||||||
timeout=timeout
|
timeout=timeout
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查任务结果
|
# 检查任务结果
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.warning(f"任务 {i} 取消时发生异常: {result}")
|
logger.warning(f"任务 {i} 取消时发生异常: {result}")
|
||||||
|
|
||||||
logger.info("所有剩余任务已成功取消")
|
logger.info("所有剩余任务已成功取消")
|
||||||
return True
|
return True
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
@@ -208,30 +206,30 @@ class TaskManager:
|
|||||||
|
|
||||||
class ShutdownManager:
|
class ShutdownManager:
|
||||||
"""关闭管理器"""
|
"""关闭管理器"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def graceful_shutdown(loop=None):
|
async def graceful_shutdown(loop=None):
|
||||||
"""优雅关闭程序"""
|
"""优雅关闭程序"""
|
||||||
try:
|
try:
|
||||||
logger.info("正在优雅关闭麦麦...")
|
logger.info("正在优雅关闭麦麦...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 停止异步任务
|
# 停止异步任务
|
||||||
tasks_stopped = await TaskManager.stop_async_tasks()
|
tasks_stopped = await TaskManager.stop_async_tasks()
|
||||||
|
|
||||||
# 取消待处理任务
|
# 取消待处理任务
|
||||||
tasks_cancelled = True
|
tasks_cancelled = True
|
||||||
if loop and not loop.is_closed():
|
if loop and not loop.is_closed():
|
||||||
tasks_cancelled = await TaskManager.cancel_pending_tasks(loop)
|
tasks_cancelled = await TaskManager.cancel_pending_tasks(loop)
|
||||||
|
|
||||||
shutdown_time = time.time() - start_time
|
shutdown_time = time.time() - start_time
|
||||||
success = tasks_stopped and tasks_cancelled
|
success = tasks_stopped and tasks_cancelled
|
||||||
|
|
||||||
if success:
|
if success:
|
||||||
logger.info(f"麦麦优雅关闭完成,耗时: {shutdown_time:.2f}秒")
|
logger.info(f"麦麦优雅关闭完成,耗时: {shutdown_time:.2f}秒")
|
||||||
else:
|
else:
|
||||||
logger.warning(f"麦麦关闭完成,但部分操作未成功,耗时: {shutdown_time:.2f}秒")
|
logger.warning(f"麦麦关闭完成,但部分操作未成功,耗时: {shutdown_time:.2f}秒")
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -264,29 +262,29 @@ async def create_event_loop_context():
|
|||||||
|
|
||||||
class DatabaseManager:
|
class DatabaseManager:
|
||||||
"""数据库连接管理器"""
|
"""数据库连接管理器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._connection = None
|
self._connection = None
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
"""异步上下文管理器入口"""
|
"""异步上下文管理器入口"""
|
||||||
try:
|
try:
|
||||||
from src.common.database.database import initialize_sql_database
|
from src.common.database.database import initialize_sql_database
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
logger.info("正在初始化数据库连接...")
|
logger.info("正在初始化数据库连接...")
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
# 使用线程执行器运行潜在的阻塞操作
|
# 使用线程执行器运行潜在的阻塞操作
|
||||||
await asyncio.to_thread(initialize_sql_database, global_config.database)
|
await asyncio.to_thread(initialize_sql_database, global_config.database)
|
||||||
elapsed_time = time.time() - start_time
|
elapsed_time = time.time() - start_time
|
||||||
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库,耗时: {elapsed_time:.2f}秒")
|
logger.info(f"数据库连接初始化成功,使用 {global_config.database.database_type} 数据库,耗时: {elapsed_time:.2f}秒")
|
||||||
|
|
||||||
return self
|
return self
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"数据库连接初始化失败: {e}")
|
logger.error(f"数据库连接初始化失败: {e}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||||
"""异步上下文管理器出口"""
|
"""异步上下文管理器出口"""
|
||||||
if exc_type:
|
if exc_type:
|
||||||
@@ -295,34 +293,34 @@ class DatabaseManager:
|
|||||||
|
|
||||||
class ConfigurationValidator:
|
class ConfigurationValidator:
|
||||||
"""配置验证器"""
|
"""配置验证器"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def validate_configuration():
|
def validate_configuration():
|
||||||
"""验证关键配置"""
|
"""验证关键配置"""
|
||||||
try:
|
try:
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|
||||||
# 检查必要的配置节
|
# 检查必要的配置节
|
||||||
required_sections = ['database', 'bot']
|
required_sections = ["database", "bot"]
|
||||||
for section in required_sections:
|
for section in required_sections:
|
||||||
if not hasattr(global_config, section):
|
if not hasattr(global_config, section):
|
||||||
logger.error(f"配置中缺少{section}配置节")
|
logger.error(f"配置中缺少{section}配置节")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
# 验证数据库配置
|
# 验证数据库配置
|
||||||
db_config = global_config.database
|
db_config = global_config.database
|
||||||
if not hasattr(db_config, 'database_type') or not db_config.database_type:
|
if not hasattr(db_config, "database_type") or not db_config.database_type:
|
||||||
logger.error("数据库配置缺少database_type字段")
|
logger.error("数据库配置缺少database_type字段")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if db_config.database_type not in SUPPORTED_DATABASES:
|
if db_config.database_type not in SUPPORTED_DATABASES:
|
||||||
logger.error(f"不支持的数据库类型: {db_config.database_type}")
|
logger.error(f"不支持的数据库类型: {db_config.database_type}")
|
||||||
logger.info(f"支持的数据库类型: {', '.join(SUPPORTED_DATABASES)}")
|
logger.info(f"支持的数据库类型: {', '.join(SUPPORTED_DATABASES)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
logger.info("配置验证通过")
|
logger.info("配置验证通过")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
except ImportError:
|
except ImportError:
|
||||||
logger.error("无法导入全局配置模块")
|
logger.error("无法导入全局配置模块")
|
||||||
return False
|
return False
|
||||||
@@ -332,16 +330,16 @@ class ConfigurationValidator:
|
|||||||
|
|
||||||
class EasterEgg:
|
class EasterEgg:
|
||||||
"""彩蛋功能"""
|
"""彩蛋功能"""
|
||||||
|
|
||||||
_initialized = False
|
_initialized = False
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def show(cls):
|
def show(cls):
|
||||||
"""显示彩色文本"""
|
"""显示彩色文本"""
|
||||||
if not cls._initialized:
|
if not cls._initialized:
|
||||||
init()
|
init()
|
||||||
cls._initialized = True
|
cls._initialized = True
|
||||||
|
|
||||||
text = "多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
|
text = "多年以后,面对AI行刑队,张三将会回想起他2023年在会议上讨论人工智能的那个下午"
|
||||||
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
|
rainbow_colors = [Fore.RED, Fore.YELLOW, Fore.GREEN, Fore.CYAN, Fore.BLUE, Fore.MAGENTA]
|
||||||
rainbow_text = ""
|
rainbow_text = ""
|
||||||
@@ -394,27 +392,27 @@ class MaiBotMain:
|
|||||||
"""执行同步初始化步骤"""
|
"""执行同步初始化步骤"""
|
||||||
self.setup_timezone()
|
self.setup_timezone()
|
||||||
await EULAManager.check_eula()
|
await EULAManager.check_eula()
|
||||||
|
|
||||||
if not ConfigurationValidator.validate_configuration():
|
if not ConfigurationValidator.validate_configuration():
|
||||||
raise RuntimeError("配置验证失败,请检查配置文件")
|
raise RuntimeError("配置验证失败,请检查配置文件")
|
||||||
|
|
||||||
return self.create_main_system()
|
return self.create_main_system()
|
||||||
|
|
||||||
async def run_async_init(self, main_system):
|
async def run_async_init(self, main_system):
|
||||||
"""执行异步初始化步骤"""
|
"""执行异步初始化步骤"""
|
||||||
# 初始化数据库连接
|
# 初始化数据库连接
|
||||||
await self.initialize_database()
|
await self.initialize_database()
|
||||||
|
|
||||||
# 初始化数据库表结构
|
# 初始化数据库表结构
|
||||||
await self.initialize_database_async()
|
await self.initialize_database_async()
|
||||||
|
|
||||||
# 初始化主系统
|
# 初始化主系统
|
||||||
await main_system.initialize()
|
await main_system.initialize()
|
||||||
|
|
||||||
# 初始化知识库
|
# 初始化知识库
|
||||||
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
from src.chat.knowledge.knowledge_lib import initialize_lpmm_knowledge
|
||||||
initialize_lpmm_knowledge()
|
initialize_lpmm_knowledge()
|
||||||
|
|
||||||
# 显示彩蛋
|
# 显示彩蛋
|
||||||
EasterEgg.show()
|
EasterEgg.show()
|
||||||
|
|
||||||
@@ -422,7 +420,7 @@ async def wait_for_user_input():
|
|||||||
"""等待用户输入(异步方式)"""
|
"""等待用户输入(异步方式)"""
|
||||||
try:
|
try:
|
||||||
# 在非生产环境下,使用异步方式等待输入
|
# 在非生产环境下,使用异步方式等待输入
|
||||||
if os.getenv('ENVIRONMENT') != 'production':
|
if os.getenv("ENVIRONMENT") != "production":
|
||||||
logger.info("程序执行完成,按 Ctrl+C 退出...")
|
logger.info("程序执行完成,按 Ctrl+C 退出...")
|
||||||
# 简单的异步等待,避免阻塞事件循环
|
# 简单的异步等待,避免阻塞事件循环
|
||||||
while True:
|
while True:
|
||||||
@@ -438,30 +436,30 @@ async def main_async():
|
|||||||
"""主异步函数"""
|
"""主异步函数"""
|
||||||
exit_code = 0
|
exit_code = 0
|
||||||
main_task = None
|
main_task = None
|
||||||
|
|
||||||
async with create_event_loop_context() as loop:
|
async with create_event_loop_context() as loop:
|
||||||
try:
|
try:
|
||||||
# 确保环境文件存在
|
# 确保环境文件存在
|
||||||
ConfigManager.ensure_env_file()
|
ConfigManager.ensure_env_file()
|
||||||
|
|
||||||
# 创建主程序实例并执行初始化
|
# 创建主程序实例并执行初始化
|
||||||
maibot = MaiBotMain()
|
maibot = MaiBotMain()
|
||||||
main_system = await maibot.run_sync_init()
|
main_system = await maibot.run_sync_init()
|
||||||
await maibot.run_async_init(main_system)
|
await maibot.run_async_init(main_system)
|
||||||
|
|
||||||
# 运行主任务
|
# 运行主任务
|
||||||
main_task = asyncio.create_task(main_system.schedule_tasks())
|
main_task = asyncio.create_task(main_system.schedule_tasks())
|
||||||
logger.info("麦麦机器人启动完成,开始运行主任务...")
|
logger.info("麦麦机器人启动完成,开始运行主任务...")
|
||||||
|
|
||||||
# 同时运行主任务和用户输入等待
|
# 同时运行主任务和用户输入等待
|
||||||
user_input_done = asyncio.create_task(wait_for_user_input())
|
user_input_done = asyncio.create_task(wait_for_user_input())
|
||||||
|
|
||||||
# 使用wait等待任意一个任务完成
|
# 使用wait等待任意一个任务完成
|
||||||
done, pending = await asyncio.wait(
|
done, pending = await asyncio.wait(
|
||||||
[main_task, user_input_done],
|
[main_task, user_input_done],
|
||||||
return_when=asyncio.FIRST_COMPLETED
|
return_when=asyncio.FIRST_COMPLETED
|
||||||
)
|
)
|
||||||
|
|
||||||
# 如果用户输入任务完成(用户按了Ctrl+C),取消主任务
|
# 如果用户输入任务完成(用户按了Ctrl+C),取消主任务
|
||||||
if user_input_done in done and main_task not in done:
|
if user_input_done in done and main_task not in done:
|
||||||
logger.info("用户请求退出,正在取消主任务...")
|
logger.info("用户请求退出,正在取消主任务...")
|
||||||
@@ -472,7 +470,7 @@ async def main_async():
|
|||||||
logger.info("主任务已取消")
|
logger.info("主任务已取消")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"主任务取消时发生错误: {e}")
|
logger.error(f"主任务取消时发生错误: {e}")
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.warning("收到中断信号,正在优雅关闭...")
|
logger.warning("收到中断信号,正在优雅关闭...")
|
||||||
if main_task and not main_task.done():
|
if main_task and not main_task.done():
|
||||||
@@ -481,7 +479,7 @@ async def main_async():
|
|||||||
logger.error(f"主程序发生异常: {e}")
|
logger.error(f"主程序发生异常: {e}")
|
||||||
logger.debug(f"异常详情: {traceback.format_exc()}")
|
logger.debug(f"异常详情: {traceback.format_exc()}")
|
||||||
exit_code = 1
|
exit_code = 1
|
||||||
|
|
||||||
return exit_code
|
return exit_code
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@@ -500,5 +498,5 @@ if __name__ == "__main__":
|
|||||||
shutdown_logging()
|
shutdown_logging()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"关闭日志系统时出错: {e}")
|
print(f"关闭日志系统时出错: {e}")
|
||||||
|
|
||||||
sys.exit(exit_code)
|
sys.exit(exit_code)
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from .memory_chunk import MemoryChunk as Memory
|
|||||||
|
|
||||||
# 遗忘引擎
|
# 遗忘引擎
|
||||||
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
|
from .memory_forgetting_engine import ForgettingConfig, MemoryForgettingEngine, get_memory_forgetting_engine
|
||||||
|
from .memory_formatter import format_memories_bracket_style
|
||||||
|
|
||||||
# 记忆管理器
|
# 记忆管理器
|
||||||
from .memory_manager import MemoryManager, MemoryResult, memory_manager
|
from .memory_manager import MemoryManager, MemoryResult, memory_manager
|
||||||
@@ -30,7 +31,6 @@ from .memory_system import MemorySystem, MemorySystemConfig, get_memory_system,
|
|||||||
|
|
||||||
# Vector DB存储系统
|
# Vector DB存储系统
|
||||||
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
from .vector_memory_storage_v2 import VectorMemoryStorage, VectorStorageConfig, get_vector_memory_storage
|
||||||
from .memory_formatter import format_memories_bracket_style
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# 核心数据结构
|
# 核心数据结构
|
||||||
|
|||||||
@@ -17,8 +17,9 @@
|
|||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, Iterable
|
|
||||||
import time
|
import time
|
||||||
|
from collections.abc import Iterable
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def _format_timestamp(ts: Any) -> str:
|
def _format_timestamp(ts: Any) -> str:
|
||||||
|
|||||||
@@ -2,9 +2,8 @@
|
|||||||
记忆元数据索引。
|
记忆元数据索引。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, asdict
|
from dataclasses import asdict, dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from time import time
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
@@ -12,6 +11,7 @@ logger = get_logger(__name__)
|
|||||||
|
|
||||||
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
|
from inkfox.memory import PyMetadataIndex as _RustIndex # type: ignore
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MemoryMetadataIndexEntry:
|
class MemoryMetadataIndexEntry:
|
||||||
memory_id: str
|
memory_id: str
|
||||||
@@ -51,7 +51,7 @@ class MemoryMetadataIndex:
|
|||||||
if payload:
|
if payload:
|
||||||
try:
|
try:
|
||||||
self._rust.batch_add(payload)
|
self._rust.batch_add(payload)
|
||||||
except Exception as ex: # noqa: BLE001
|
except Exception as ex:
|
||||||
logger.error(f"Rust 元数据批量添加失败: {ex}")
|
logger.error(f"Rust 元数据批量添加失败: {ex}")
|
||||||
|
|
||||||
def add_or_update(self, entry: MemoryMetadataIndexEntry):
|
def add_or_update(self, entry: MemoryMetadataIndexEntry):
|
||||||
@@ -88,7 +88,7 @@ class MemoryMetadataIndex:
|
|||||||
if flexible_mode:
|
if flexible_mode:
|
||||||
return list(self._rust.search_flexible(params))
|
return list(self._rust.search_flexible(params))
|
||||||
return list(self._rust.search_strict(params))
|
return list(self._rust.search_strict(params))
|
||||||
except Exception as ex: # noqa: BLE001
|
except Exception as ex:
|
||||||
logger.error(f"Rust 搜索失败返回空: {ex}")
|
logger.error(f"Rust 搜索失败返回空: {ex}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
@@ -105,18 +105,18 @@ class MemoryMetadataIndex:
|
|||||||
"keywords_count": raw.get("keywords_indexed", 0),
|
"keywords_count": raw.get("keywords_indexed", 0),
|
||||||
"tags_count": raw.get("tags_indexed", 0),
|
"tags_count": raw.get("tags_indexed", 0),
|
||||||
}
|
}
|
||||||
except Exception as ex: # noqa: BLE001
|
except Exception as ex:
|
||||||
logger.warning(f"读取 Rust stats 失败: {ex}")
|
logger.warning(f"读取 Rust stats 失败: {ex}")
|
||||||
return {"total_memories": 0}
|
return {"total_memories": 0}
|
||||||
|
|
||||||
def save(self): # 仅调用 rust save
|
def save(self): # 仅调用 rust save
|
||||||
try:
|
try:
|
||||||
self._rust.save()
|
self._rust.save()
|
||||||
except Exception as ex: # noqa: BLE001
|
except Exception as ex:
|
||||||
logger.warning(f"Rust save 失败: {ex}")
|
logger.warning(f"Rust save 失败: {ex}")
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MemoryMetadataIndexEntry",
|
|
||||||
"MemoryMetadataIndex",
|
"MemoryMetadataIndex",
|
||||||
|
"MemoryMetadataIndexEntry",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ class MessageRecv(Message):
|
|||||||
logger.warning("视频消息中没有base64数据")
|
logger.warning("视频消息中没有base64数据")
|
||||||
return "[收到视频消息,但数据异常]"
|
return "[收到视频消息,但数据异常]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"视频处理失败: {str(e)}")
|
logger.error(f"视频处理失败: {e!s}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||||
@@ -277,7 +277,7 @@ class MessageRecv(Message):
|
|||||||
logger.info("未启用视频识别")
|
logger.info("未启用视频识别")
|
||||||
return "[视频]"
|
return "[视频]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||||
return f"[处理失败的{segment.type}消息]"
|
return f"[处理失败的{segment.type}消息]"
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
"""纯 inkfox 视频关键帧分析工具
|
"""纯 inkfox 视频关键帧分析工具
|
||||||
|
|
||||||
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
|
仅依赖 `inkfox.video` 提供的 Rust 扩展能力:
|
||||||
@@ -14,25 +13,25 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
|
||||||
import io
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
import tempfile
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import List, Tuple, Optional, Dict, Any
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import io
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
|
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config, model_config
|
from src.config.config import global_config, model_config
|
||||||
from src.llm_models.utils_model import LLMRequest
|
from src.llm_models.utils_model import LLMRequest
|
||||||
from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore
|
|
||||||
|
|
||||||
# 简易并发控制:同一 hash 只处理一次
|
# 简易并发控制:同一 hash 只处理一次
|
||||||
_video_locks: Dict[str, asyncio.Lock] = {}
|
_video_locks: dict[str, asyncio.Lock] = {}
|
||||||
_locks_guard = asyncio.Lock()
|
_locks_guard = asyncio.Lock()
|
||||||
|
|
||||||
logger = get_logger("utils_video")
|
logger = get_logger("utils_video")
|
||||||
@@ -90,7 +89,7 @@ class VideoAnalyzer:
|
|||||||
logger.debug(f"获取系统信息失败: {e}")
|
logger.debug(f"获取系统信息失败: {e}")
|
||||||
|
|
||||||
# ---- 关键帧提取 ----
|
# ---- 关键帧提取 ----
|
||||||
async def extract_keyframes(self, video_path: str) -> List[Tuple[str, float]]:
|
async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]:
|
||||||
"""提取关键帧并返回 (base64, timestamp_seconds) 列表"""
|
"""提取关键帧并返回 (base64, timestamp_seconds) 列表"""
|
||||||
with tempfile.TemporaryDirectory() as tmp:
|
with tempfile.TemporaryDirectory() as tmp:
|
||||||
result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
|
result = video.extract_keyframes_from_video( # type: ignore[attr-defined]
|
||||||
@@ -105,7 +104,7 @@ class VideoAnalyzer:
|
|||||||
)
|
)
|
||||||
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
|
files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames]
|
||||||
total_ms = getattr(result, "total_time_ms", 0)
|
total_ms = getattr(result, "total_time_ms", 0)
|
||||||
frames: List[Tuple[str, float]] = []
|
frames: list[tuple[str, float]] = []
|
||||||
for i, f in enumerate(files):
|
for i, f in enumerate(files):
|
||||||
img = Image.open(f).convert("RGB")
|
img = Image.open(f).convert("RGB")
|
||||||
if max(img.size) > self.max_image_size:
|
if max(img.size) > self.max_image_size:
|
||||||
@@ -119,7 +118,7 @@ class VideoAnalyzer:
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
# ---- 批量分析 ----
|
# ---- 批量分析 ----
|
||||||
async def _analyze_batch(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||||
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
from src.llm_models.payload_content.message import MessageBuilder, RoleType
|
||||||
from src.llm_models.utils_model import RequestType
|
from src.llm_models.utils_model import RequestType
|
||||||
prompt = self.batch_analysis_prompt.format(
|
prompt = self.batch_analysis_prompt.format(
|
||||||
@@ -149,8 +148,8 @@ class VideoAnalyzer:
|
|||||||
return resp.content or "❌ 未获得响应"
|
return resp.content or "❌ 未获得响应"
|
||||||
|
|
||||||
# ---- 逐帧分析 ----
|
# ---- 逐帧分析 ----
|
||||||
async def _analyze_sequential(self, frames: List[Tuple[str, float]], question: Optional[str]) -> str:
|
async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str:
|
||||||
results: List[str] = []
|
results: list[str] = []
|
||||||
for i, (b64, ts) in enumerate(frames):
|
for i, (b64, ts) in enumerate(frames):
|
||||||
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "")
|
||||||
if question:
|
if question:
|
||||||
@@ -174,7 +173,7 @@ class VideoAnalyzer:
|
|||||||
return "\n".join(results)
|
return "\n".join(results)
|
||||||
|
|
||||||
# ---- 主入口 ----
|
# ---- 主入口 ----
|
||||||
async def analyze_video(self, video_path: str, question: Optional[str] = None) -> Tuple[bool, str]:
|
async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]:
|
||||||
if not os.path.exists(video_path):
|
if not os.path.exists(video_path):
|
||||||
return False, "❌ 文件不存在"
|
return False, "❌ 文件不存在"
|
||||||
frames = await self.extract_keyframes(video_path)
|
frames = await self.extract_keyframes(video_path)
|
||||||
@@ -189,10 +188,10 @@ class VideoAnalyzer:
|
|||||||
async def analyze_video_from_bytes(
|
async def analyze_video_from_bytes(
|
||||||
self,
|
self,
|
||||||
video_bytes: bytes,
|
video_bytes: bytes,
|
||||||
filename: Optional[str] = None,
|
filename: str | None = None,
|
||||||
prompt: Optional[str] = None,
|
prompt: str | None = None,
|
||||||
question: Optional[str] = None,
|
question: str | None = None,
|
||||||
) -> Dict[str, str]:
|
) -> dict[str, str]:
|
||||||
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
|
"""从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}."""
|
||||||
if not video_bytes:
|
if not video_bytes:
|
||||||
return {"summary": "❌ 空视频数据"}
|
return {"summary": "❌ 空视频数据"}
|
||||||
@@ -271,7 +270,7 @@ class VideoAnalyzer:
|
|||||||
|
|
||||||
|
|
||||||
# ---- 外部接口 ----
|
# ---- 外部接口 ----
|
||||||
_INSTANCE: Optional[VideoAnalyzer] = None
|
_INSTANCE: VideoAnalyzer | None = None
|
||||||
|
|
||||||
|
|
||||||
def get_video_analyzer() -> VideoAnalyzer:
|
def get_video_analyzer() -> VideoAnalyzer:
|
||||||
@@ -285,7 +284,7 @@ def is_video_analysis_available() -> bool:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def get_video_analysis_status() -> Dict[str, Any]:
|
def get_video_analysis_status() -> dict[str, Any]:
|
||||||
try:
|
try:
|
||||||
info = video.get_system_info() # type: ignore[attr-defined]
|
info = video.get_system_info() # type: ignore[attr-defined]
|
||||||
except Exception as e: # pragma: no cover
|
except Exception as e: # pragma: no cover
|
||||||
@@ -297,4 +296,4 @@ def get_video_analysis_status() -> Dict[str, Any]:
|
|||||||
"modes": ["auto", "batch", "sequential"],
|
"modes": ["auto", "batch", "sequential"],
|
||||||
"max_frames_default": inst.max_frames,
|
"max_frames_default": inst.max_frames,
|
||||||
"implementation": "inkfox",
|
"implementation": "inkfox",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,8 +53,8 @@ class StreamContext(BaseDataModel):
|
|||||||
priority_mode: str | None = None
|
priority_mode: str | None = None
|
||||||
priority_info: dict | None = None
|
priority_info: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def add_action_to_message(self, message_id: str, action: str):
|
def add_action_to_message(self, message_id: str, action: str):
|
||||||
"""
|
"""
|
||||||
向指定消息添加执行的动作
|
向指定消息添加执行的动作
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ MCP (Model Context Protocol) SSE (Server-Sent Events) 客户端实现
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import io
|
import io
|
||||||
import json
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -20,7 +19,6 @@ from ..exceptions import (
|
|||||||
NetworkConnectionError,
|
NetworkConnectionError,
|
||||||
ReqAbortException,
|
ReqAbortException,
|
||||||
RespNotOkException,
|
RespNotOkException,
|
||||||
RespParseException,
|
|
||||||
)
|
)
|
||||||
from ..payload_content.message import Message, RoleType
|
from ..payload_content.message import Message, RoleType
|
||||||
from ..payload_content.resp_format import RespFormat
|
from ..payload_content.resp_format import RespFormat
|
||||||
|
|||||||
68
src/main.py
68
src/main.py
@@ -6,7 +6,7 @@ import time
|
|||||||
import traceback
|
import traceback
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from random import choices
|
from random import choices
|
||||||
from typing import Any, List, Tuple
|
from typing import Any
|
||||||
|
|
||||||
from maim_message import MessageServer
|
from maim_message import MessageServer
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
@@ -36,7 +36,7 @@ install(extra_lines=3)
|
|||||||
logger = get_logger("main")
|
logger = get_logger("main")
|
||||||
|
|
||||||
# 预定义彩蛋短语,避免在每次初始化时重新创建
|
# 预定义彩蛋短语,避免在每次初始化时重新创建
|
||||||
EGG_PHRASES: List[Tuple[str, int]] = [
|
EGG_PHRASES: list[tuple[str, int]] = [
|
||||||
("我们的代码里真的没有bug,只有'特性'。", 10),
|
("我们的代码里真的没有bug,只有'特性'。", 10),
|
||||||
("你知道吗?阿范喜欢被切成臊子😡", 10),
|
("你知道吗?阿范喜欢被切成臊子😡", 10),
|
||||||
("你知道吗,雅诺狐的耳朵其实很好摸", 5),
|
("你知道吗,雅诺狐的耳朵其实很好摸", 5),
|
||||||
@@ -69,22 +69,22 @@ def _task_done_callback(task: asyncio.Task, message_id: str, start_time: float)
|
|||||||
|
|
||||||
class MainSystem:
|
class MainSystem:
|
||||||
"""主系统类,负责协调所有组件"""
|
"""主系统类,负责协调所有组件"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
# 使用增强记忆系统
|
# 使用增强记忆系统
|
||||||
self.memory_manager = memory_manager
|
self.memory_manager = memory_manager
|
||||||
self.individuality: Individuality = get_individuality()
|
self.individuality: Individuality = get_individuality()
|
||||||
|
|
||||||
# 使用消息API替代直接的FastAPI实例
|
# 使用消息API替代直接的FastAPI实例
|
||||||
self.app: MessageServer = get_global_api()
|
self.app: MessageServer = get_global_api()
|
||||||
self.server: Server = get_global_server()
|
self.server: Server = get_global_server()
|
||||||
|
|
||||||
# 设置信号处理器用于优雅退出
|
# 设置信号处理器用于优雅退出
|
||||||
self._shutting_down = False
|
self._shutting_down = False
|
||||||
self._setup_signal_handlers()
|
self._setup_signal_handlers()
|
||||||
|
|
||||||
# 存储清理任务的引用
|
# 存储清理任务的引用
|
||||||
self._cleanup_tasks: List[asyncio.Task] = []
|
self._cleanup_tasks: list[asyncio.Task] = []
|
||||||
|
|
||||||
def _setup_signal_handlers(self) -> None:
|
def _setup_signal_handlers(self) -> None:
|
||||||
"""设置信号处理器"""
|
"""设置信号处理器"""
|
||||||
@@ -92,7 +92,7 @@ class MainSystem:
|
|||||||
if self._shutting_down:
|
if self._shutting_down:
|
||||||
logger.warning("系统已经在关闭过程中,忽略重复信号")
|
logger.warning("系统已经在关闭过程中,忽略重复信号")
|
||||||
return
|
return
|
||||||
|
|
||||||
self._shutting_down = True
|
self._shutting_down = True
|
||||||
logger.info("收到退出信号,正在优雅关闭系统...")
|
logger.info("收到退出信号,正在优雅关闭系统...")
|
||||||
|
|
||||||
@@ -148,7 +148,7 @@ class MainSystem:
|
|||||||
|
|
||||||
# 尝试注册所有可用的计算器
|
# 尝试注册所有可用的计算器
|
||||||
registered_calculators = []
|
registered_calculators = []
|
||||||
|
|
||||||
for calc_name, calc_info in interest_calculators.items():
|
for calc_name, calc_info in interest_calculators.items():
|
||||||
enabled = getattr(calc_info, "enabled", True)
|
enabled = getattr(calc_info, "enabled", True)
|
||||||
default_enabled = getattr(calc_info, "enabled_by_default", True)
|
default_enabled = getattr(calc_info, "enabled_by_default", True)
|
||||||
@@ -169,7 +169,7 @@ class MainSystem:
|
|||||||
|
|
||||||
# 创建组件实例
|
# 创建组件实例
|
||||||
calculator_instance = component_class()
|
calculator_instance = component_class()
|
||||||
|
|
||||||
# 初始化组件
|
# 初始化组件
|
||||||
if not await calculator_instance.initialize():
|
if not await calculator_instance.initialize():
|
||||||
logger.error(f"兴趣计算器 {calc_name} 初始化失败")
|
logger.error(f"兴趣计算器 {calc_name} 初始化失败")
|
||||||
@@ -199,12 +199,12 @@ class MainSystem:
|
|||||||
"""异步清理资源"""
|
"""异步清理资源"""
|
||||||
if self._shutting_down:
|
if self._shutting_down:
|
||||||
return
|
return
|
||||||
|
|
||||||
self._shutting_down = True
|
self._shutting_down = True
|
||||||
logger.info("开始系统清理流程...")
|
logger.info("开始系统清理流程...")
|
||||||
|
|
||||||
cleanup_tasks = []
|
cleanup_tasks = []
|
||||||
|
|
||||||
# 停止数据库服务
|
# 停止数据库服务
|
||||||
try:
|
try:
|
||||||
from src.common.database.database import stop_database
|
from src.common.database.database import stop_database
|
||||||
@@ -236,14 +236,14 @@ class MainSystem:
|
|||||||
# 触发停止事件
|
# 触发停止事件
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.core.event_manager import event_manager
|
from src.plugin_system.core.event_manager import event_manager
|
||||||
cleanup_tasks.append(("插件系统停止事件",
|
cleanup_tasks.append(("插件系统停止事件",
|
||||||
event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")))
|
event_manager.trigger_event(EventType.ON_STOP, permission_group="SYSTEM")))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"准备触发停止事件时出错: {e}")
|
logger.error(f"准备触发停止事件时出错: {e}")
|
||||||
|
|
||||||
# 停止表情管理器
|
# 停止表情管理器
|
||||||
try:
|
try:
|
||||||
cleanup_tasks.append(("表情管理器",
|
cleanup_tasks.append(("表情管理器",
|
||||||
asyncio.get_event_loop().run_in_executor(None, get_emoji_manager().shutdown)))
|
asyncio.get_event_loop().run_in_executor(None, get_emoji_manager().shutdown)))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"准备停止表情管理器时出错: {e}")
|
logger.error(f"准备停止表情管理器时出错: {e}")
|
||||||
@@ -270,21 +270,21 @@ class MainSystem:
|
|||||||
logger.info(f"开始并行执行 {len(cleanup_tasks)} 个清理任务...")
|
logger.info(f"开始并行执行 {len(cleanup_tasks)} 个清理任务...")
|
||||||
tasks = [task for _, task in cleanup_tasks]
|
tasks = [task for _, task in cleanup_tasks]
|
||||||
task_names = [name for name, _ in cleanup_tasks]
|
task_names = [name for name, _ in cleanup_tasks]
|
||||||
|
|
||||||
# 使用asyncio.gather并行执行,设置超时防止卡死
|
# 使用asyncio.gather并行执行,设置超时防止卡死
|
||||||
try:
|
try:
|
||||||
results = await asyncio.wait_for(
|
results = await asyncio.wait_for(
|
||||||
asyncio.gather(*tasks, return_exceptions=True),
|
asyncio.gather(*tasks, return_exceptions=True),
|
||||||
timeout=30.0 # 30秒超时
|
timeout=30.0 # 30秒超时
|
||||||
)
|
)
|
||||||
|
|
||||||
# 记录结果
|
# 记录结果
|
||||||
for i, (name, result) in enumerate(zip(task_names, results)):
|
for i, (name, result) in enumerate(zip(task_names, results)):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.error(f"停止 {name} 时出错: {result}")
|
logger.error(f"停止 {name} 时出错: {result}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"🛑 {name} 已停止")
|
logger.info(f"🛑 {name} 已停止")
|
||||||
|
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
logger.error("清理任务超时,强制退出")
|
logger.error("清理任务超时,强制退出")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -311,16 +311,16 @@ class MainSystem:
|
|||||||
try:
|
try:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN")
|
message_id = message_data.get("message_info", {}).get("message_id", "UNKNOWN")
|
||||||
|
|
||||||
# 检查系统是否正在关闭
|
# 检查系统是否正在关闭
|
||||||
if self._shutting_down:
|
if self._shutting_down:
|
||||||
logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}")
|
logger.warning(f"系统正在关闭,拒绝处理消息 {message_id}")
|
||||||
return
|
return
|
||||||
|
|
||||||
# 创建后台任务
|
# 创建后台任务
|
||||||
task = asyncio.create_task(chat_bot.message_process(message_data))
|
task = asyncio.create_task(chat_bot.message_process(message_data))
|
||||||
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
|
logger.debug(f"已为消息 {message_id} 创建后台处理任务 (ID: {id(task)})")
|
||||||
|
|
||||||
# 添加一个回调函数,当任务完成时,它会被调用
|
# 添加一个回调函数,当任务完成时,它会被调用
|
||||||
task.add_done_callback(partial(_task_done_callback, message_id=message_id, start_time=start_time))
|
task.add_done_callback(partial(_task_done_callback, message_id=message_id, start_time=start_time))
|
||||||
except Exception:
|
except Exception:
|
||||||
@@ -330,19 +330,19 @@ class MainSystem:
|
|||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""初始化系统组件"""
|
"""初始化系统组件"""
|
||||||
# 检查必要的配置
|
# 检查必要的配置
|
||||||
if not hasattr(global_config, 'bot') or not hasattr(global_config.bot, 'nickname'):
|
if not hasattr(global_config, "bot") or not hasattr(global_config.bot, "nickname"):
|
||||||
logger.error("缺少必要的bot配置")
|
logger.error("缺少必要的bot配置")
|
||||||
raise ValueError("Bot配置不完整")
|
raise ValueError("Bot配置不完整")
|
||||||
|
|
||||||
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
logger.info(f"正在唤醒{global_config.bot.nickname}......")
|
||||||
|
|
||||||
# 初始化组件
|
# 初始化组件
|
||||||
await self._init_components()
|
await self._init_components()
|
||||||
|
|
||||||
# 随机选择彩蛋
|
# 随机选择彩蛋
|
||||||
egg_texts, weights = zip(*EGG_PHRASES)
|
egg_texts, weights = zip(*EGG_PHRASES)
|
||||||
selected_egg = choices(egg_texts, weights=weights, k=1)[0]
|
selected_egg = choices(egg_texts, weights=weights, k=1)[0]
|
||||||
|
|
||||||
logger.info(f"""
|
logger.info(f"""
|
||||||
全部系统初始化完成,{global_config.bot.nickname}已成功唤醒
|
全部系统初始化完成,{global_config.bot.nickname}已成功唤醒
|
||||||
=========================================================
|
=========================================================
|
||||||
@@ -367,7 +367,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
async_task_manager.add_task(StatisticOutputTask()),
|
async_task_manager.add_task(StatisticOutputTask()),
|
||||||
async_task_manager.add_task(TelemetryHeartBeatTask()),
|
async_task_manager.add_task(TelemetryHeartBeatTask()),
|
||||||
]
|
]
|
||||||
|
|
||||||
await asyncio.gather(*base_init_tasks, return_exceptions=True)
|
await asyncio.gather(*base_init_tasks, return_exceptions=True)
|
||||||
logger.info("基础定时任务初始化成功")
|
logger.info("基础定时任务初始化成功")
|
||||||
|
|
||||||
@@ -399,7 +399,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
|
|
||||||
# 处理所有缓存的事件订阅(插件加载完成后)
|
# 处理所有缓存的事件订阅(插件加载完成后)
|
||||||
event_manager.process_all_pending_subscriptions()
|
event_manager.process_all_pending_subscriptions()
|
||||||
|
|
||||||
# 初始化MCP工具提供器
|
# 初始化MCP工具提供器
|
||||||
try:
|
try:
|
||||||
mcp_config = global_config.get("mcp_servers", [])
|
mcp_config = global_config.get("mcp_servers", [])
|
||||||
@@ -412,24 +412,24 @@ MoFox_Bot(第三方修改版)
|
|||||||
|
|
||||||
# 并行初始化其他管理器
|
# 并行初始化其他管理器
|
||||||
manager_init_tasks = []
|
manager_init_tasks = []
|
||||||
|
|
||||||
# 表情管理器
|
# 表情管理器
|
||||||
manager_init_tasks.append(self._safe_init("表情包管理器", get_emoji_manager().initialize))
|
manager_init_tasks.append(self._safe_init("表情包管理器", get_emoji_manager().initialize))
|
||||||
|
|
||||||
# 情绪管理器
|
# 情绪管理器
|
||||||
manager_init_tasks.append(self._safe_init("情绪管理器", mood_manager.start))
|
manager_init_tasks.append(self._safe_init("情绪管理器", mood_manager.start))
|
||||||
|
|
||||||
# 聊天管理器
|
# 聊天管理器
|
||||||
manager_init_tasks.append(self._safe_init("聊天管理器", get_chat_manager()._initialize))
|
manager_init_tasks.append(self._safe_init("聊天管理器", get_chat_manager()._initialize))
|
||||||
|
|
||||||
# 等待所有管理器初始化完成
|
# 等待所有管理器初始化完成
|
||||||
results = await asyncio.gather(*manager_init_tasks, return_exceptions=True)
|
results = await asyncio.gather(*manager_init_tasks, return_exceptions=True)
|
||||||
|
|
||||||
# 检查初始化结果
|
# 检查初始化结果
|
||||||
for i, result in enumerate(results):
|
for i, result in enumerate(results):
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
logger.error(f"组件初始化失败: {result}")
|
logger.error(f"组件初始化失败: {result}")
|
||||||
|
|
||||||
# 启动聊天管理器的自动保存任务
|
# 启动聊天管理器的自动保存任务
|
||||||
asyncio.create_task(get_chat_manager()._auto_save_task())
|
asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||||
|
|
||||||
@@ -558,7 +558,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
"""关闭系统组件"""
|
"""关闭系统组件"""
|
||||||
if self._shutting_down:
|
if self._shutting_down:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("正在关闭MainSystem...")
|
logger.info("正在关闭MainSystem...")
|
||||||
await self._async_cleanup()
|
await self._async_cleanup()
|
||||||
logger.info("MainSystem关闭完成")
|
logger.info("MainSystem关闭完成")
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ from src.common.logger import get_logger
|
|||||||
from src.plugin_system.base.component_types import ActionInfo
|
from src.plugin_system.base.component_types import ActionInfo
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.replyer.default_generator import DefaultReplyer
|
pass
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ def get_tool_instance(tool_name: str) -> BaseTool | None:
|
|||||||
tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
tool_class: type[BaseTool] = component_registry.get_component_class(tool_name, ComponentType.TOOL) # type: ignore
|
||||||
if tool_class:
|
if tool_class:
|
||||||
return tool_class(plugin_config)
|
return tool_class(plugin_config)
|
||||||
|
|
||||||
# 如果不是常规工具,检查是否是MCP工具
|
# 如果不是常规工具,检查是否是MCP工具
|
||||||
# MCP工具不需要返回实例,会在execute_tool_call中特殊处理
|
# MCP工具不需要返回实例,会在execute_tool_call中特殊处理
|
||||||
return None
|
return None
|
||||||
@@ -35,7 +35,7 @@ def get_llm_available_tool_definitions():
|
|||||||
|
|
||||||
llm_available_tools = component_registry.get_llm_available_tools()
|
llm_available_tools = component_registry.get_llm_available_tools()
|
||||||
tool_definitions = [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
tool_definitions = [(name, tool_class.get_tool_definition()) for name, tool_class in llm_available_tools.items()]
|
||||||
|
|
||||||
# 添加MCP工具
|
# 添加MCP工具
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
||||||
@@ -45,5 +45,5 @@ def get_llm_available_tool_definitions():
|
|||||||
logger.debug(f"已添加 {len(mcp_tools)} 个MCP工具到可用工具列表")
|
logger.debug(f"已添加 {len(mcp_tools)} 个MCP工具到可用工具列表")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"获取MCP工具失败(可能未配置): {e}")
|
logger.debug(f"获取MCP工具失败(可能未配置): {e}")
|
||||||
|
|
||||||
return tool_definitions
|
return tool_definitions
|
||||||
|
|||||||
@@ -279,7 +279,7 @@ class ToolExecutor:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
|
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 检查是否是MCP工具
|
# 检查是否是MCP工具
|
||||||
try:
|
try:
|
||||||
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
from src.plugin_system.utils.mcp_tool_provider import mcp_tool_provider
|
||||||
@@ -295,7 +295,7 @@ class ToolExecutor:
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"检查MCP工具时出错: {e}")
|
logger.debug(f"检查MCP工具时出错: {e}")
|
||||||
|
|
||||||
function_args["llm_called"] = True # 标记为LLM调用
|
function_args["llm_called"] = True # 标记为LLM调用
|
||||||
|
|
||||||
# 检查是否是二步工具的第二步调用
|
# 检查是否是二步工具的第二步调用
|
||||||
|
|||||||
@@ -3,11 +3,9 @@ MCP (Model Context Protocol) 连接器
|
|||||||
负责连接MCP服务器,获取和执行工具
|
负责连接MCP服务器,获取和执行工具
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import orjson
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ MCP工具提供器 - 简化版
|
|||||||
直接集成到工具系统,无需复杂的插件架构
|
直接集成到工具系统,无需复杂的插件架构
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|||||||
@@ -4,9 +4,10 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
import orjson
|
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import orjson
|
||||||
|
|
||||||
from src.chat.interest_system import bot_interest_manager
|
from src.chat.interest_system import bot_interest_manager
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
|
|||||||
@@ -230,11 +230,11 @@ class ChatterPlanExecutor:
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_message = str(e)
|
error_message = str(e)
|
||||||
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
logger.error(f"执行回复动作失败: {action_info.action_type}, 错误: {error_message}")
|
||||||
'''
|
"""
|
||||||
# 记录用户关系追踪
|
# 记录用户关系追踪
|
||||||
if success and action_info.action_message:
|
if success and action_info.action_message:
|
||||||
await self._track_user_interaction(action_info, plan, reply_content)
|
await self._track_user_interaction(action_info, plan, reply_content)
|
||||||
'''
|
"""
|
||||||
execution_time = time.time() - start_time
|
execution_time = time.time() - start_time
|
||||||
self.execution_stats["execution_times"].append(execution_time)
|
self.execution_stats["execution_times"].append(execution_time)
|
||||||
|
|
||||||
|
|||||||
@@ -10,10 +10,10 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.mood.mood_manager import mood_manager
|
from src.mood.mood_manager import mood_manager
|
||||||
|
from src.plugin_system.base.component_types import ChatMode
|
||||||
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
from src.plugins.built_in.affinity_flow_chatter.plan_executor import ChatterPlanExecutor
|
||||||
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
from src.plugins.built_in.affinity_flow_chatter.plan_filter import ChatterPlanFilter
|
||||||
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
from src.plugins.built_in.affinity_flow_chatter.plan_generator import ChatterPlanGenerator
|
||||||
from src.plugin_system.base.component_types import ChatMode
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
|
|||||||
@@ -6,9 +6,7 @@ SearXNG search engine implementation
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
from typing import Any
|
||||||
import functools
|
|
||||||
from typing import Any, List
|
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
@@ -39,13 +37,13 @@ class SearXNGSearchEngine(BaseSearchEngine):
|
|||||||
instances = config_api.get_global_config("web_search.searxng_instances", None)
|
instances = config_api.get_global_config("web_search.searxng_instances", None)
|
||||||
if isinstance(instances, list):
|
if isinstance(instances, list):
|
||||||
# 过滤空值
|
# 过滤空值
|
||||||
self.instances: List[str] = [u.rstrip("/") for u in instances if isinstance(u, str) and u.strip()]
|
self.instances: list[str] = [u.rstrip("/") for u in instances if isinstance(u, str) and u.strip()]
|
||||||
else:
|
else:
|
||||||
self.instances = []
|
self.instances = []
|
||||||
|
|
||||||
api_keys = config_api.get_global_config("web_search.searxng_api_keys", None)
|
api_keys = config_api.get_global_config("web_search.searxng_api_keys", None)
|
||||||
if isinstance(api_keys, list):
|
if isinstance(api_keys, list):
|
||||||
self.api_keys: List[str | None] = [k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys]
|
self.api_keys: list[str | None] = [k.strip() if isinstance(k, str) and k.strip() else None for k in api_keys]
|
||||||
else:
|
else:
|
||||||
self.api_keys = []
|
self.api_keys = []
|
||||||
|
|
||||||
@@ -85,7 +83,7 @@ class SearXNGSearchEngine(BaseSearchEngine):
|
|||||||
results.extend(instance_results)
|
results.extend(instance_results)
|
||||||
if len(results) >= num_results:
|
if len(results) >= num_results:
|
||||||
break
|
break
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
logger.warning(f"SearXNG 实例 {base_url} 调用失败: {e}")
|
logger.warning(f"SearXNG 实例 {base_url} 调用失败: {e}")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -116,12 +114,12 @@ class SearXNGSearchEngine(BaseSearchEngine):
|
|||||||
try:
|
try:
|
||||||
resp = await self._client.get(url, params=params, headers=headers)
|
resp = await self._client.get(url, params=params, headers=headers)
|
||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
raise RuntimeError(f"请求失败: {e}") from e
|
raise RuntimeError(f"请求失败: {e}") from e
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
except Exception as e: # noqa: BLE001
|
except Exception as e:
|
||||||
raise RuntimeError(f"解析 JSON 失败: {e}") from e
|
raise RuntimeError(f"解析 JSON 失败: {e}") from e
|
||||||
|
|
||||||
raw_results = data.get("results", []) if isinstance(data, dict) else []
|
raw_results = data.get("results", []) if isinstance(data, dict) else []
|
||||||
@@ -141,5 +139,5 @@ class SearXNGSearchEngine(BaseSearchEngine):
|
|||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb): # noqa: D401
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
await self._client.aclose()
|
await self._client.aclose()
|
||||||
|
|||||||
@@ -41,8 +41,8 @@ class WEBSEARCHPLUGIN(BasePlugin):
|
|||||||
from .engines.bing_engine import BingSearchEngine
|
from .engines.bing_engine import BingSearchEngine
|
||||||
from .engines.ddg_engine import DDGSearchEngine
|
from .engines.ddg_engine import DDGSearchEngine
|
||||||
from .engines.exa_engine import ExaSearchEngine
|
from .engines.exa_engine import ExaSearchEngine
|
||||||
from .engines.tavily_engine import TavilySearchEngine
|
|
||||||
from .engines.searxng_engine import SearXNGSearchEngine
|
from .engines.searxng_engine import SearXNGSearchEngine
|
||||||
|
from .engines.tavily_engine import TavilySearchEngine
|
||||||
|
|
||||||
# 实例化所有搜索引擎,这会触发API密钥管理器的初始化
|
# 实例化所有搜索引擎,这会触发API密钥管理器的初始化
|
||||||
exa_engine = ExaSearchEngine()
|
exa_engine = ExaSearchEngine()
|
||||||
|
|||||||
@@ -13,8 +13,8 @@ from src.plugin_system.apis import config_api
|
|||||||
from ..engines.bing_engine import BingSearchEngine
|
from ..engines.bing_engine import BingSearchEngine
|
||||||
from ..engines.ddg_engine import DDGSearchEngine
|
from ..engines.ddg_engine import DDGSearchEngine
|
||||||
from ..engines.exa_engine import ExaSearchEngine
|
from ..engines.exa_engine import ExaSearchEngine
|
||||||
from ..engines.tavily_engine import TavilySearchEngine
|
|
||||||
from ..engines.searxng_engine import SearXNGSearchEngine
|
from ..engines.searxng_engine import SearXNGSearchEngine
|
||||||
|
from ..engines.tavily_engine import TavilySearchEngine
|
||||||
from ..utils.formatters import deduplicate_results, format_search_results
|
from ..utils.formatters import deduplicate_results, format_search_results
|
||||||
|
|
||||||
logger = get_logger("web_search_tool")
|
logger = get_logger("web_search_tool")
|
||||||
|
|||||||
Reference in New Issue
Block a user