From f308adcf5b5e14aa7a7aab94cda347679666a734 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=85=E8=AF=BA=E7=8B=90?= <212194964+foxcyber907@users.noreply.github.com> Date: Sun, 24 Aug 2025 22:11:20 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E6=B8=85=E7=90=86=E4=BB=A3?= =?UTF-8?q?=E7=A0=81=E8=B4=A8=E9=87=8F=E5=92=8C=E7=A7=BB=E9=99=A4=E6=9C=AA?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 移除未使用的导入语句和变量 - 修复代码风格问题(空格、格式化等) - 删除备份文件和测试文件 - 改进异常处理链式调用 - 添加权限系统数据库模型和配置 - 更新版本号至6.4.4 - 优化SQL查询使用正确的布尔表达式 --- .gitignore | 2 + bot.py | 13 +- docs/PERMISSION_SYSTEM.md | 196 ++++++++ src/chat/chat_loop/heartFC_chat.py | 1 - src/chat/chat_loop/response_handler.py | 3 +- src/chat/memory_system/action_diagnostics.py | 2 +- .../memory_system/async_memory_optimizer.py | 2 +- .../memory_system/vector_instant_memory.py | 4 +- src/chat/message_receive/bot.py | 4 +- src/chat/utils/utils_video.py | 6 +- src/common/cache_manager_backup.py | 344 ------------- src/common/database/monthly_plan_db.py | 4 +- src/common/database/sqlalchemy_models.py | 36 ++ src/config/config.py | 4 +- src/config/official_configs.py | 7 + src/llm_models/model_client/gemini_client.py | 28 +- src/main.py | 7 + src/plugin_system/apis/__init__.py | 2 + src/plugin_system/apis/permission_api.py | 339 +++++++++++++ src/plugin_system/core/permission_manager.py | 452 ++++++++++++++++++ .../utils/permission_decorators.py | 274 +++++++++++ .../services/qzone_service.py | 5 +- .../built_in/permission_management/plugin.py | 302 ++++++++++++ .../web_search_tool/engines/bing_engine.py | 2 +- template/bot_config_template.toml | 10 +- tests/test_wakeup_system.py | 292 ----------- 26 files changed, 1664 insertions(+), 677 deletions(-) create mode 100644 docs/PERMISSION_SYSTEM.md delete mode 100644 src/common/cache_manager_backup.py create mode 100644 src/plugin_system/apis/permission_api.py create mode 100644 src/plugin_system/core/permission_manager.py create mode 100644 src/plugin_system/utils/permission_decorators.py create mode 100644 src/plugins/built_in/permission_management/plugin.py delete mode 100644 tests/test_wakeup_system.py diff --git a/.gitignore b/.gitignore index c8aa3bec3..c0db9b4c5 100644 --- a/.gitignore +++ b/.gitignore @@ -330,3 +330,5 @@ config.toml interested_rates.txt MaiBot.code-workspace +/tests +/tests diff --git a/bot.py b/bot.py index 6904a3c5a..80b3f9f1b 100644 --- a/bot.py +++ b/bot.py @@ -22,14 +22,15 @@ else: # 最早期初始化日志系统,确保所有后续模块都使用正确的日志格式 from src.common.logger import initialize_logging, get_logger, shutdown_logging + initialize_logging() -from src.main import MainSystem #noqa -from src import BaseMain -from src.manager.async_task_manager import async_task_manager #noqa -from src.config.config import global_config -from src.common.database.database import initialize_sql_database -from src.common.database.sqlalchemy_models import initialize_database as init_db +from src.main import MainSystem # noqa +from src import BaseMain # noqa +from src.manager.async_task_manager import async_task_manager # noqa +from src.config.config import global_config # noqa +from src.common.database.database import initialize_sql_database # noqa +from src.common.database.sqlalchemy_models import initialize_database as init_db # noqa logger = get_logger("main") diff --git a/docs/PERMISSION_SYSTEM.md b/docs/PERMISSION_SYSTEM.md new file mode 100644 index 000000000..38687463c --- /dev/null +++ b/docs/PERMISSION_SYSTEM.md @@ -0,0 +1,196 @@ +# 权限系统使用说明 + +## 概述 + +MaiBot的权限系统提供了完整的权限管理功能,支持权限等级和权限节点配置。系统包含以下核心概念: + +- **Master用户**:拥有最高权限,无视所有权限节点,在配置文件中设置 +- **权限节点**:细粒度的权限控制单元,由插件自行创建和管理 +- **权限管理**:统一的权限授权、撤销和查询功能 + +## 配置文件设置 + +在 `config/bot_config.toml` 中添加权限配置: + +```toml +[permission] # 权限系统配置 +# Master用户配置(拥有最高权限,无视所有权限节点) +# 格式:[[platform, user_id], ...] +master_users = [ + ["qq", "123456789"], # QQ平台的Master用户 + ["qq", "987654321"], # 可以配置多个Master用户 +] +``` + +## 插件开发中使用权限系统 + +### 1. 注册权限节点 + +在插件的 `on_load()` 方法中注册权限节点: + +```python +from src.plugin_system.apis.permission_api import permission_api + +class MyPlugin(BasePlugin): + def on_load(self): + # 注册权限节点 + permission_api.register_permission_node( + "plugin.myplugin.admin", # 权限节点名称 + "我的插件管理员权限", # 权限描述 + "myplugin", # 插件名称 + False # 默认是否授权(False=默认拒绝) + ) + + permission_api.register_permission_node( + "plugin.myplugin.user", + "我的插件用户权限", + "myplugin", + True # 默认授权 + ) +``` + +### 2. 使用权限装饰器 + +最简单的权限检查方式是使用装饰器: + +```python +from src.plugin_system.utils.permission_decorators import require_permission, require_master + +class MyCommand(BaseCommand): + @require_permission("plugin.myplugin.admin") + async def execute(self, message: Message, chat_stream: ChatStream, args: List[str]): + await send_message(chat_stream, "你有管理员权限!") + + @require_master("只有Master可以执行此操作") + async def master_only_function(self, message: Message, chat_stream: ChatStream): + await send_message(chat_stream, "Master专用功能") +``` + +### 3. 手动权限检查 + +对于更复杂的权限逻辑,可以手动检查权限: + +```python +from src.plugin_system.utils.permission_decorators import PermissionChecker + +class MyCommand(BaseCommand): + async def execute(self, message: Message, chat_stream: ChatStream, args: List[str]): + # 检查是否为Master用户 + if PermissionChecker.is_master(chat_stream): + await send_message(chat_stream, "Master用户可以执行所有操作") + return + + # 检查特定权限 + if PermissionChecker.check_permission(chat_stream, "plugin.myplugin.read"): + await send_message(chat_stream, "你可以读取数据") + + # 使用 ensure_permission 自动发送权限不足消息 + if await PermissionChecker.ensure_permission(chat_stream, "plugin.myplugin.write"): + await send_message(chat_stream, "你可以写入数据") +``` + +### 4. 直接使用权限API + +```python +from src.plugin_system.apis.permission_api import permission_api + +# 检查权限 +has_permission = permission_api.check_permission("qq", "123456", "plugin.myplugin.admin") + +# 检查是否为Master +is_master = permission_api.is_master("qq", "123456") + +# 授权用户 +success = permission_api.grant_permission("qq", "123456", "plugin.myplugin.admin") + +# 撤销权限 +success = permission_api.revoke_permission("qq", "123456", "plugin.myplugin.admin") + +# 获取用户的所有权限 +permissions = permission_api.get_user_permissions("qq", "123456") + +# 获取所有权限节点 +all_nodes = permission_api.get_all_permission_nodes() + +# 获取指定插件的权限节点 +plugin_nodes = permission_api.get_plugin_permission_nodes("myplugin") +``` + +## 权限管理命令 + +系统提供了内置的权限管理命令,需要相应权限才能使用: + +### 管理员命令(需要 `plugin.permission.manage` 权限) + +``` +# 授权用户权限 +/permission grant @用户 plugin.example.admin +/permission grant 123456789 plugin.example.admin + +# 撤销用户权限 +/permission revoke @用户 plugin.example.admin +/permission revoke 123456789 plugin.example.admin +``` + +### 查看命令(需要 `plugin.permission.view` 权限) + +``` +# 查看用户权限列表 +/permission list @用户 +/permission list 123456789 +/permission list # 查看自己的权限 + +# 检查用户是否拥有权限 +/permission check @用户 plugin.example.admin +/permission check 123456789 plugin.example.admin + +# 查看权限节点列表 +/permission nodes # 查看所有权限节点 +/permission nodes example_plugin # 查看指定插件的权限节点 +``` + +### 帮助命令 + +``` +/permission help # 显示帮助信息 +``` + +## 权限节点命名规范 + +建议使用以下命名规范: + +``` +plugin.<插件名>.<功能类别>.<具体权限> +``` + +示例: +- `plugin.music.play` - 音乐插件播放权限 +- `plugin.music.admin` - 音乐插件管理权限 +- `plugin.game.user` - 游戏插件用户权限 +- `plugin.game.room.create` - 游戏插件房间创建权限 + +## 权限系统数据库表 + +系统会自动创建以下数据库表: + +1. **permission_nodes** - 存储权限节点信息 +2. **user_permissions** - 存储用户权限授权记录 + +## 最佳实践 + +1. **细粒度权限**:为不同功能创建独立的权限节点 +2. **默认权限设置**:谨慎设置默认权限,敏感操作应默认拒绝 +3. **权限描述**:为每个权限节点提供清晰的描述 +4. **Master用户**:只为真正的管理员分配Master权限 +5. **权限检查**:在执行敏感操作前始终检查权限 + +## 示例插件 + +查看 `plugins/permission_example.py` 了解完整的权限系统使用示例。 + +## 故障排除 + +1. **权限检查失败**:确保权限节点已正确注册 +2. **Master用户配置**:检查配置文件中的用户ID格式是否正确 +3. **权限不生效**:重启机器人以重新加载配置 +4. **数据库问题**:检查数据库连接和表结构是否正确 diff --git a/src/chat/chat_loop/heartFC_chat.py b/src/chat/chat_loop/heartFC_chat.py index 09481f18b..da67eac81 100644 --- a/src/chat/chat_loop/heartFC_chat.py +++ b/src/chat/chat_loop/heartFC_chat.py @@ -1,7 +1,6 @@ import asyncio import time import traceback -import re from typing import Optional from src.common.logger import get_logger diff --git a/src/chat/chat_loop/response_handler.py b/src/chat/chat_loop/response_handler.py index 2703c4611..9c0b4976a 100644 --- a/src/chat/chat_loop/response_handler.py +++ b/src/chat/chat_loop/response_handler.py @@ -142,11 +142,10 @@ class ResponseHandler: # 修正:正确处理元组格式 (格式为: (type, content)) if isinstance(reply_seg, tuple) and len(reply_seg) >= 2: - reply_type, data = reply_seg + _, data = reply_seg else: # 向下兼容:如果已经是字符串,则直接使用 data = str(reply_seg) - reply_type = "text" reply_text += data diff --git a/src/chat/memory_system/action_diagnostics.py b/src/chat/memory_system/action_diagnostics.py index ba14e5222..e0d0650db 100644 --- a/src/chat/memory_system/action_diagnostics.py +++ b/src/chat/memory_system/action_diagnostics.py @@ -214,7 +214,7 @@ class ActionDiagnostics: raise Exception("注册失败") except Exception as e: - raise Exception(f"手动注册no_reply Action失败: {e}") + raise Exception(f"手动注册no_reply Action失败: {e}") from e def run_full_diagnosis(self) -> Dict[str, Any]: """运行完整诊断""" diff --git a/src/chat/memory_system/async_memory_optimizer.py b/src/chat/memory_system/async_memory_optimizer.py index 77855ba56..61311ff5c 100644 --- a/src/chat/memory_system/async_memory_optimizer.py +++ b/src/chat/memory_system/async_memory_optimizer.py @@ -131,7 +131,7 @@ class AsyncMemoryQueue: await task.callback(None) else: task.callback(None) - except: + except Exception: pass async def _handle_store_task(self, task: MemoryTask) -> Any: diff --git a/src/chat/memory_system/vector_instant_memory.py b/src/chat/memory_system/vector_instant_memory.py index a95115d78..201076ffb 100644 --- a/src/chat/memory_system/vector_instant_memory.py +++ b/src/chat/memory_system/vector_instant_memory.py @@ -271,7 +271,7 @@ class VectorInstantMemoryV2: return f"{int(diff/3600)}小时前" else: return f"{int(diff/86400)}天前" - except: + except Exception: return "时间格式错误" async def get_memory_for_context(self, current_message: str, context_size: int = 3) -> str: @@ -318,7 +318,7 @@ class VectorInstantMemoryV2: try: result = self.collection.count() stats["total_messages"] = result - except: + except Exception: stats["total_messages"] = "查询失败" return stats diff --git a/src/chat/message_receive/bot.py b/src/chat/message_receive/bot.py index e6a36fe24..226f2ff7d 100644 --- a/src/chat/message_receive/bot.py +++ b/src/chat/message_receive/bot.py @@ -16,11 +16,9 @@ from src.chat.utils.prompt_builder import Prompt, global_prompt_manager from src.plugin_system.core import component_registry, events_manager, global_announcement_manager from src.plugin_system.base import BaseCommand, EventType from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor -from src.plugin_system.apis import send_api # 导入反注入系统 -from src.chat.antipromptinjector import get_anti_injector, initialize_anti_injector -from src.chat.antipromptinjector.types import ProcessResult +from src.chat.antipromptinjector import initialize_anti_injector # 定义日志配置 diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 179231f84..f68118580 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -458,7 +458,7 @@ class VideoAnalyzer: try: # 等待处理完成的事件信号,最多等待60秒 await asyncio.wait_for(video_event.wait(), timeout=60.0) - logger.info(f"✅ 等待结束,检查是否有处理结果") + logger.info("✅ 等待结束,检查是否有处理结果") # 检查是否有结果了 existing_video = self._check_video_exists(video_hash) @@ -466,9 +466,9 @@ class VideoAnalyzer: logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})") return {"summary": existing_video.description} else: - logger.warning(f"⚠️ 等待完成但未找到结果,可能处理失败") + logger.warning("⚠️ 等待完成但未找到结果,可能处理失败") except asyncio.TimeoutError: - logger.warning(f"⚠️ 等待超时(60秒),放弃等待") + logger.warning("⚠️ 等待超时(60秒),放弃等待") # 获取锁开始处理 async with video_lock: diff --git a/src/common/cache_manager_backup.py b/src/common/cache_manager_backup.py deleted file mode 100644 index ecaff3458..000000000 --- a/src/common/cache_manager_backup.py +++ /dev/null @@ -1,344 +0,0 @@ -import json -import hashlib -import re -from typing import Any, Dict, Optional -from datetime import datetime, timedelta -from pathlib import Path -from difflib import SequenceMatcher - -from src.common.logger import get_logger - -logger = get_logger("cache_manager") - - -class ToolCache: - """工具缓存管理器,用于缓存工具调用结果,支持近似匹配""" - - def __init__( - self, - cache_dir: str = "data/tool_cache", - max_age_hours: int = 24, - similarity_threshold: float = 0.65, - ): - """ - 初始化缓存管理器 - - Args: - cache_dir: 缓存目录路径 - max_age_hours: 缓存最大存活时间(小时) - similarity_threshold: 近似匹配的相似度阈值 (0-1) - """ - self.cache_dir = Path(cache_dir) - self.max_age = timedelta(hours=max_age_hours) - self.max_age_seconds = max_age_hours * 3600 - self.similarity_threshold = similarity_threshold - self.cache_dir.mkdir(parents=True, exist_ok=True) - - @staticmethod - def _normalize_query(query: str) -> str: - """ - 标准化查询文本,用于相似度比较 - - Args: - query: 原始查询文本 - - Returns: - 标准化后的查询文本 - """ - if not query: - return "" - - # 纯 Python 实现 - normalized = query.lower() - normalized = re.sub(r"[^\w\s]", " ", normalized) - normalized = " ".join(normalized.split()) - return normalized - - def _calculate_similarity(self, text1: str, text2: str) -> float: - """ - 计算两个文本的相似度 - - Args: - text1: 文本1 - text2: 文本2 - - Returns: - 相似度分数 (0-1) - """ - if not text1 or not text2: - return 0.0 - - # 纯 Python 实现 - norm_text1 = self._normalize_query(text1) - norm_text2 = self._normalize_query(text2) - - if norm_text1 == norm_text2: - return 1.0 - - return SequenceMatcher(None, norm_text1, norm_text2).ratio() - - @staticmethod - def _generate_cache_key(tool_name: str, function_args: Dict[str, Any]) -> str: - """ - 生成缓存键 - - Args: - tool_name: 工具名称 - function_args: 函数参数 - - Returns: - 缓存键字符串 - """ - # 将参数排序后序列化,确保相同参数产生相同的键 - sorted_args = json.dumps(function_args, sort_keys=True, ensure_ascii=False) - - # 纯 Python 实现 - cache_string = f"{tool_name}:{sorted_args}" - return hashlib.md5(cache_string.encode("utf-8")).hexdigest() - - def _get_cache_file_path(self, cache_key: str) -> Path: - """获取缓存文件路径""" - return self.cache_dir / f"{cache_key}.json" - - def _is_cache_expired(self, cached_time: datetime) -> bool: - """检查缓存是否过期""" - return datetime.now() - cached_time > self.max_age - - def _find_similar_cache( - self, tool_name: str, function_args: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - """ - 查找相似的缓存条目 - - Args: - tool_name: 工具名称 - function_args: 函数参数 - - Returns: - 相似的缓存结果,如果不存在则返回None - """ - query = function_args.get("query", "") - if not query: - return None - - candidates = [] - cache_data_list = [] - - # 遍历所有缓存文件,收集候选项 - for cache_file in self.cache_dir.glob("*.json"): - try: - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - # 检查是否是同一个工具 - if cache_data.get("tool_name") != tool_name: - continue - - # 检查缓存是否过期 - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - continue - - # 检查其他参数是否匹配(除了query) - cached_args = cache_data.get("function_args", {}) - args_match = True - for key, value in function_args.items(): - if key != "query" and cached_args.get(key) != value: - args_match = False - break - - if not args_match: - continue - - # 收集候选项 - cached_query = cached_args.get("query", "") - candidates.append((cached_query, len(cache_data_list))) - cache_data_list.append(cache_data) - - except Exception as e: - logger.warning(f"检查缓存文件时出错: {cache_file}, 错误: {e}") - continue - - if not candidates: - logger.debug( - f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" - ) - return None - - # 纯 Python 实现 - best_match = None - best_similarity = 0.0 - - for cached_query, index in candidates: - similarity = self._calculate_similarity(query, cached_query) - if similarity > best_similarity and similarity >= self.similarity_threshold: - best_similarity = similarity - best_match = cache_data_list[index] - - if best_match is not None: - cached_query = best_match["function_args"].get("query", "") - logger.info( - f"相似缓存命中,相似度: {best_similarity:.2f}, 原查询: '{cached_query}', 当前查询: '{query}'" - ) - return best_match["result"] - - logger.debug( - f"未找到相似缓存: {tool_name}, 查询: '{query}',相似度阈值: {self.similarity_threshold}" - ) - return None - - def get( - self, tool_name: str, function_args: Dict[str, Any] - ) -> Optional[Dict[str, Any]]: - """ - 从缓存获取结果,支持精确匹配和近似匹配 - - Args: - tool_name: 工具名称 - function_args: 函数参数 - - Returns: - 缓存的结果,如果不存在或已过期则返回None - """ - # 首先尝试精确匹配 - cache_key = self._generate_cache_key(tool_name, function_args) - cache_file = self._get_cache_file_path(cache_key) - - if cache_file.exists(): - try: - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - # 检查缓存是否过期 - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - logger.debug(f"缓存已过期: {cache_key}") - cache_file.unlink() # 删除过期缓存 - else: - logger.debug(f"精确匹配缓存: {tool_name}") - return cache_data["result"] - - except (json.JSONDecodeError, KeyError, ValueError) as e: - logger.warning(f"读取缓存文件失败: {cache_file}, 错误: {e}") - # 删除损坏的缓存文件 - if cache_file.exists(): - cache_file.unlink() - - # 如果精确匹配失败,尝试近似匹配 - return self._find_similar_cache(tool_name, function_args) - - def set( - self, tool_name: str, function_args: Dict[str, Any], result: Dict[str, Any] - ) -> None: - """ - 将结果保存到缓存 - - Args: - tool_name: 工具名称 - function_args: 函数参数 - result: 缓存结果 - """ - cache_key = self._generate_cache_key(tool_name, function_args) - cache_file = self._get_cache_file_path(cache_key) - - cache_data = { - "tool_name": tool_name, - "function_args": function_args, - "result": result, - "timestamp": datetime.now().isoformat(), - } - - try: - with open(cache_file, "w", encoding="utf-8") as f: - json.dump(cache_data, f, ensure_ascii=False, indent=2) - logger.debug(f"缓存已保存: {tool_name} -> {cache_key}") - except Exception as e: - logger.error(f"保存缓存失败: {cache_file}, 错误: {e}") - - def clear_expired(self) -> int: - """ - 清理过期缓存 - - Returns: - 删除的文件数量 - """ - removed_count = 0 - - for cache_file in self.cache_dir.glob("*.json"): - try: - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - cache_file.unlink() - removed_count += 1 - logger.debug(f"删除过期缓存: {cache_file}") - - except Exception as e: - logger.warning(f"清理缓存文件时出错: {cache_file}, 错误: {e}") - # 删除损坏的文件 - try: - cache_file.unlink() - removed_count += 1 - except (OSError, json.JSONDecodeError, KeyError, ValueError): - logger.warning(f"删除损坏的缓存文件失败: {cache_file}, 错误: {e}") - - logger.info(f"清理完成,删除了 {removed_count} 个过期缓存文件") - return removed_count - - def clear_all(self) -> int: - """ - 清空所有缓存 - - Returns: - 删除的文件数量 - """ - removed_count = 0 - - for cache_file in self.cache_dir.glob("*.json"): - try: - cache_file.unlink() - removed_count += 1 - except Exception as e: - logger.warning(f"删除缓存文件失败: {cache_file}, 错误: {e}") - - logger.info(f"清空缓存完成,删除了 {removed_count} 个文件") - return removed_count - - def get_stats(self) -> Dict[str, Any]: - """ - 获取缓存统计信息 - - Returns: - 缓存统计信息字典 - """ - total_files = 0 - expired_files = 0 - total_size = 0 - - for cache_file in self.cache_dir.glob("*.json"): - try: - total_files += 1 - total_size += cache_file.stat().st_size - - with open(cache_file, "r", encoding="utf-8") as f: - cache_data = json.load(f) - - cached_time = datetime.fromisoformat(cache_data["timestamp"]) - if self._is_cache_expired(cached_time): - expired_files += 1 - - except (OSError, json.JSONDecodeError, KeyError, ValueError): - expired_files += 1 # 损坏的文件也算作过期 - - return { - "total_files": total_files, - "expired_files": expired_files, - "total_size_bytes": total_size, - "cache_dir": str(self.cache_dir), - "max_age_hours": self.max_age.total_seconds() / 3600, - "similarity_threshold": self.similarity_threshold, - } - -tool_cache = ToolCache() \ No newline at end of file diff --git a/src/common/database/monthly_plan_db.py b/src/common/database/monthly_plan_db.py index 439cc2044..2e6142d16 100644 --- a/src/common/database/monthly_plan_db.py +++ b/src/common/database/monthly_plan_db.py @@ -19,7 +19,7 @@ def add_new_plans(plans: List[str], month: str): # 1. 获取当前有效计划数量 current_plan_count = session.query(MonthlyPlan).filter( MonthlyPlan.target_month == month, - MonthlyPlan.is_deleted == False + not MonthlyPlan.is_deleted ).count() # 2. 从配置获取上限 @@ -62,7 +62,7 @@ def get_active_plans_for_month(month: str) -> List[MonthlyPlan]: try: plans = session.query(MonthlyPlan).filter( MonthlyPlan.target_month == month, - MonthlyPlan.is_deleted == False + not MonthlyPlan.is_deleted ).all() return plans except Exception as e: diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 76a8289bd..95ecb4e41 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -650,3 +650,39 @@ def get_engine(): """获取数据库引擎""" engine, _ = initialize_database() return engine + + +class PermissionNodes(Base): + """权限节点模型""" + __tablename__ = 'permission_nodes' + + id = Column(Integer, primary_key=True, autoincrement=True) + node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称 + description = Column(Text, nullable=False) # 权限描述 + plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件 + default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权 + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间 + + __table_args__ = ( + Index('idx_permission_plugin', 'plugin_name'), + Index('idx_permission_node', 'node_name'), + ) + + +class UserPermissions(Base): + """用户权限模型""" + __tablename__ = 'user_permissions' + + id = Column(Integer, primary_key=True, autoincrement=True) + platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型 + user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID + permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称 + granted = Column(Boolean, default=True, nullable=False) # 是否授权 + granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间 + granted_by = Column(get_string_field(100), nullable=True) # 授权者信息 + + __table_args__ = ( + Index('idx_user_platform_id', 'platform', 'user_id'), + Index('idx_user_permission', 'platform', 'user_id', 'permission_node'), + Index('idx_permission_granted', 'permission_node', 'granted'), + ) diff --git a/src/config/config.py b/src/config/config.py index 5272acf3c..3ea9eaa08 100644 --- a/src/config/config.py +++ b/src/config/config.py @@ -46,7 +46,8 @@ from src.config.official_configs import ( PluginsConfig, WakeUpSystemConfig, MonthlyPlanSystemConfig, - CrossContextConfig + CrossContextConfig, + PermissionConfig ) from .api_ada_configs import ( @@ -382,6 +383,7 @@ class Config(ValidatedConfigBase): custom_prompt: CustomPromptConfig = Field(..., description="自定义提示配置") voice: VoiceConfig = Field(..., description="语音配置") schedule: ScheduleConfig = Field(..., description="调度配置") + permission: PermissionConfig = Field(..., description="权限配置") # 有默认值的字段放在后面 anti_prompt_injection: AntiPromptInjectionConfig = Field(default_factory=lambda: AntiPromptInjectionConfig(), description="反提示注入配置") diff --git a/src/config/official_configs.py b/src/config/official_configs.py index a447a9d95..1ee7cd305 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -697,3 +697,10 @@ class CrossContextConfig(ValidatedConfigBase): """跨群聊上下文共享配置""" enable: bool = Field(default=False, description="是否启用跨群聊上下文共享功能") groups: List[ContextGroup] = Field(default_factory=list, description="上下文共享组列表") + + +class PermissionConfig(ValidatedConfigBase): + """权限系统配置类""" + + # Master用户配置(拥有最高权限,无视所有权限节点) + master_users: List[List[str]] = Field(default_factory=list, description="Master用户列表,格式: [[platform, user_id], ...]") diff --git a/src/llm_models/model_client/gemini_client.py b/src/llm_models/model_client/gemini_client.py index 0b5375935..a14cabb9e 100644 --- a/src/llm_models/model_client/gemini_client.py +++ b/src/llm_models/model_client/gemini_client.py @@ -12,13 +12,25 @@ from google.generativeai.types import ( try: # 尝试从较新的API导入 - from google.generativeai import configure from google.generativeai.types import SafetySetting, GenerationConfig except ImportError: # 回退到基本类型 SafetySetting = Dict GenerationConfig = Dict +from src.config.api_ada_configs import ModelInfo, APIProvider +from src.common.logger import get_logger +from .base_client import APIResponse, UsageRecord, BaseClient, client_registry +from ..exceptions import ( + RespParseException, + NetworkConnectionError, + RespNotOkException, + ReqAbortException, +) +from ..payload_content.message import Message, RoleType +from ..payload_content.resp_format import RespFormat, RespFormatType +from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall + # 定义兼容性类型 ContentDict = Dict PartDict = Dict @@ -50,20 +62,6 @@ class UnsupportedFunctionError(Exception): class FunctionInvocationError(Exception): pass -from src.config.api_ada_configs import ModelInfo, APIProvider -from src.common.logger import get_logger - -from .base_client import APIResponse, UsageRecord, BaseClient, client_registry -from ..exceptions import ( - RespParseException, - NetworkConnectionError, - RespNotOkException, - ReqAbortException, -) -from ..payload_content.message import Message, RoleType -from ..payload_content.resp_format import RespFormat, RespFormatType -from ..payload_content.tool_option import ToolOption, ToolParam, ToolCall - logger = get_logger("Gemini客户端") SAFETY_SETTINGS = [ diff --git a/src/main.py b/src/main.py index 1e43f4d49..ceeec4940 100644 --- a/src/main.py +++ b/src/main.py @@ -147,6 +147,13 @@ MaiMbot-Pro-Max(第三方修改版) # 添加统计信息输出任务 await async_task_manager.add_task(StatisticOutputTask()) + # 初始化权限管理器 + from src.plugin_system.core.permission_manager import PermissionManager + from src.plugin_system.apis.permission_api import permission_api + permission_manager = PermissionManager() + permission_api.set_permission_manager(permission_manager) + logger.info("权限管理器初始化成功") + # 启动API服务器 # start_api_server() # logger.info("API服务器启动成功") diff --git a/src/plugin_system/apis/__init__.py b/src/plugin_system/apis/__init__.py index 362c98581..f00f18f30 100644 --- a/src/plugin_system/apis/__init__.py +++ b/src/plugin_system/apis/__init__.py @@ -18,6 +18,7 @@ from src.plugin_system.apis import ( plugin_manage_api, send_api, tool_api, + permission_api, ) from .logging_api import get_logger from .plugin_register_api import register_plugin @@ -38,4 +39,5 @@ __all__ = [ "get_logger", "register_plugin", "tool_api", + "permission_api", ] diff --git a/src/plugin_system/apis/permission_api.py b/src/plugin_system/apis/permission_api.py new file mode 100644 index 000000000..fd25c63dd --- /dev/null +++ b/src/plugin_system/apis/permission_api.py @@ -0,0 +1,339 @@ +""" +权限系统API - 提供权限管理相关的API接口 + +这个模块提供了权限系统的核心API,包括权限检查、权限节点管理等功能。 +插件可以通过这些API来检查用户权限和管理权限节点。 +""" + +from typing import Optional, List, Dict, Any +from enum import Enum +from dataclasses import dataclass +from abc import ABC, abstractmethod + +from src.common.logger import get_logger + +logger = get_logger(__name__) + + +class PermissionLevel(Enum): + """权限等级枚举""" + MASTER = "master" # 最高权限,无视所有权限节点 + + +@dataclass +class PermissionNode: + """权限节点数据类""" + node_name: str # 权限节点名称,如 "plugin.example.command.test" + description: str # 权限节点描述 + plugin_name: str # 所属插件名称 + default_granted: bool = False # 默认是否授权 + + +@dataclass +class UserInfo: + """用户信息数据类""" + platform: str # 平台类型,如 "qq" + user_id: str # 用户ID + + def __post_init__(self): + """确保user_id是字符串类型""" + self.user_id = str(self.user_id) + + def to_tuple(self) -> tuple[str, str]: + """转换为元组格式""" + return (self.platform, self.user_id) + + +class IPermissionManager(ABC): + """权限管理器接口""" + + @abstractmethod + def check_permission(self, user: UserInfo, permission_node: str) -> bool: + """ + 检查用户是否拥有指定权限节点 + + Args: + user: 用户信息 + permission_node: 权限节点名称 + + Returns: + bool: 是否拥有权限 + """ + pass + + @abstractmethod + def is_master(self, user: UserInfo) -> bool: + """ + 检查用户是否为Master用户 + + Args: + user: 用户信息 + + Returns: + bool: 是否为Master用户 + """ + pass + + @abstractmethod + def register_permission_node(self, node: PermissionNode) -> bool: + """ + 注册权限节点 + + Args: + node: 权限节点 + + Returns: + bool: 注册是否成功 + """ + pass + + @abstractmethod + def grant_permission(self, user: UserInfo, permission_node: str) -> bool: + """ + 授权用户权限节点 + + Args: + user: 用户信息 + permission_node: 权限节点名称 + + Returns: + bool: 授权是否成功 + """ + pass + + @abstractmethod + def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: + """ + 撤销用户权限节点 + + Args: + user: 用户信息 + permission_node: 权限节点名称 + + Returns: + bool: 撤销是否成功 + """ + pass + + @abstractmethod + def get_user_permissions(self, user: UserInfo) -> List[str]: + """ + 获取用户拥有的所有权限节点 + + Args: + user: 用户信息 + + Returns: + List[str]: 权限节点列表 + """ + pass + + @abstractmethod + def get_all_permission_nodes(self) -> List[PermissionNode]: + """ + 获取所有已注册的权限节点 + + Returns: + List[PermissionNode]: 权限节点列表 + """ + pass + + @abstractmethod + def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: + """ + 获取指定插件的所有权限节点 + + Args: + plugin_name: 插件名称 + + Returns: + List[PermissionNode]: 权限节点列表 + """ + pass + + +class PermissionAPI: + """权限系统API类""" + + def __init__(self): + self._permission_manager: Optional[IPermissionManager] = None + + def set_permission_manager(self, manager: IPermissionManager): + """设置权限管理器实例""" + self._permission_manager = manager + logger.info("权限管理器已设置") + + def _ensure_manager(self): + """确保权限管理器已设置""" + if self._permission_manager is None: + raise RuntimeError("权限管理器未设置,请先调用 set_permission_manager") + + def check_permission(self, platform: str, user_id: str, permission_node: str) -> bool: + """ + 检查用户是否拥有指定权限节点 + + Args: + platform: 平台类型,如 "qq" + user_id: 用户ID + permission_node: 权限节点名称 + + Returns: + bool: 是否拥有权限 + + Raises: + RuntimeError: 权限管理器未设置时抛出 + """ + self._ensure_manager() + user = UserInfo(platform=platform, user_id=str(user_id)) + return self._permission_manager.check_permission(user, permission_node) + + def is_master(self, platform: str, user_id: str) -> bool: + """ + 检查用户是否为Master用户 + + Args: + platform: 平台类型,如 "qq" + user_id: 用户ID + + Returns: + bool: 是否为Master用户 + + Raises: + RuntimeError: 权限管理器未设置时抛出 + """ + self._ensure_manager() + user = UserInfo(platform=platform, user_id=str(user_id)) + return self._permission_manager.is_master(user) + + def register_permission_node(self, node_name: str, description: str, plugin_name: str, + default_granted: bool = False) -> bool: + """ + 注册权限节点 + + Args: + node_name: 权限节点名称,如 "plugin.example.command.test" + description: 权限节点描述 + plugin_name: 所属插件名称 + default_granted: 默认是否授权 + + Returns: + bool: 注册是否成功 + + Raises: + RuntimeError: 权限管理器未设置时抛出 + """ + self._ensure_manager() + node = PermissionNode( + node_name=node_name, + description=description, + plugin_name=plugin_name, + default_granted=default_granted + ) + return self._permission_manager.register_permission_node(node) + + def grant_permission(self, platform: str, user_id: str, permission_node: str) -> bool: + """ + 授权用户权限节点 + + Args: + platform: 平台类型,如 "qq" + user_id: 用户ID + permission_node: 权限节点名称 + + Returns: + bool: 授权是否成功 + + Raises: + RuntimeError: 权限管理器未设置时抛出 + """ + self._ensure_manager() + user = UserInfo(platform=platform, user_id=str(user_id)) + return self._permission_manager.grant_permission(user, permission_node) + + def revoke_permission(self, platform: str, user_id: str, permission_node: str) -> bool: + """ + 撤销用户权限节点 + + Args: + platform: 平台类型,如 "qq" + user_id: 用户ID + permission_node: 权限节点名称 + + Returns: + bool: 撤销是否成功 + + Raises: + RuntimeError: 权限管理器未设置时抛出 + """ + self._ensure_manager() + user = UserInfo(platform=platform, user_id=str(user_id)) + return self._permission_manager.revoke_permission(user, permission_node) + + def get_user_permissions(self, platform: str, user_id: str) -> List[str]: + """ + 获取用户拥有的所有权限节点 + + Args: + platform: 平台类型,如 "qq" + user_id: 用户ID + + Returns: + List[str]: 权限节点列表 + + Raises: + RuntimeError: 权限管理器未设置时抛出 + """ + self._ensure_manager() + user = UserInfo(platform=platform, user_id=str(user_id)) + return self._permission_manager.get_user_permissions(user) + + def get_all_permission_nodes(self) -> List[Dict[str, Any]]: + """ + 获取所有已注册的权限节点 + + Returns: + List[Dict[str, Any]]: 权限节点列表,每个节点包含 node_name, description, plugin_name, default_granted + + Raises: + RuntimeError: 权限管理器未设置时抛出 + """ + self._ensure_manager() + nodes = self._permission_manager.get_all_permission_nodes() + return [ + { + "node_name": node.node_name, + "description": node.description, + "plugin_name": node.plugin_name, + "default_granted": node.default_granted + } + for node in nodes + ] + + def get_plugin_permission_nodes(self, plugin_name: str) -> List[Dict[str, Any]]: + """ + 获取指定插件的所有权限节点 + + Args: + plugin_name: 插件名称 + + Returns: + List[Dict[str, Any]]: 权限节点列表 + + Raises: + RuntimeError: 权限管理器未设置时抛出 + """ + self._ensure_manager() + nodes = self._permission_manager.get_plugin_permission_nodes(plugin_name) + return [ + { + "node_name": node.node_name, + "description": node.description, + "plugin_name": node.plugin_name, + "default_granted": node.default_granted + } + for node in nodes + ] + + +# 全局权限API实例 +permission_api = PermissionAPI() diff --git a/src/plugin_system/core/permission_manager.py b/src/plugin_system/core/permission_manager.py new file mode 100644 index 000000000..c860851fb --- /dev/null +++ b/src/plugin_system/core/permission_manager.py @@ -0,0 +1,452 @@ +""" +权限管理器实现 + +这个模块提供了权限系统的核心实现,包括权限检查、权限节点管理、用户权限管理等功能。 +""" + +from typing import List, Set, Tuple +from sqlalchemy.orm import sessionmaker +from sqlalchemy.exc import IntegrityError, SQLAlchemyError +from datetime import datetime + +from src.common.logger import get_logger +from src.common.database.sqlalchemy_models import get_database_engine, PermissionNodes, UserPermissions +from src.plugin_system.apis.permission_api import IPermissionManager, PermissionNode, UserInfo +from src.config.config import global_config + +logger = get_logger(__name__) + + +class PermissionManager(IPermissionManager): + """权限管理器实现类""" + + def __init__(self): + self.engine = get_database_engine() + self.SessionLocal = sessionmaker(bind=self.engine) + self._master_users: Set[Tuple[str, str]] = set() + self._load_master_users() + logger.info("权限管理器初始化完成") + + def _load_master_users(self): + """从配置文件加载Master用户列表""" + try: + master_users_config = global_config.permission.master_users + self._master_users = set() + for user_info in master_users_config: + if isinstance(user_info, list) and len(user_info) == 2: + platform, user_id = user_info + self._master_users.add((str(platform), str(user_id))) + logger.info(f"已加载 {len(self._master_users)} 个Master用户") + except Exception as e: + logger.warning(f"加载Master用户配置失败: {e}") + self._master_users = set() + + def reload_master_users(self): + """重新加载Master用户配置""" + self._load_master_users() + logger.info("Master用户配置已重新加载") + + def is_master(self, user: UserInfo) -> bool: + """ + 检查用户是否为Master用户 + + Args: + user: 用户信息 + + Returns: + bool: 是否为Master用户 + """ + user_tuple = (user.platform, user.user_id) + is_master = user_tuple in self._master_users + if is_master: + logger.debug(f"用户 {user.platform}:{user.user_id} 是Master用户") + return is_master + + def check_permission(self, user: UserInfo, permission_node: str) -> bool: + """ + 检查用户是否拥有指定权限节点 + + Args: + user: 用户信息 + permission_node: 权限节点名称 + + Returns: + bool: 是否拥有权限 + """ + try: + # Master用户拥有所有权限 + if self.is_master(user): + logger.debug(f"Master用户 {user.platform}:{user.user_id} 拥有权限节点 {permission_node}") + return True + + with self.SessionLocal() as session: + # 检查权限节点是否存在 + node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + if not node: + logger.warning(f"权限节点 {permission_node} 不存在") + return False + + # 检查用户是否有明确的权限设置 + user_perm = session.query(UserPermissions).filter_by( + platform=user.platform, + user_id=user.user_id, + permission_node=permission_node + ).first() + + if user_perm: + # 有明确设置,返回设置的值 + result = user_perm.granted + logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 的明确设置: {result}") + return result + else: + # 没有明确设置,使用默认值 + result = node.default_granted + logger.debug(f"用户 {user.platform}:{user.user_id} 对权限节点 {permission_node} 使用默认设置: {result}") + return result + + except SQLAlchemyError as e: + logger.error(f"检查权限时数据库错误: {e}") + return False + except Exception as e: + logger.error(f"检查权限时发生未知错误: {e}") + return False + + def register_permission_node(self, node: PermissionNode) -> bool: + """ + 注册权限节点 + + Args: + node: 权限节点 + + Returns: + bool: 注册是否成功 + """ + try: + with self.SessionLocal() as session: + # 检查节点是否已存在 + existing_node = session.query(PermissionNodes).filter_by(node_name=node.node_name).first() + if existing_node: + # 更新现有节点的信息 + existing_node.description = node.description + existing_node.plugin_name = node.plugin_name + existing_node.default_granted = node.default_granted + session.commit() + logger.debug(f"更新权限节点: {node.node_name}") + return True + + # 创建新节点 + new_node = PermissionNodes( + node_name=node.node_name, + description=node.description, + plugin_name=node.plugin_name, + default_granted=node.default_granted, + created_at=datetime.utcnow() + ) + session.add(new_node) + session.commit() + logger.info(f"注册新权限节点: {node.node_name} (插件: {node.plugin_name})") + return True + + except IntegrityError as e: + logger.error(f"注册权限节点时发生完整性错误: {e}") + return False + except SQLAlchemyError as e: + logger.error(f"注册权限节点时数据库错误: {e}") + return False + except Exception as e: + logger.error(f"注册权限节点时发生未知错误: {e}") + return False + + def grant_permission(self, user: UserInfo, permission_node: str) -> bool: + """ + 授权用户权限节点 + + Args: + user: 用户信息 + permission_node: 权限节点名称 + + Returns: + bool: 授权是否成功 + """ + try: + with self.SessionLocal() as session: + # 检查权限节点是否存在 + node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + if not node: + logger.error(f"尝试授权不存在的权限节点: {permission_node}") + return False + + # 检查是否已有权限记录 + existing_perm = session.query(UserPermissions).filter_by( + platform=user.platform, + user_id=user.user_id, + permission_node=permission_node + ).first() + + if existing_perm: + # 更新现有记录 + existing_perm.granted = True + existing_perm.granted_at = datetime.utcnow() + else: + # 创建新记录 + new_perm = UserPermissions( + platform=user.platform, + user_id=user.user_id, + permission_node=permission_node, + granted=True, + granted_at=datetime.utcnow() + ) + session.add(new_perm) + + session.commit() + logger.info(f"已授权用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") + return True + + except SQLAlchemyError as e: + logger.error(f"授权权限时数据库错误: {e}") + return False + except Exception as e: + logger.error(f"授权权限时发生未知错误: {e}") + return False + + def revoke_permission(self, user: UserInfo, permission_node: str) -> bool: + """ + 撤销用户权限节点 + + Args: + user: 用户信息 + permission_node: 权限节点名称 + + Returns: + bool: 撤销是否成功 + """ + try: + with self.SessionLocal() as session: + # 检查权限节点是否存在 + node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + if not node: + logger.error(f"尝试撤销不存在的权限节点: {permission_node}") + return False + + # 检查是否已有权限记录 + existing_perm = session.query(UserPermissions).filter_by( + platform=user.platform, + user_id=user.user_id, + permission_node=permission_node + ).first() + + if existing_perm: + # 更新现有记录 + existing_perm.granted = False + existing_perm.granted_at = datetime.utcnow() + else: + # 创建新记录(明确撤销) + new_perm = UserPermissions( + platform=user.platform, + user_id=user.user_id, + permission_node=permission_node, + granted=False, + granted_at=datetime.utcnow() + ) + session.add(new_perm) + + session.commit() + logger.info(f"已撤销用户 {user.platform}:{user.user_id} 权限节点 {permission_node}") + return True + + except SQLAlchemyError as e: + logger.error(f"撤销权限时数据库错误: {e}") + return False + except Exception as e: + logger.error(f"撤销权限时发生未知错误: {e}") + return False + + def get_user_permissions(self, user: UserInfo) -> List[str]: + """ + 获取用户拥有的所有权限节点 + + Args: + user: 用户信息 + + Returns: + List[str]: 权限节点列表 + """ + try: + # Master用户拥有所有权限 + if self.is_master(user): + with self.SessionLocal() as session: + all_nodes = session.query(PermissionNodes.node_name).all() + return [node.node_name for node in all_nodes] + + permissions = [] + + with self.SessionLocal() as session: + # 获取所有权限节点 + all_nodes = session.query(PermissionNodes).all() + + for node in all_nodes: + # 检查用户是否有明确的权限设置 + user_perm = session.query(UserPermissions).filter_by( + platform=user.platform, + user_id=user.user_id, + permission_node=node.node_name + ).first() + + if user_perm: + # 有明确设置,使用设置的值 + if user_perm.granted: + permissions.append(node.node_name) + else: + # 没有明确设置,使用默认值 + if node.default_granted: + permissions.append(node.node_name) + + return permissions + + except SQLAlchemyError as e: + logger.error(f"获取用户权限时数据库错误: {e}") + return [] + except Exception as e: + logger.error(f"获取用户权限时发生未知错误: {e}") + return [] + + def get_all_permission_nodes(self) -> List[PermissionNode]: + """ + 获取所有已注册的权限节点 + + Returns: + List[PermissionNode]: 权限节点列表 + """ + try: + with self.SessionLocal() as session: + nodes = session.query(PermissionNodes).all() + return [ + PermissionNode( + node_name=node.node_name, + description=node.description, + plugin_name=node.plugin_name, + default_granted=node.default_granted + ) + for node in nodes + ] + + except SQLAlchemyError as e: + logger.error(f"获取所有权限节点时数据库错误: {e}") + return [] + except Exception as e: + logger.error(f"获取所有权限节点时发生未知错误: {e}") + return [] + + def get_plugin_permission_nodes(self, plugin_name: str) -> List[PermissionNode]: + """ + 获取指定插件的所有权限节点 + + Args: + plugin_name: 插件名称 + + Returns: + List[PermissionNode]: 权限节点列表 + """ + try: + with self.SessionLocal() as session: + nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all() + return [ + PermissionNode( + node_name=node.node_name, + description=node.description, + plugin_name=node.plugin_name, + default_granted=node.default_granted + ) + for node in nodes + ] + + except SQLAlchemyError as e: + logger.error(f"获取插件权限节点时数据库错误: {e}") + return [] + except Exception as e: + logger.error(f"获取插件权限节点时发生未知错误: {e}") + return [] + + def delete_plugin_permissions(self, plugin_name: str) -> bool: + """ + 删除指定插件的所有权限节点(用于插件卸载时清理) + + Args: + plugin_name: 插件名称 + + Returns: + bool: 删除是否成功 + """ + try: + with self.SessionLocal() as session: + # 获取插件的所有权限节点 + plugin_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).all() + node_names = [node.node_name for node in plugin_nodes] + + if not node_names: + logger.info(f"插件 {plugin_name} 没有注册任何权限节点") + return True + + # 删除用户权限记录 + deleted_user_perms = session.query(UserPermissions).filter( + UserPermissions.permission_node.in_(node_names) + ).delete(synchronize_session=False) + + # 删除权限节点 + deleted_nodes = session.query(PermissionNodes).filter_by(plugin_name=plugin_name).delete() + + session.commit() + logger.info(f"已删除插件 {plugin_name} 的 {deleted_nodes} 个权限节点和 {deleted_user_perms} 条用户权限记录") + return True + + except SQLAlchemyError as e: + logger.error(f"删除插件权限时数据库错误: {e}") + return False + except Exception as e: + logger.error(f"删除插件权限时发生未知错误: {e}") + return False + + def get_users_with_permission(self, permission_node: str) -> List[Tuple[str, str]]: + """ + 获取拥有指定权限的所有用户 + + Args: + permission_node: 权限节点名称 + + Returns: + List[Tuple[str, str]]: 用户列表,格式为 [(platform, user_id), ...] + """ + try: + users = [] + + with self.SessionLocal() as session: + # 检查权限节点是否存在 + node = session.query(PermissionNodes).filter_by(node_name=permission_node).first() + if not node: + logger.warning(f"权限节点 {permission_node} 不存在") + return users + + # 获取明确授权的用户 + granted_users = session.query(UserPermissions).filter_by( + permission_node=permission_node, + granted=True + ).all() + + for user_perm in granted_users: + users.append((user_perm.platform, user_perm.user_id)) + + # 如果是默认授权的权限节点,还需要考虑没有明确设置的用户 + # 但这里我们只返回明确授权的用户,避免返回所有用户 + + # 添加Master用户(他们拥有所有权限) + users.extend(list(self._master_users)) + + # 去重 + return list(set(users)) + + except SQLAlchemyError as e: + logger.error(f"获取拥有权限的用户时数据库错误: {e}") + return [] + except Exception as e: + logger.error(f"获取拥有权限的用户时发生未知错误: {e}") + return [] diff --git a/src/plugin_system/utils/permission_decorators.py b/src/plugin_system/utils/permission_decorators.py new file mode 100644 index 000000000..ae5b48e0e --- /dev/null +++ b/src/plugin_system/utils/permission_decorators.py @@ -0,0 +1,274 @@ +""" +权限装饰器 + +提供方便的权限检查装饰器,用于插件命令和其他需要权限验证的地方。 +""" + +from functools import wraps +from typing import Callable, Optional +from inspect import iscoroutinefunction + +from src.plugin_system.apis.permission_api import permission_api +from src.plugin_system.apis.send_api import send_message +from src.plugin_system.apis.logging_api import get_logger +from src.common.message import ChatStream + +logger = get_logger(__name__) + + +def require_permission(permission_node: str, deny_message: Optional[str] = None): + """ + 权限检查装饰器 + + 用于装饰需要特定权限才能执行的函数。如果用户没有权限,会发送拒绝消息并阻止函数执行。 + + Args: + permission_node: 所需的权限节点名称 + deny_message: 权限不足时的提示消息,如果为None则使用默认消息 + + Example: + @require_permission("plugin.example.admin") + async def admin_command(message: Message, chat_stream: ChatStream): + # 只有拥有 plugin.example.admin 权限的用户才能执行 + pass + """ + def decorator(func: Callable): + @wraps(func) + async def async_wrapper(*args, **kwargs): + # 尝试从参数中提取 ChatStream 对象 + chat_stream = None + for arg in args: + if isinstance(arg, ChatStream): + chat_stream = arg + break + + # 如果在位置参数中没找到,尝试从关键字参数中查找 + if chat_stream is None: + chat_stream = kwargs.get('chat_stream') + + if chat_stream is None: + logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") + return + + # 检查权限 + has_permission = permission_api.check_permission( + chat_stream.user_platform, + chat_stream.user_id, + permission_node + ) + + if not has_permission: + # 权限不足,发送拒绝消息 + message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}" + await send_message(chat_stream, message) + return + + # 权限检查通过,执行原函数 + return await func(*args, **kwargs) + + def sync_wrapper(*args, **kwargs): + # 对于同步函数,我们不能发送异步消息,只能记录日志 + chat_stream = None + for arg in args: + if isinstance(arg, ChatStream): + chat_stream = arg + break + + if chat_stream is None: + chat_stream = kwargs.get('chat_stream') + + if chat_stream is None: + logger.error(f"权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") + return + + # 检查权限 + has_permission = permission_api.check_permission( + chat_stream.user_platform, + chat_stream.user_id, + permission_node + ) + + if not has_permission: + logger.warning(f"用户 {chat_stream.user_platform}:{chat_stream.user_id} 没有权限 {permission_node}") + return + + # 权限检查通过,执行原函数 + return func(*args, **kwargs) + + # 根据函数类型选择包装器 + if iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + +def require_master(deny_message: Optional[str] = None): + """ + Master权限检查装饰器 + + 用于装饰只有Master用户才能执行的函数。 + + Args: + deny_message: 权限不足时的提示消息,如果为None则使用默认消息 + + Example: + @require_master() + async def master_only_command(message: Message, chat_stream: ChatStream): + # 只有Master用户才能执行 + pass + """ + def decorator(func: Callable): + @wraps(func) + async def async_wrapper(*args, **kwargs): + # 尝试从参数中提取 ChatStream 对象 + chat_stream = None + for arg in args: + if isinstance(arg, ChatStream): + chat_stream = arg + break + + # 如果在位置参数中没找到,尝试从关键字参数中查找 + if chat_stream is None: + chat_stream = kwargs.get('chat_stream') + + if chat_stream is None: + logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") + return + + # 检查是否为Master用户 + is_master = permission_api.is_master( + chat_stream.user_platform, + chat_stream.user_id + ) + + if not is_master: + # 权限不足,发送拒绝消息 + message = deny_message or "❌ 此操作仅限Master用户执行" + await send_message(chat_stream, message) + return + + # 权限检查通过,执行原函数 + return await func(*args, **kwargs) + + def sync_wrapper(*args, **kwargs): + # 对于同步函数,我们不能发送异步消息,只能记录日志 + chat_stream = None + for arg in args: + if isinstance(arg, ChatStream): + chat_stream = arg + break + + if chat_stream is None: + chat_stream = kwargs.get('chat_stream') + + if chat_stream is None: + logger.error(f"Master权限装饰器无法找到 ChatStream 对象,函数: {func.__name__}") + return + + # 检查是否为Master用户 + is_master = permission_api.is_master( + chat_stream.user_platform, + chat_stream.user_id + ) + + if not is_master: + logger.warning(f"用户 {chat_stream.user_platform}:{chat_stream.user_id} 不是Master用户") + return + + # 权限检查通过,执行原函数 + return func(*args, **kwargs) + + # 根据函数类型选择包装器 + if iscoroutinefunction(func): + return async_wrapper + else: + return sync_wrapper + + return decorator + + +class PermissionChecker: + """ + 权限检查工具类 + + 提供一些便捷的权限检查方法,用于在代码中进行权限验证。 + """ + + @staticmethod + def check_permission(chat_stream: ChatStream, permission_node: str) -> bool: + """ + 检查用户是否拥有指定权限 + + Args: + chat_stream: 聊天流对象 + permission_node: 权限节点名称 + + Returns: + bool: 是否拥有权限 + """ + return permission_api.check_permission( + chat_stream.user_platform, + chat_stream.user_id, + permission_node + ) + + @staticmethod + def is_master(chat_stream: ChatStream) -> bool: + """ + 检查用户是否为Master用户 + + Args: + chat_stream: 聊天流对象 + + Returns: + bool: 是否为Master用户 + """ + return permission_api.is_master( + chat_stream.user_platform, + chat_stream.user_id + ) + + @staticmethod + async def ensure_permission(chat_stream: ChatStream, permission_node: str, + deny_message: Optional[str] = None) -> bool: + """ + 确保用户拥有指定权限,如果没有权限会发送消息并返回False + + Args: + chat_stream: 聊天流对象 + permission_node: 权限节点名称 + deny_message: 权限不足时的提示消息 + + Returns: + bool: 是否拥有权限 + """ + has_permission = PermissionChecker.check_permission(chat_stream, permission_node) + + if not has_permission: + message = deny_message or f"❌ 你没有执行此操作的权限\n需要权限: {permission_node}" + await send_message(chat_stream, message) + + return has_permission + + @staticmethod + async def ensure_master(chat_stream: ChatStream, + deny_message: Optional[str] = None) -> bool: + """ + 确保用户为Master用户,如果不是会发送消息并返回False + + Args: + chat_stream: 聊天流对象 + deny_message: 权限不足时的提示消息 + + Returns: + bool: 是否为Master用户 + """ + is_master = PermissionChecker.is_master(chat_stream) + + if not is_master: + message = deny_message or "❌ 此操作仅限Master用户执行" + await send_message(chat_stream, message) + + return is_master diff --git a/src/plugins/built_in/maizone_refactored/services/qzone_service.py b/src/plugins/built_in/maizone_refactored/services/qzone_service.py index f1c4ac5d5..181ac9dd7 100644 --- a/src/plugins/built_in/maizone_refactored/services/qzone_service.py +++ b/src/plugins/built_in/maizone_refactored/services/qzone_service.py @@ -340,7 +340,7 @@ class QZoneService: retry_delay *= 2 continue logger.error(f"无法连接到Napcat服务(最终尝试): {url},错误: {str(e)}") - raise RuntimeError(f"无法连接到Napcat服务: {url}") + raise RuntimeError(f"无法连接到Napcat服务: {url}") from e except Exception as e: logger.error(f"获取cookie异常: {str(e)}") raise @@ -718,7 +718,8 @@ class QZoneService: feeds_list = [] for feed in feeds_data: - if not feed: continue + if not feed: + continue # 过滤非说说动态 if str(feed.get('appid', '')) != '311': diff --git a/src/plugins/built_in/permission_management/plugin.py b/src/plugins/built_in/permission_management/plugin.py new file mode 100644 index 000000000..e6b02f5f5 --- /dev/null +++ b/src/plugins/built_in/permission_management/plugin.py @@ -0,0 +1,302 @@ +""" +权限管理插件 + +提供权限系统的管理命令,包括权限授权、撤销、查询等功能。 +""" + +import re +from typing import List, Optional, Tuple, Type + +from src.plugin_system.apis.plugin_register_api import register_plugin +from src.plugin_system.base.base_plugin import BasePlugin +from src.plugin_system.base.base_command import BaseCommand +from src.plugin_system.apis.permission_api import permission_api +from src.plugin_system.apis.logging_api import get_logger +from src.common.message import ChatStream, Message +from src.plugin_system.base.component_types import CommandInfo +from src.plugin_system.base.config_types import ConfigField + + +logger = get_logger("Permission") + + +class PermissionCommand(BaseCommand): + """权限管理命令""" + + command_name = "permission" + command_description = "权限管理命令" + command_pattern = r"^/permission(\s[a-zA-Z0-9_]+)*\s*$)" + command_help = "/permission <子命令> [参数...]" + intercept_message = True + + def __init__(self): + # 注册权限节点 + permission_api.register_permission_node( + "plugin.permission.manage", + "权限管理:可以授权和撤销其他用户的权限", + "permission_manager", + False + ) + permission_api.register_permission_node( + "plugin.permission.view", + "权限查看:可以查看权限节点和用户权限信息", + "permission_manager", + True + ) + + def can_execute(self, message: Message, chat_stream: ChatStream) -> bool: + """检查命令是否可以执行""" + # 基本权限检查由权限系统处理 + return True + + async def execute(self, message: Message, chat_stream: ChatStream, args: List[str]) -> None: + """执行权限管理命令""" + if not args: + await self._show_help(chat_stream) + return + + subcommand = args[0].lower() + remaining_args = args[1:] + + # 检查基本查看权限 + can_view = permission_api.check_permission( + chat_stream.user_platform, + chat_stream.user_id, + "plugin.permission.view" + ) or permission_api.is_master(chat_stream.user_platform, chat_stream.user_id) + + # 检查管理权限 + can_manage = permission_api.check_permission( + chat_stream.user_platform, + chat_stream.user_id, + "plugin.permission.manage" + ) or permission_api.is_master(chat_stream.user_platform, chat_stream.user_id) + + if subcommand in ["grant", "授权", "give"]: + if not can_manage: + await self.send_text("❌ 你没有权限管理的权限") + return + await self._grant_permission(chat_stream, remaining_args) + + elif subcommand in ["revoke", "撤销", "remove"]: + if not can_manage: + await self.send_text("❌ 你没有权限管理的权限") + return + await self._revoke_permission(chat_stream, remaining_args) + + elif subcommand in ["list", "列表", "ls"]: + if not can_view: + await self.send_text("❌ 你没有查看权限的权限") + return + await self._list_permissions(chat_stream, remaining_args) + + elif subcommand in ["check", "检查"]: + if not can_view: + await self.send_text("❌ 你没有查看权限的权限") + return + await self._check_permission(chat_stream, remaining_args) + + elif subcommand in ["nodes", "节点"]: + if not can_view: + await self.send_text("❌ 你没有查看权限的权限") + return + await self._list_nodes(chat_stream, remaining_args) + + elif subcommand in ["help", "帮助"]: + await self._show_help(chat_stream) + + else: + await self.send_text(f"❌ 未知的子命令: {subcommand}\n使用 /permission help 查看帮助") + + async def _show_help(self, chat_stream: ChatStream): + """显示帮助信息""" + help_text = """📋 权限管理命令帮助 + +🔐 管理命令(需要管理权限): +• /permission grant <@用户|QQ号> <权限节点> - 授权用户权限 +• /permission revoke <@用户|QQ号> <权限节点> - 撤销用户权限 + +👀 查看命令(需要查看权限): +• /permission list [用户] - 查看用户权限列表 +• /permission check <@用户|QQ号> <权限节点> - 检查用户是否拥有权限 +• /permission nodes [插件名] - 查看权限节点列表 + +❓ 其他: +• /permission help - 显示此帮助 + +📝 示例: +• /permission grant @张三 plugin.example.command +• /permission list 123456789 +• /permission nodes example_plugin""" + + await self.send_text(help_text) + + def _parse_user_mention(self, mention: str) -> Optional[str]: + """解析用户提及,提取QQ号""" + # 匹配 @用户 格式,提取QQ号 + at_match = re.search(r'\[CQ:at,qq=(\d+)\]', mention) + if at_match: + return at_match.group(1) + + # 直接是数字 + if mention.isdigit(): + return mention + + return None + + async def _grant_permission(self, chat_stream: ChatStream, args: List[str]): + """授权用户权限""" + if len(args) < 2: + await self.send_text("❌ 用法: /permission grant <@用户|QQ号> <权限节点>") + return + + user_mention = args[0] + permission_node = args[1] + + # 解析用户ID + user_id = self._parse_user_mention(user_mention) + if not user_id: + await self.send_text("❌ 无效的用户格式,请使用 @用户 或直接输入QQ号") + return + + # 执行授权 + success = permission_api.grant_permission(chat_stream.user_platform, user_id, permission_node) + + if success: + await self.send_text(f"✅ 已授权用户 {user_id} 权限节点 {permission_node}") + else: + await self.send_text("❌ 授权失败,请检查权限节点是否存在") + + async def _revoke_permission(self, chat_stream: ChatStream, args: List[str]): + """撤销用户权限""" + if len(args) < 2: + await self.send_text("❌ 用法: /permission revoke <@用户|QQ号> <权限节点>") + return + + user_mention = args[0] + permission_node = args[1] + + # 解析用户ID + user_id = self._parse_user_mention(user_mention) + if not user_id: + await self.send_text("❌ 无效的用户格式,请使用 @用户 或直接输入QQ号") + return + + # 执行撤销 + success = permission_api.revoke_permission(chat_stream.user_platform, user_id, permission_node) + + if success: + await self.send_text(f"✅ 已撤销用户 {user_id} 权限节点 {permission_node}") + else: + await self.send_text("❌ 撤销失败,请检查权限节点是否存在") + + async def _list_permissions(self, chat_stream: ChatStream, args: List[str]): + """列出用户权限""" + target_user_id = None + + if args: + # 指定了用户 + user_mention = args[0] + target_user_id = self._parse_user_mention(user_mention) + if not target_user_id: + await self.send_text("❌ 无效的用户格式,请使用 @用户 或直接输入QQ号") + return + else: + # 查看自己的权限 + target_user_id = chat_stream.user_id + + # 检查是否为Master用户 + is_master = permission_api.is_master(chat_stream.user_platform, target_user_id) + + # 获取用户权限 + permissions = permission_api.get_user_permissions(chat_stream.user_platform, target_user_id) + + if is_master: + response = f"👑 用户 {target_user_id} 是Master用户,拥有所有权限" + else: + if permissions: + perm_list = "\n".join([f"• {perm}" for perm in permissions]) + response = f"📋 用户 {target_user_id} 拥有的权限:\n{perm_list}" + else: + response = f"📋 用户 {target_user_id} 没有任何权限" + + await self.send_text(response) + + async def _check_permission(self, chat_stream: ChatStream, args: List[str]): + """检查用户权限""" + if len(args) < 2: + await self.send_text("❌ 用法: /permission check <@用户|QQ号> <权限节点>") + return + + user_mention = args[0] + permission_node = args[1] + + # 解析用户ID + user_id = self._parse_user_mention(user_mention) + if not user_id: + await self.send_text("❌ 无效的用户格式,请使用 @用户 或直接输入QQ号") + return + + # 检查权限 + has_permission = permission_api.check_permission(chat_stream.user_platform, user_id, permission_node) + is_master = permission_api.is_master(chat_stream.user_platform, user_id) + + if has_permission: + if is_master: + response = f"✅ 用户 {user_id} 拥有权限 {permission_node}(Master用户)" + else: + response = f"✅ 用户 {user_id} 拥有权限 {permission_node}" + else: + response = f"❌ 用户 {user_id} 没有权限 {permission_node}" + + await self.send_text(response) + + async def _list_nodes(self, chat_stream: ChatStream, args: List[str]): + """列出权限节点""" + plugin_name = args[0] if args else None + + if plugin_name: + # 获取指定插件的权限节点 + nodes = permission_api.get_plugin_permission_nodes(plugin_name) + title = f"📋 插件 {plugin_name} 的权限节点:" + else: + # 获取所有权限节点 + nodes = permission_api.get_all_permission_nodes() + title = "📋 所有权限节点:" + + if not nodes: + if plugin_name: + response = f"📋 插件 {plugin_name} 没有注册任何权限节点" + else: + response = "📋 系统中没有任何权限节点" + else: + node_list = [] + for node in nodes: + default_text = "(默认授权)" if node["default_granted"] else "(默认拒绝)" + node_list.append(f"• {node['node_name']} {default_text}") + node_list.append(f" 📄 {node['description']}") + if not plugin_name: + node_list.append(f" 🔌 插件: {node['plugin_name']}") + node_list.append("") # 空行分隔 + + response = title + "\n" + "\n".join(node_list) + + await self.send_text(response) + + +@register_plugin +class PermissionManagerPlugin(BasePlugin): + plugin_name: str = "permission_manager_plugin" + enable_plugin: bool = True + dependencies: list[str] = [] + python_dependencies: list[str] = [] + config_file_name: str = "config.toml" + config_schema: dict = { + "plugin": { + "enabled": ConfigField(bool, default=True, description="是否启用插件"), + "config_version": ConfigField(type=str, default="1.1.0", description="配置文件版本") + } + } + + def get_plugin_components(self) -> List[Tuple[CommandInfo, Type[BaseCommand]]]: + return [(PermissionCommand.get_command_info(), PermissionCommand)] \ No newline at end of file diff --git a/src/plugins/built_in/web_search_tool/engines/bing_engine.py b/src/plugins/built_in/web_search_tool/engines/bing_engine.py index cbd30a6f9..ac90956e0 100644 --- a/src/plugins/built_in/web_search_tool/engines/bing_engine.py +++ b/src/plugins/built_in/web_search_tool/engines/bing_engine.py @@ -183,7 +183,7 @@ class BingSearchEngine(BaseSearchEngine): results = root.select("ol#b_results li.b_algo") if results: - for rank, result in enumerate(results, 1): + for _rank, result in enumerate(results, 1): # 提取标题和链接 title_link = result.select_one("h2 a") if not title_link: diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index f5ec06493..a55f6ace6 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "6.4.3" +version = "6.4.4" #----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -40,6 +40,14 @@ mysql_sql_mode = "TRADITIONAL" # SQL模式 connection_pool_size = 10 # 连接池大小(仅MySQL有效) connection_timeout = 10 # 连接超时时间(秒) +[permission] # 权限系统配置 +# Master用户配置(拥有最高权限,无视所有权限节点) +# 格式:[[platform, user_id], ...] +# 示例:[["qq", "123456"], ["telegram", "user789"]] +master_users = [ + # ["qq", "123456789"], # 示例:QQ平台的Master用户 +] + [bot] platform = "qq" qq_account = 1145141919810 # 麦麦的QQ账号 diff --git a/tests/test_wakeup_system.py b/tests/test_wakeup_system.py deleted file mode 100644 index bc340adcf..000000000 --- a/tests/test_wakeup_system.py +++ /dev/null @@ -1,292 +0,0 @@ -import pytest -import time -from unittest.mock import Mock, patch -import sys -import os - -# 添加项目根目录到Python路径 -sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) - -from src.chat.chat_loop.wakeup_manager import WakeUpManager -from src.chat.chat_loop.hfc_context import HfcContext -from src.config.official_configs import WakeUpSystemConfig - - -class TestWakeUpManager: - """唤醒度管理器测试类""" - - @pytest.fixture - def mock_context(self): - """创建模拟的HFC上下文""" - context = Mock(spec=HfcContext) - context.stream_id = "test_chat_123" - context.log_prefix = "[TEST]" - context.running = True - return context - - @pytest.fixture - def wakeup_config(self): - """创建测试用的唤醒度配置""" - return WakeUpSystemConfig( - enable=True, - wakeup_threshold=15.0, - private_message_increment=3.0, - group_mention_increment=2.0, - decay_rate=0.2, - decay_interval=30.0, - angry_duration=300.0 # 5分钟 - ) - - @pytest.fixture - def wakeup_manager(self, mock_context, wakeup_config): - """创建唤醒度管理器实例""" - with patch('src.chat.chat_loop.wakeup_manager.global_config') as mock_global_config: - mock_global_config.wakeup_system = wakeup_config - manager = WakeUpManager(mock_context) - return manager - - def test_initialization(self, wakeup_manager, wakeup_config): - """测试初始化""" - assert wakeup_manager.wakeup_value == 0.0 - assert wakeup_manager.is_angry == False - assert wakeup_manager.wakeup_threshold == wakeup_config.wakeup_threshold - assert wakeup_manager.private_message_increment == wakeup_config.private_message_increment - assert wakeup_manager.group_mention_increment == wakeup_config.group_mention_increment - assert wakeup_manager.decay_rate == wakeup_config.decay_rate - assert wakeup_manager.decay_interval == wakeup_config.decay_interval - assert wakeup_manager.angry_duration == wakeup_config.angry_duration - assert wakeup_manager.enabled == wakeup_config.enable - - @patch('src.manager.schedule_manager.schedule_manager') - @patch('src.mood.mood_manager.mood_manager') - def test_private_message_wakeup_accumulation(self, mock_mood_manager, mock_schedule_manager, wakeup_manager): - """测试私聊消息唤醒度累积""" - # 模拟休眠状态 - mock_schedule_manager.is_sleeping.return_value = True - - # 发送5条私聊消息 (5 * 3.0 = 15.0,达到阈值) - for i in range(4): - result = wakeup_manager.add_wakeup_value(is_private_chat=True) - assert result == False # 前4条消息不应该触发唤醒 - assert wakeup_manager.wakeup_value == (i + 1) * 3.0 - - # 第5条消息应该触发唤醒 - result = wakeup_manager.add_wakeup_value(is_private_chat=True) - assert result == True - assert wakeup_manager.is_angry == True - assert wakeup_manager.wakeup_value == 0.0 # 唤醒后重置 - - # 验证情绪管理器被调用 - mock_mood_manager.set_angry_from_wakeup.assert_called_once_with("test_chat_123") - - @patch('src.manager.schedule_manager.schedule_manager') - @patch('src.mood.mood_manager.mood_manager') - def test_group_mention_wakeup_accumulation(self, mock_mood_manager, mock_schedule_manager, wakeup_manager): - """测试群聊艾特消息唤醒度累积""" - # 模拟休眠状态 - mock_schedule_manager.is_sleeping.return_value = True - - # 发送7条群聊艾特消息 (7 * 2.0 = 14.0,未达到阈值) - for i in range(7): - result = wakeup_manager.add_wakeup_value(is_private_chat=False, is_mentioned=True) - assert result == False - assert wakeup_manager.wakeup_value == (i + 1) * 2.0 - - # 第8条消息应该触发唤醒 (8 * 2.0 = 16.0,超过阈值15.0) - result = wakeup_manager.add_wakeup_value(is_private_chat=False, is_mentioned=True) - assert result == True - assert wakeup_manager.is_angry == True - assert wakeup_manager.wakeup_value == 0.0 - - # 验证情绪管理器被调用 - mock_mood_manager.set_angry_from_wakeup.assert_called_once_with("test_chat_123") - - @patch('src.manager.schedule_manager.schedule_manager') - def test_group_message_without_mention(self, mock_schedule_manager, wakeup_manager): - """测试群聊未艾特消息不增加唤醒度""" - # 模拟休眠状态 - mock_schedule_manager.is_sleeping.return_value = True - - # 发送群聊消息但未被艾特 - result = wakeup_manager.add_wakeup_value(is_private_chat=False, is_mentioned=False) - assert result == False - assert wakeup_manager.wakeup_value == 0.0 # 不应该增加 - - @patch('src.manager.schedule_manager.schedule_manager') - def test_no_accumulation_when_not_sleeping(self, mock_schedule_manager, wakeup_manager): - """测试非休眠状态下不累积唤醒度""" - # 模拟非休眠状态 - mock_schedule_manager.is_sleeping.return_value = False - - # 发送私聊消息 - result = wakeup_manager.add_wakeup_value(is_private_chat=True) - assert result == False - assert wakeup_manager.wakeup_value == 0.0 # 不应该增加 - - def test_disabled_system(self, mock_context): - """测试系统禁用时的行为""" - disabled_config = WakeUpSystemConfig(enable=False) - - with patch('src.chat.chat_loop.wakeup_manager.global_config') as mock_global_config: - mock_global_config.wakeup_system = disabled_config - manager = WakeUpManager(mock_context) - - with patch('src.manager.schedule_manager.schedule_manager') as mock_schedule_manager: - mock_schedule_manager.is_sleeping.return_value = True - - # 即使发送消息也不应该累积唤醒度 - result = manager.add_wakeup_value(is_private_chat=True) - assert result == False - assert manager.wakeup_value == 0.0 - - @patch('src.mood.mood_manager.mood_manager') - def test_angry_state_expiration(self, mock_mood_manager, wakeup_manager): - """测试愤怒状态过期""" - # 手动设置愤怒状态 - wakeup_manager.is_angry = True - wakeup_manager.angry_start_time = time.time() - 400 # 400秒前开始愤怒(超过300秒持续时间) - - # 检查愤怒状态应该已过期 - is_angry = wakeup_manager.is_in_angry_state() - assert is_angry == False - assert wakeup_manager.is_angry == False - - # 验证情绪管理器被调用清除愤怒状态 - mock_mood_manager.clear_angry_from_wakeup.assert_called_once_with("test_chat_123") - - def test_angry_prompt_addition(self, wakeup_manager): - """测试愤怒状态提示词""" - # 非愤怒状态 - prompt = wakeup_manager.get_angry_prompt_addition() - assert prompt == "" - - # 愤怒状态 - wakeup_manager.is_angry = True - wakeup_manager.angry_start_time = time.time() - prompt = wakeup_manager.get_angry_prompt_addition() - assert "吵醒" in prompt and "生气" in prompt - - def test_status_info(self, wakeup_manager): - """测试状态信息获取""" - # 设置一些状态 - wakeup_manager.wakeup_value = 10.5 - wakeup_manager.is_angry = True - wakeup_manager.angry_start_time = time.time() - - status = wakeup_manager.get_status_info() - - assert status["wakeup_value"] == 10.5 - assert status["wakeup_threshold"] == 15.0 - assert status["is_angry"] == True - assert status["angry_remaining_time"] > 0 - - @pytest.mark.asyncio - async def test_decay_loop(self, wakeup_manager): - """测试衰减循环""" - # 设置初始唤醒度 - wakeup_manager.wakeup_value = 5.0 - - # 模拟一次衰减 - with patch('asyncio.sleep') as mock_sleep: - # 创建一个会立即停止的衰减循环 - wakeup_manager.context.running = False - - # 手动调用衰减逻辑 - if wakeup_manager.wakeup_value > 0: - old_value = wakeup_manager.wakeup_value - wakeup_manager.wakeup_value = max(0, wakeup_manager.wakeup_value - wakeup_manager.decay_rate) - - assert wakeup_manager.wakeup_value == 4.8 # 5.0 - 0.2 = 4.8 - - @pytest.mark.asyncio - @patch('src.mood.mood_manager.mood_manager') - async def test_angry_state_expiration_in_decay_loop(self, mock_mood_manager, wakeup_manager): - """测试衰减循环中愤怒状态过期""" - # 设置过期的愤怒状态 - wakeup_manager.is_angry = True - wakeup_manager.angry_start_time = time.time() - 400 # 400秒前 - - # 手动调用衰减循环中的愤怒状态检查逻辑 - current_time = time.time() - if wakeup_manager.is_angry and current_time - wakeup_manager.angry_start_time >= wakeup_manager.angry_duration: - wakeup_manager.is_angry = False - mock_mood_manager.clear_angry_from_wakeup(wakeup_manager.context.stream_id) - - assert wakeup_manager.is_angry == False - mock_mood_manager.clear_angry_from_wakeup.assert_called_once_with("test_chat_123") - - @pytest.mark.asyncio - async def test_start_stop_lifecycle(self, wakeup_manager): - """测试启动和停止生命周期""" - # 测试启动 - await wakeup_manager.start() - assert wakeup_manager._decay_task is not None - assert not wakeup_manager._decay_task.done() - - # 测试停止 - await wakeup_manager.stop() - assert wakeup_manager._decay_task.cancelled() - - @pytest.mark.asyncio - async def test_disabled_system_start(self, mock_context): - """测试禁用系统的启动行为""" - disabled_config = WakeUpSystemConfig(enable=False) - - with patch('src.chat.chat_loop.wakeup_manager.global_config') as mock_global_config: - mock_global_config.wakeup_system = disabled_config - manager = WakeUpManager(mock_context) - - await manager.start() - assert manager._decay_task is None # 禁用时不应该创建衰减任务 - - -class TestWakeUpSystemIntegration: - """唤醒度系统集成测试""" - - @patch('src.manager.schedule_manager.schedule_manager') - @patch('src.mood.mood_manager.mood_manager') - def test_mixed_message_types(self, mock_mood_manager, mock_schedule_manager): - """测试混合消息类型的唤醒度累积""" - mock_schedule_manager.is_sleeping.return_value = True - - # 创建配置和管理器 - config = WakeUpSystemConfig( - enable=True, - wakeup_threshold=10.0, # 降低阈值便于测试 - private_message_increment=3.0, - group_mention_increment=2.0, - decay_rate=0.2, - decay_interval=30.0, - angry_duration=300.0 - ) - - context = Mock(spec=HfcContext) - context.stream_id = "test_mixed" - context.log_prefix = "[MIXED]" - context.running = True - - with patch('src.chat.chat_loop.wakeup_manager.global_config') as mock_global_config: - mock_global_config.wakeup_system = config - manager = WakeUpManager(context) - - # 发送2条私聊消息 (2 * 3.0 = 6.0) - manager.add_wakeup_value(is_private_chat=True) - manager.add_wakeup_value(is_private_chat=True) - assert manager.wakeup_value == 6.0 - - # 发送2条群聊艾特消息 (2 * 2.0 = 4.0, 总计10.0) - manager.add_wakeup_value(is_private_chat=False, is_mentioned=True) - assert manager.wakeup_value == 8.0 - - # 最后一条消息触发唤醒 - result = manager.add_wakeup_value(is_private_chat=False, is_mentioned=True) - assert result == True - assert manager.is_angry == True - assert manager.wakeup_value == 0.0 - - mock_mood_manager.set_angry_from_wakeup.assert_called_once_with("test_mixed") - - -if __name__ == "__main__": - # 运行测试 - pytest.main([__file__, "-v"]) \ No newline at end of file