This commit is contained in:
Windpicker-owo
2025-11-16 21:18:30 +08:00
21 changed files with 296 additions and 53 deletions

View File

@@ -14,6 +14,7 @@ from fastapi import APIRouter, HTTPException, Query, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
# 调整项目根目录的计算方式
project_root = Path(__file__).parent.parent.parent
data_dir = project_root / "data" / "memory_graph"

View File

@@ -1,16 +1,17 @@
import time
from typing import Literal
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.common.security import get_api_key
from src.config.config import global_config
from src.plugin_system.apis import message_api, person_api
logger = get_logger("HTTP消息API")
router = APIRouter()
router = APIRouter(dependencies=[Depends(get_api_key)])
@router.get("/messages/recent")
@@ -161,5 +162,3 @@ async def get_message_stats_by_chat(
# 统一异常处理
logger.error(f"获取消息统计时发生错误: {e}")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -1,16 +1,17 @@
from datetime import datetime, timedelta
from typing import Literal
from fastapi import APIRouter, HTTPException, Query
from fastapi import APIRouter, Depends, HTTPException, Query
from src.chat.utils.statistic import (
StatisticOutputTask,
)
from src.common.logger import get_logger
from src.common.security import get_api_key
logger = get_logger("LLM统计API")
router = APIRouter()
router = APIRouter(dependencies=[Depends(get_api_key)])
# 定义统计数据的键,以减少魔法字符串
TOTAL_REQ_CNT = "total_requests"

37
src/common/security.py Normal file
View File

@@ -0,0 +1,37 @@
from fastapi import Depends, HTTPException, Security
from fastapi.security.api_key import APIKeyHeader
from starlette.status import HTTP_401_UNAUTHORIZED, HTTP_403_FORBIDDEN
from src.common.logger import get_logger
from src.config.config import global_config as bot_config
logger = get_logger("security")
API_KEY_HEADER = "X-API-Key"
api_key_header_auth = APIKeyHeader(name=API_KEY_HEADER, auto_error=True)
async def get_api_key(api_key: str = Security(api_key_header_auth)) -> str:
"""
FastAPI 依赖项用于验证API密钥。
从请求头中提取 X-API-Key 并验证它是否存在于配置的有效密钥列表中。
"""
valid_keys = bot_config.plugin_http_system.plugin_api_valid_keys
if not valid_keys:
logger.warning("API密钥认证已启用但未配置任何有效的API密钥。所有请求都将被拒绝。")
raise HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="服务未正确配置API密钥",
)
if api_key not in valid_keys:
logger.warning(f"无效的API密钥: {api_key}")
raise HTTPException(
status_code=HTTP_403_FORBIDDEN,
detail="无效的API密钥",
)
return api_key
# 创建一个可重用的依赖项,供插件开发者在其需要验证的端点上使用
# 用法: @router.get("/protected_route", dependencies=[VerifiedDep])
# 或者: async def my_endpoint(_=VerifiedDep): ...
VerifiedDep = Depends(get_api_key)

View File

@@ -1,32 +1,60 @@
import os
import socket
from fastapi import APIRouter, FastAPI
from fastapi import APIRouter, FastAPI, Request, Response
from fastapi.middleware.cors import CORSMiddleware
from rich.traceback import install
from uvicorn import Config
from uvicorn import Server as UvicornServer
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from slowapi.util import get_remote_address
from src.common.logger import get_logger
from src.config.config import global_config as bot_config
install(extra_lines=3)
logger = get_logger("Server")
def rate_limit_exceeded_handler(request: Request, exc: Exception) -> Response:
"""自定义速率限制超出处理器以解决类型提示问题"""
# 由于此处理器专门用于 RateLimitExceeded我们可以安全地断言异常类型。
# 这满足了类型检查器的要求,并确保了运行时安全。
assert isinstance(exc, RateLimitExceeded)
return _rate_limit_exceeded_handler(request, exc)
class Server:
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MaiMCore"):
def __init__(self, host: str | None = None, port: int | None = None, app_name: str = "MoFox-Bot"):
# 根据配置初始化速率限制器
limiter = Limiter(
key_func=get_remote_address,
default_limits=[bot_config.plugin_http_system.plugin_api_rate_limit_default],
)
self.app = FastAPI(title=app_name)
self.host: str = "127.0.0.1"
self.port: int = 8080
self._server: UvicornServer | None = None
self.set_address(host, port)
# 设置速率限制
self.app.state.limiter = limiter
self.app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler)
# 根据配置决定是否添加中间件
if bot_config.plugin_http_system.plugin_api_rate_limit_enable:
logger.info(f"已为插件API启用全局速率限制: {bot_config.plugin_http_system.plugin_api_rate_limit_default}")
self.app.add_middleware(SlowAPIMiddleware)
# 配置 CORS
origins = [
"http://localhost:3000", # 允许的前端源
"http://127.0.0.1:3000",
"http://127.0.0.1:3000",
# 在生产环境中,您应该添加实际的前端域名
]

View File

@@ -76,8 +76,6 @@ class ModelInfo(ValidatedConfigBase):
default="light", description="扰动强度light/medium/heavy"
)
enable_semantic_variants: bool = Field(default=False, description="是否启用语义变体作为扰动策略")
prepend_noise_instruction: bool = Field(default=False, description="是否在提示词前部添加抗审查指令")
@classmethod
def validate_prices(cls, v):
"""验证价格必须为非负数"""

View File

@@ -34,6 +34,7 @@ from src.config.official_configs import (
PermissionConfig,
PersonalityConfig,
PlanningSystemConfig,
PluginHttpSystemConfig,
ProactiveThinkingConfig,
ReactionConfig,
ResponsePostProcessConfig,
@@ -414,6 +415,9 @@ class Config(ValidatedConfigBase):
proactive_thinking: ProactiveThinkingConfig = Field(
default_factory=lambda: ProactiveThinkingConfig(), description="主动思考配置"
)
plugin_http_system: PluginHttpSystemConfig = Field(
default_factory=lambda: PluginHttpSystemConfig(), description="插件HTTP端点系统配置"
)
class APIAdapterConfig(ValidatedConfigBase):

View File

@@ -736,6 +736,23 @@ class CommandConfig(ValidatedConfigBase):
command_prefixes: list[str] = Field(default_factory=lambda: ["/", "!", ".", "#"], description="支持的命令前缀列表")
class PluginHttpSystemConfig(ValidatedConfigBase):
"""插件http系统相关配置"""
enable_plugin_http_endpoints: bool = Field(
default=True, description="总开关是否允许插件创建HTTP端点"
)
plugin_api_rate_limit_enable: bool = Field(
default=True, description="是否为插件API启用全局速率限制"
)
plugin_api_rate_limit_default: str = Field(
default="100/minute", description="插件API的默认速率限制策略"
)
plugin_api_valid_keys: list[str] = Field(
default_factory=list, description="有效的API密钥列表用于插件认证"
)
class MasterPromptConfig(ValidatedConfigBase):
"""主人身份提示词配置"""

View File

@@ -500,8 +500,8 @@ class _PromptProcessor:
final_prompt_parts = []
user_prompt = prompt
# 步骤 A: (可选) 添加抗审查指令
if getattr(model_info, "prepend_noise_instruction", False):
# 步骤 A: 添加抗审查指令
if model_info.enable_prompt_perturbation:
final_prompt_parts.append(self.noise_instruction)
# 步骤 B: (可选) 应用统一的提示词扰动
@@ -516,7 +516,7 @@ class _PromptProcessor:
final_prompt_parts.append(user_prompt)
# 步骤 C: (可选) 添加反截断指令
if getattr(model_info, "use_anti_truncation", False):
if model_info.anti_truncation:
final_prompt_parts.append(self.anti_truncation_instruction)
logger.info(f"模型 '{model_info.name}' (任务: '{task_name}') 已启用反截断功能。")
@@ -882,7 +882,7 @@ class _RequestStrategy:
# --- 响应内容处理和空回复/截断检查 ---
content = response.content or ""
use_anti_truncation = getattr(model_info, "use_anti_truncation", False)
use_anti_truncation = model_info.anti_truncation
processed_content, reasoning, is_truncated = await self.prompt_processor.process_response(
content, use_anti_truncation
)

View File

@@ -44,6 +44,7 @@ from .base import (
PluginInfo,
# 新增的增强命令系统
PlusCommand,
BaseRouterComponent,
PythonDependency,
ToolInfo,
ToolParamType,
@@ -56,7 +57,7 @@ from .utils.dependency_manager import configure_dependency_manager, get_dependen
__version__ = "2.0.0"
__all__ = [
__all__ = [ # noqa: RUF022
"ActionActivationType",
"ActionInfo",
"BaseAction",
@@ -82,6 +83,7 @@ __all__ = [
"PluginInfo",
# 增强命令系统
"PlusCommand",
"BaseRouterComponent"
"PythonDependency",
"ToolInfo",
"ToolParamType",
@@ -114,4 +116,4 @@ __all__ = [
# "ManifestGenerator",
# "validate_plugin_manifest",
# "generate_plugin_manifest",
]
] # type: ignore

View File

@@ -7,6 +7,7 @@
from .base_action import BaseAction
from .base_command import BaseCommand
from .base_events_handler import BaseEventHandler
from .base_http_component import BaseRouterComponent
from .base_plugin import BasePlugin
from .base_prompt import BasePrompt
from .base_tool import BaseTool
@@ -55,7 +56,7 @@ __all__ = [
"PluginMetadata",
# 增强命令系统
"PlusCommand",
"PlusCommandAdapter",
"BaseRouterComponent"
"PlusCommandInfo",
"PythonDependency",
"ToolInfo",

View File

@@ -0,0 +1,40 @@
from abc import ABC, abstractmethod
from fastapi import APIRouter
from .component_types import ComponentType, RouterInfo
class BaseRouterComponent(ABC):
"""
用于暴露HTTP端点的组件基类。
插件开发者应继承此类,并实现 register_endpoints 方法来定义API路由。
"""
# 组件元数据,由插件管理器读取
component_name: str
component_description: str
component_version: str = "1.0.0"
# 每个组件实例都会管理自己的APIRouter
router: APIRouter
def __init__(self):
self.router = APIRouter()
self.register_endpoints()
@abstractmethod
def register_endpoints(self) -> None:
"""
【开发者必须实现】
在此方法中定义所有HTTP端点。
"""
pass
@classmethod
def get_router_info(cls) -> "RouterInfo":
"""从类属性生成RouterInfo"""
return RouterInfo(
name=cls.component_name,
description=getattr(cls, "component_description", "路由组件"),
component_type=ComponentType.ROUTER,
)

View File

@@ -53,6 +53,7 @@ class ComponentType(Enum):
CHATTER = "chatter" # 聊天处理器组件
INTEREST_CALCULATOR = "interest_calculator" # 兴趣度计算组件
PROMPT = "prompt" # Prompt组件
ROUTER = "router" # 路由组件
def __str__(self) -> str:
return self.value
@@ -146,6 +147,7 @@ class PermissionNodeField:
node_name: str # 节点名称 (例如 "manage" 或 "view")
description: str # 权限描述
@dataclass
class ComponentInfo:
"""组件信息"""
@@ -442,3 +444,11 @@ class MaiMessages:
def __post_init__(self):
if self.message_segments is None:
self.message_segments = []
@dataclass
class RouterInfo(ComponentInfo):
"""路由组件信息"""
def __post_init__(self):
super().__post_init__()
self.component_type = ComponentType.ROUTER

View File

@@ -5,11 +5,15 @@ from pathlib import Path
from re import Pattern
from typing import Any, cast
from fastapi import Depends
from src.common.logger import get_logger
from src.config.config import global_config as bot_config
from src.plugin_system.base.base_action import BaseAction
from src.plugin_system.base.base_chatter import BaseChatter
from src.plugin_system.base.base_command import BaseCommand
from src.plugin_system.base.base_events_handler import BaseEventHandler
from src.plugin_system.base.base_http_component import BaseRouterComponent
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
from src.plugin_system.base.base_prompt import BasePrompt
from src.plugin_system.base.base_tool import BaseTool
@@ -24,6 +28,7 @@ from src.plugin_system.base.component_types import (
PluginInfo,
PlusCommandInfo,
PromptInfo,
RouterInfo,
ToolInfo,
)
from src.plugin_system.base.plus_command import PlusCommand, create_legacy_command_adapter
@@ -40,6 +45,7 @@ ComponentClassType = (
| type[BaseChatter]
| type[BaseInterestCalculator]
| type[BasePrompt]
| type[BaseRouterComponent]
)
@@ -194,6 +200,10 @@ class ComponentRegistry:
assert isinstance(component_info, PromptInfo)
assert issubclass(component_class, BasePrompt)
ret = self._register_prompt_component(component_info, component_class)
case ComponentType.ROUTER:
assert isinstance(component_info, RouterInfo)
assert issubclass(component_class, BaseRouterComponent)
ret = self._register_router_component(component_info, component_class)
case _:
logger.warning(f"未知组件类型: {component_type}")
ret = False
@@ -373,6 +383,43 @@ class ComponentRegistry:
logger.debug(f"已注册Prompt组件: {prompt_name}")
return True
def _register_router_component(self, router_info: RouterInfo, router_class: type[BaseRouterComponent]) -> bool:
"""注册Router组件并将其端点挂载到主服务器"""
# 1. 检查总开关是否开启
if not bot_config.plugin_http_system.enable_plugin_http_endpoints:
logger.info("插件HTTP端点功能已禁用跳过路由注册")
return True
try:
from src.common.server import get_global_server
router_name = router_info.name
plugin_name = router_info.plugin_name
# 2. 实例化组件以触发其 __init__ 和 register_endpoints
component_instance = router_class()
# 3. 获取配置好的 APIRouter
plugin_router = component_instance.router
# 4. 获取全局服务器实例
server = get_global_server()
# 5. 生成唯一的URL前缀
prefix = f"/plugins/{plugin_name}"
# 6. 注册路由并使用插件名作为API文档的分组标签
# 移除了dependencies参数因为现在由每个端点自行决定是否需要验证
server.app.include_router(
plugin_router, prefix=prefix, tags=[plugin_name]
)
logger.debug(f"成功将插件 '{plugin_name}' 的路由组件 '{router_name}' 挂载到: {prefix}")
return True
except Exception as e:
logger.error(f"注册路由组件 '{router_info.name}' 时出错: {e}", exc_info=True)
return False
# === 组件移除相关 ===
async def remove_component(self, component_name: str, component_type: ComponentType, plugin_name: str) -> bool:
@@ -616,6 +663,7 @@ class ComponentRegistry:
| BaseChatter
| BaseInterestCalculator
| BasePrompt
| BaseRouterComponent
]
| None
):
@@ -643,6 +691,8 @@ class ComponentRegistry:
| type[PlusCommand]
| type[BaseChatter]
| type[BaseInterestCalculator]
| type[BasePrompt]
| type[BaseRouterComponent]
| None,
self._components_classes.get(namespaced_name),
)
@@ -825,6 +875,7 @@ class ComponentRegistry:
def get_plugin_components(self, plugin_name: str) -> list["ComponentInfo"]:
"""获取插件的所有组件"""
plugin_info = self.get_plugin_info(plugin_name)
logger.info(plugin_info.components)
return plugin_info.components if plugin_info else []
def get_plugin_config(self, plugin_name: str) -> dict:
@@ -867,6 +918,7 @@ class ComponentRegistry:
plus_command_components: int = 0
chatter_components: int = 0
prompt_components: int = 0
router_components: int = 0
for component in self._components.values():
if component.component_type == ComponentType.ACTION:
action_components += 1
@@ -882,6 +934,8 @@ class ComponentRegistry:
chatter_components += 1
elif component.component_type == ComponentType.PROMPT:
prompt_components += 1
elif component.component_type == ComponentType.ROUTER:
router_components += 1
return {
"action_components": action_components,
"command_components": command_components,
@@ -891,6 +945,7 @@ class ComponentRegistry:
"plus_command_components": plus_command_components,
"chatter_components": chatter_components,
"prompt_components": prompt_components,
"router_components": router_components,
"total_components": len(self._components),
"total_plugins": len(self._plugins),
"components_by_type": {

View File

@@ -405,13 +405,14 @@ class PluginManager:
plus_command_count = stats.get("plus_command_components", 0)
chatter_count = stats.get("chatter_components", 0)
prompt_count = stats.get("prompt_components", 0)
router_count = stats.get("router_components", 0)
total_components = stats.get("total_components", 0)
# 📋 显示插件加载总览
if total_registered > 0:
logger.info("🎉 插件系统加载完成!")
logger.info(
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count}, Prompt: {prompt_count})"
f"📊 总览: {total_registered}个插件, {total_components}个组件 (Action: {action_count}, Command: {command_count}, Tool: {tool_count}, PlusCommand: {plus_command_count}, EventHandler: {event_handler_count}, Chatter: {chatter_count}, Prompt: {prompt_count}, Router: {router_count})"
)
# 显示详细的插件列表
@@ -452,6 +453,9 @@ class PluginManager:
prompt_components = [
c for c in plugin_info.components if c.component_type == ComponentType.PROMPT
]
router_components = [
c for c in plugin_info.components if c.component_type == ComponentType.ROUTER
]
if action_components:
action_details = [format_component(c) for c in action_components]
@@ -478,6 +482,9 @@ class PluginManager:
if prompt_components:
prompt_details = [format_component(c) for c in prompt_components]
logger.info(f" 📝 Prompt组件: {', '.join(prompt_details)}")
if router_components:
router_details = [format_component(c) for c in router_components]
logger.info(f" 🌐 Router组件: {', '.join(router_details)}")
# 权限节点信息
if plugin_instance := self.loaded_plugins.get(plugin_name):

View File

@@ -80,37 +80,39 @@ class ReplyTrackerService:
if old_data_file.exists():
logger.info(f"检测到旧的数据文件 '{old_data_file}',开始执行一次性迁移...")
try:
# 读取旧文件内容
# 步骤1: 读取旧文件内容并立即关闭文件
with open(old_data_file, "rb") as f:
file_content = f.read()
# 如果文件为空,直接删除,无需迁移
if not file_content.strip():
logger.warning("旧数据文件为空,无需迁移。")
os.remove(old_data_file)
logger.info(f"空的旧数据文件 '{old_data_file}' 已被删除。")
return
# 解析JSON数据
old_data = orjson.loads(file_content)
# 步骤2: 处理文件内容
# 如果文件为空,直接删除,无需迁移
if not file_content.strip():
logger.warning("旧数据文件为空,无需迁移。")
os.remove(old_data_file)
logger.info(f"空的旧数据文件 '{old_data_file}' 已被删除。")
return
# 验证数据格式是否正确
if self._validate_data(old_data):
# 验证通过将数据写入新的存储API
self.storage.set("data", old_data)
# 立即强制保存,确保迁移数据落盘
self.storage._save_data()
logger.info("旧数据已成功迁移到新的存储API。")
# 解析JSON数据
old_data = orjson.loads(file_content)
# 将旧文件重命名为备份文件,而不是直接删除,以防万一
backup_file = old_data_file.with_suffix(f".json.bak.migrated.{int(time.time())}")
old_data_file.rename(backup_file)
logger.info(f"旧数据文件已成功迁移并备份为: {backup_file}")
else:
# 如果数据格式无效,迁移中止,并备份损坏的文件
logger.error("旧数据文件格式无效,迁移中止")
backup_file = old_data_file.with_suffix(f".json.bak.invalid.{int(time.time())}")
old_data_file.rename(backup_file)
logger.warning(f"已将无效的旧数据文件备份为: {backup_file}")
# 步骤3: 验证数据并执行迁移/备份
if self._validate_data(old_data):
# 验证通过将数据写入新的存储API
self.storage.set("data", old_data)
# 立即强制保存,确保迁移数据落盘
self.storage._save_data()
logger.info("旧数据已成功迁移到新的存储API")
# 将旧文件重命名为备份文件
backup_file = old_data_file.with_suffix(f".json.bak.migrated.{int(time.time())}")
old_data_file.rename(backup_file)
logger.info(f"旧数据文件已成功迁移并备份为: {backup_file}")
else:
# 如果数据格式无效,迁移中止,并备份损坏的文件
logger.error("旧数据文件格式无效,迁移中止。")
backup_file = old_data_file.with_suffix(f".json.bak.invalid.{int(time.time())}")
old_data_file.rename(backup_file)
logger.warning(f"已将无效的旧数据文件备份为: {backup_file}")
except Exception as e:
# 捕获迁移过程中可能出现的任何异常