feat(plugin): 集成 MCP 协议支持并优化代码风格

- 新增 fastmcp 依赖,支持通过 Streamable HTTP 连接外部工具服务器
- 在 component_registry 与 tool_api 中实现 MCP 工具加载、注册及调用链路
- 补充 README 中的 MCP 特性说明
- 统一修复多处 import 顺序、空行、引号及类型注解,提升代码整洁度
- 在 pyproject.toml 中忽略 PERF203 规则,允许循环内异常处理
- 优化语音缓存与本地 ASR 调用逻辑,减少冗余代码
This commit is contained in:
明天好像没什么
2025-10-26 13:10:31 +08:00
parent 5e6857c8f7
commit 7b80d7c0b3
31 changed files with 1034 additions and 43 deletions

View File

@@ -85,6 +85,7 @@
- 🧠 **拓展记忆系统** - 支持瞬时记忆等多种记忆
- 🎪 **完善的 Event** - 支持动态事件注册和处理器订阅,并实现了聚合结果管理
- 🔍 **内嵌魔改插件** - 内置联网搜索等诸多功能,等你来探索
- 🔌 **MCP 协议支持** - 集成 Model Context Protocol支持外部工具服务器连接仅 Streamable HTTP
- 🌟 **还有更多** - 请参阅详细修改 [commits](https://github.com/MoFox-Studio/MoFox_Bot/commits)
</td>

View File

@@ -19,8 +19,8 @@ from src.plugin_system import (
ToolParamType,
register_plugin,
)
from src.plugin_system.base.component_types import InjectionRule,InjectionType
from src.plugin_system.base.base_event import HandlerResult
from src.plugin_system.base.component_types import InjectionRule, InjectionType
logger = get_logger("hello_world_plugin")

View File

@@ -74,6 +74,7 @@ dependencies = [
"aiosqlite>=0.21.0",
"inkfox>=0.1.1",
"rjieba>=0.1.13",
"fastmcp>=2.13.0",
]
[[tool.uv.index]]
@@ -110,6 +111,7 @@ ignore = [
"RUF001", # ambiguous-unicode-character-string
"RUF002", # ambiguous-unicode-character-docstring
"RUF003", # ambiguous-unicode-character-comment
"PERF203", # try-except-in-loop (我们需要单独处理每个项的错误)
]

View File

@@ -9,6 +9,7 @@ customtkinter
dotenv
faiss-cpu
fastapi
fastmcp
rjieba
jsonlines
maim_message

View File

@@ -0,0 +1,59 @@
"""
调试 MCP 工具列表获取
直接测试 MCP 客户端是否能获取工具
"""
import asyncio
from fastmcp.client import Client, StreamableHttpTransport
async def test_direct_connection():
"""直接连接 MCP 服务器并获取工具列表"""
print("=" * 60)
print("直接测试 MCP 服务器连接")
print("=" * 60)
url = "http://localhost:8000/mcp"
print(f"\n连接到: {url}")
try:
# 创建传输层
transport = StreamableHttpTransport(url)
print("✓ 传输层创建成功")
# 创建客户端
async with Client(transport) as client:
print("✓ 客户端连接成功")
# 获取工具列表
print("\n正在获取工具列表...")
tools_result = await client.list_tools()
print(f"\n获取结果类型: {type(tools_result)}")
print(f"结果内容: {tools_result}")
# 检查是否有 tools 属性
if hasattr(tools_result, "tools"):
tools = tools_result.tools
print(f"\n✓ 找到 tools 属性,包含 {len(tools)} 个工具")
for i, tool in enumerate(tools, 1):
print(f"\n工具 {i}:")
print(f" 名称: {tool.name}")
print(f" 描述: {tool.description}")
if hasattr(tool, "inputSchema"):
print(f" 参数 Schema: {tool.inputSchema}")
else:
print("\n✗ 结果中没有 tools 属性")
print(f"可用属性: {dir(tools_result)}")
except Exception as e:
print(f"\n✗ 连接失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_direct_connection())

View File

@@ -0,0 +1,142 @@
"""
简单的 MCP 测试服务器
使用 fastmcp 创建一个简单的 MCP 服务器供测试使用。
安装依赖:
pip install fastmcp uvicorn
运行服务器:
python scripts/simple_mcp_server.py
服务器将在 http://localhost:8000/mcp 提供 MCP 服务
"""
from fastmcp import FastMCP
# 创建 MCP 服务器实例
mcp = FastMCP("Demo Server")
@mcp.tool()
def add(a: int, b: int) -> int:
"""将两个数字相加
Args:
a: 第一个数字
b: 第二个数字
Returns:
两个数字的和
"""
return a + b
@mcp.tool()
def multiply(a: float, b: float) -> float:
"""将两个数字相乘
Args:
a: 第一个数字
b: 第二个数字
Returns:
两个数字的乘积
"""
return a * b
@mcp.tool()
def get_weather(city: str) -> str:
"""获取指定城市的天气信息(模拟)
Args:
city: 城市名称
Returns:
天气信息字符串
"""
# 这是一个模拟实现
weather_data = {
"beijing": "北京今天晴朗,温度 20°C",
"shanghai": "上海今天多云,温度 18°C",
"guangzhou": "广州今天有雨,温度 25°C",
}
city_lower = city.lower()
return weather_data.get(
city_lower,
f"{city}的天气信息暂不可用"
)
@mcp.tool()
def echo(message: str, repeat: int = 1) -> str:
"""重复输出一条消息
Args:
message: 要重复的消息
repeat: 重复次数,默认为 1
Returns:
重复后的消息
"""
return (message + "\n") * repeat
@mcp.tool()
def check_prime(number: int) -> bool:
"""检查一个数字是否为质数
Args:
number: 要检查的数字
Returns:
如果是质数返回 True否则返回 False
"""
if number < 2:
return False
for i in range(2, int(number ** 0.5) + 1):
if number % i == 0:
return False
return True
if __name__ == "__main__":
print("=" * 60)
print("简单 MCP 测试服务器")
print("=" * 60)
print("\n服务器配置:")
print(" - 传输协议: Streamable HTTP")
print(" - 地址: http://localhost:8000/mcp")
print("\n可用工具:")
print(" 1. add(a: int, b: int) -> int")
print(" 2. multiply(a: float, b: float) -> float")
print(" 3. get_weather(city: str) -> str")
print(" 4. echo(message: str, repeat: int = 1) -> str")
print(" 5. check_prime(number: int) -> bool")
print("\n配置示例 (config/mcp.json):")
print("""
{
"$schema": "./mcp.schema.json",
"mcpServers": {
"demo_server": {
"enabled": true,
"transport": {
"type": "streamable-http",
"url": "http://localhost:8000/mcp"
},
"timeout": 30
}
}
}
""")
print("=" * 60)
print("\n正在启动服务器...")
print("请参考 fastmcp 官方文档了解如何运行此服务器。")
print("文档: https://github.com/jlowin/fastmcp")
print("\n基本命令:")
print(" fastmcp run simple_mcp_server:mcp")
print("=" * 60)

View File

@@ -0,0 +1,15 @@
from fastmcp import FastMCP
app = FastMCP(
name="Demo MCP Server",
streamable_http_path="/mcp"
)
@app.tool()
async def echo_tool(input: str) -> str:
"""一个简单的回声工具"""
return f"Echo: {input}"
if __name__ == "__main__":
app.run(host="0.0.0.0", port=8000, transport="streamable-http"
)

View File

@@ -0,0 +1,190 @@
"""
MCP 集成测试脚本
测试 MCP 客户端连接、工具列表获取和工具调用功能
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.common.logger import get_logger
from src.plugin_system.core.component_registry import ComponentRegistry
from src.plugin_system.core.mcp_client_manager import MCPClientManager
logger = get_logger("test_mcp_integration")
async def test_mcp_client_manager():
"""测试 MCPClientManager 基本功能"""
print("\n" + "="*60)
print("测试 1: MCPClientManager 连接和工具列表")
print("="*60)
try:
# 初始化 MCP 客户端管理器
manager = MCPClientManager()
await manager.initialize()
print("\n✓ MCP 客户端管理器初始化成功")
print(f"已连接服务器数量: {len(manager.clients)}")
# 获取所有工具
tools = await manager.get_all_tools()
print(f"\n获取到 {len(tools)} 个 MCP 工具:")
for tool in tools:
print(f"\n 工具: {tool}")
# 注意: 这里 tool 是字符串形式的工具名称
# 如果需要工具详情,需要从其他地方获取
return manager, tools
except Exception as e:
print(f"\n✗ 测试失败: {e}")
logger.exception("MCPClientManager 测试失败")
return None, []
async def test_tool_call(manager: MCPClientManager, tools):
"""测试工具调用"""
print("\n" + "="*60)
print("测试 2: MCP 工具调用")
print("="*60)
if not tools:
print("\n⚠ 没有可用的工具进行测试")
return
try:
# 工具列表测试已在第一个测试中完成
print("\n✓ 工具列表获取成功")
print(f"可用工具数量: {len(tools)}")
except Exception as e:
print(f"\n✗ 工具调用测试失败: {e}")
logger.exception("工具调用测试失败")
async def test_component_registry_integration():
"""测试 ComponentRegistry 集成"""
print("\n" + "="*60)
print("测试 3: ComponentRegistry MCP 工具集成")
print("="*60)
try:
registry = ComponentRegistry()
# 加载 MCP 工具
await registry.load_mcp_tools()
# 获取 MCP 工具
mcp_tools = registry.get_mcp_tools()
print(f"\n✓ ComponentRegistry 加载了 {len(mcp_tools)} 个 MCP 工具")
for tool in mcp_tools:
print(f"\n 工具: {tool.name}")
print(f" 描述: {tool.description}")
print(f" 参数数量: {len(tool.parameters)}")
# 测试 is_mcp_tool 方法
is_mcp = registry.is_mcp_tool(tool.name)
print(f" is_mcp_tool 检测: {'' if is_mcp else ''}")
return mcp_tools
except Exception as e:
print(f"\n✗ ComponentRegistry 集成测试失败: {e}")
logger.exception("ComponentRegistry 集成测试失败")
return []
async def test_tool_execution(mcp_tools):
"""测试通过适配器执行工具"""
print("\n" + "="*60)
print("测试 4: MCPToolAdapter 工具执行")
print("="*60)
if not mcp_tools:
print("\n⚠ 没有可用的 MCP 工具进行测试")
return
try:
# 选择第一个工具测试
test_tool = mcp_tools[0]
print(f"\n测试工具: {test_tool.name}")
# 构建测试参数
test_args = {}
for param_name, param_type, param_desc, is_required, enum_values in test_tool.parameters:
if is_required:
# 根据类型提供默认值
from src.llm_models.payload_content.tool_option import ToolParamType
if param_type == ToolParamType.STRING:
test_args[param_name] = "test_value"
elif param_type == ToolParamType.INTEGER:
test_args[param_name] = 1
elif param_type == ToolParamType.FLOAT:
test_args[param_name] = 1.0
elif param_type == ToolParamType.BOOLEAN:
test_args[param_name] = True
print(f"测试参数: {test_args}")
# 执行工具
result = await test_tool.execute(test_args)
if result:
print("\n✓ 工具执行成功")
print(f"结果类型: {result.get('type')}")
print(f"结果内容: {result.get('content', '')[:200]}...") # 只显示前200字符
else:
print("\n✗ 工具执行失败,返回 None")
except Exception as e:
print(f"\n✗ 工具执行测试失败: {e}")
logger.exception("工具执行测试失败")
async def main():
"""主测试流程"""
print("\n" + "="*60)
print("MCP 集成测试")
print("="*60)
try:
# 测试 1: MCPClientManager 基本功能
manager, tools = await test_mcp_client_manager()
if manager:
# 测试 2: 工具调用
await test_tool_call(manager, tools)
# 测试 3: ComponentRegistry 集成
mcp_tools = await test_component_registry_integration()
# 测试 4: 工具执行
await test_tool_execution(mcp_tools)
# 关闭连接
await manager.close()
print("\n✓ MCP 客户端连接已关闭")
print("\n" + "="*60)
print("测试完成")
print("="*60 + "\n")
except KeyboardInterrupt:
print("\n\n测试被用户中断")
except Exception as e:
print(f"\n测试过程中发生错误: {e}")
logger.exception("测试失败")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -219,4 +219,4 @@ async def get_llm_stats(
raise e
except Exception as e:
logger.error(f"获取LLM统计信息失败: {e}")
raise HTTPException(status_code=500, detail=str(e))
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -878,7 +878,7 @@ class MemorySystem:
except Exception as e:
logger.warning(f"检索瞬时记忆失败: {e}", exc_info=True)
# 最终截断
final_memories = final_memories[:effective_limit]

View File

@@ -72,4 +72,4 @@ class MessageCollectionProcessor:
"active_buffers": len(self.message_buffers),
"total_buffered_messages": total_buffered_messages,
"buffer_capacity_per_chat": self.buffer_size,
}
}

View File

@@ -3,7 +3,6 @@
专用于存储和检索消息集合,以提供即时上下文。
"""
import asyncio
import time
from typing import Any
@@ -125,7 +124,7 @@ class MessageCollectionStorage:
if results and results.get("ids") and results["ids"][0]:
for metadata in results["metadatas"][0]:
collections.append(MessageCollection.from_dict(metadata))
return collections
except Exception as e:
logger.error(f"检索相关消息集合失败: {e}", exc_info=True)
@@ -163,7 +162,7 @@ class MessageCollectionStorage:
# 格式化消息集合为 prompt 上下文
final_context = "\n\n---\n\n".join(context_parts) + "\n\n---"
logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文")
return f"\n{final_context}\n"
@@ -192,4 +191,4 @@ class MessageCollectionStorage:
}
except Exception as e:
logger.error(f"获取消息集合存储统计失败: {e}")
return {}
return {}

View File

@@ -9,10 +9,10 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
from rich.traceback import install
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.chat.utils.self_voice_cache import consume_self_voice_text
from src.common.logger import get_logger
from src.config.config import global_config
@@ -212,7 +212,7 @@ class MessageRecv(Message):
return f"[语音:{cached_text}]"
else:
logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。")
# 标准语音识别流程 (也作为缓存未命中的后备方案)
if isinstance(segment.data, str):
return await get_voice_text(segment.data)
@@ -364,7 +364,7 @@ class MessageRecvS4U(MessageRecv):
self.is_picid = False
self.is_emoji = False
self.is_voice = True
# 检查消息是否由机器人自己发送
# 检查消息是否由机器人自己发送
if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account):

View File

@@ -303,7 +303,7 @@ class Prompt:
@staticmethod
def _process_escaped_braces(template) -> str:
"""预处理模板,将 `\{` 和 `\}` 替换为临时标记."""
r"""预处理模板,将 `\{` 和 `\}` 替换为临时标记."""
if isinstance(template, list):
template = "\n".join(str(item) for item in template)
elif not isinstance(template, str):

View File

@@ -1,6 +1,5 @@
import asyncio
import re
from typing import Type
from src.chat.utils.prompt_params import PromptParameters
from src.common.logger import get_logger
@@ -21,7 +20,7 @@ class PromptComponentManager:
3. 提供一个接口以便在构建核心Prompt时能够获取并执行所有相关的组件。
"""
def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, Type[BasePrompt]]]:
def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, type[BasePrompt]]]:
"""
获取指定目标Prompt的所有注入规则及其关联的组件类。

View File

@@ -6,15 +6,14 @@
避免不必要的自我语音识别。
"""
import hashlib
from typing import Dict
# 一个简单的内存缓存,用于将机器人自己发送的语音消息映射到其原始文本。
# 键是语音base64内容的SHA256哈希值。
_self_voice_cache: Dict[str, str] = {}
_self_voice_cache: dict[str, str] = {}
def get_voice_key(base64_content: str) -> str:
"""为语音内容生成一个一致的键。"""
return hashlib.sha256(base64_content.encode('utf-8')).hexdigest()
return hashlib.sha256(base64_content.encode("utf-8")).hexdigest()
def register_self_voice(base64_content: str, text: str):
"""
@@ -39,4 +38,4 @@ def consume_self_voice_text(base64_content: str) -> str | None:
str | None: 如果找到则返回原始文本否则返回None。
"""
key = get_voice_key(base64_content)
return _self_voice_cache.pop(key, None)
return _self_voice_cache.pop(key, None)

View File

@@ -19,10 +19,11 @@ async def get_voice_text(voice_base64: str) -> str:
# 如果选择本地识别
if asr_provider == "local":
from src.plugin_system.apis import tool_api
import tempfile
import base64
import os
import tempfile
from src.plugin_system.apis import tool_api
local_asr_tool = tool_api.get_tool_instance("local_asr")
if not local_asr_tool:
@@ -39,8 +40,8 @@ async def get_voice_text(voice_base64: str) -> str:
text = await local_asr_tool.execute(function_args={"audio_path": audio_path})
if "失败" in text or "出错" in text or "错误" in text:
logger.warning(f"本地语音识别失败: {text}")
return f"[语音(本地识别失败)]"
return "[语音(本地识别失败)]"
logger.info(f"本地语音识别成功: {text}")
return f"[语音] {text}"

View File

@@ -4,9 +4,10 @@ import signal
import sys
import time
import traceback
from collections.abc import Callable, Coroutine
from functools import partial
from random import choices
from typing import Any, Callable, Coroutine
from typing import Any
from maim_message import MessageServer
from rich.traceback import install
@@ -24,8 +25,8 @@ from src.config.config import global_config
from src.individuality.individuality import Individuality, get_individuality
from src.manager.async_task_manager import async_task_manager
from src.mood.mood_manager import mood_manager
from src.plugin_system.base.component_types import EventType
from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator
from src.plugin_system.base.component_types import EventType
from src.plugin_system.core.event_manager import event_manager
from src.plugin_system.core.plugin_manager import plugin_manager
from src.schedule.monthly_plan_manager import monthly_plan_manager
@@ -418,7 +419,7 @@ MoFox_Bot(第三方修改版)
# 处理所有缓存的事件订阅(插件加载完成后)
event_manager.process_all_pending_subscriptions()
# 初始化表情管理器
get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功")

View File

View File

@@ -0,0 +1 @@


View File

View File

View File

@@ -1,4 +1,5 @@
from typing import Any
from src.common.logger import get_logger
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ComponentType
@@ -22,7 +23,7 @@ def get_tool_instance(tool_name: str) -> BaseTool | None:
def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
"""获取LLM可用的工具定义列表
"""获取LLM可用的工具定义列表(包括 MCP 工具)
Returns:
list[dict[str, Any]]: 工具定义列表
@@ -31,6 +32,8 @@ def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
llm_available_tools = component_registry.get_llm_available_tools()
tool_definitions = []
# 获取常规工具定义
for tool_name, tool_class in llm_available_tools.items():
try:
# 调用类方法 get_tool_definition 获取定义
@@ -38,5 +41,18 @@ def get_llm_available_tool_definitions() -> list[dict[str, Any]]:
tool_definitions.append(definition)
except Exception as e:
logger.error(f"获取工具 {tool_name} 的定义失败: {e}")
# 获取 MCP 工具定义
try:
mcp_tools = component_registry.get_mcp_tools()
for mcp_tool in mcp_tools:
try:
definition = mcp_tool.get_tool_definition()
tool_definitions.append(definition)
except Exception as e:
logger.error(f"获取 MCP 工具 {mcp_tool.name} 的定义失败: {e}")
except Exception as e:
logger.debug(f"获取 MCP 工具列表失败(可能未启用): {e}")
return tool_definitions

View File

@@ -87,6 +87,10 @@ class ComponentRegistry:
self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类
self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类
# MCP 工具注册表(运行时动态加载)
self._mcp_tools: list["BaseTool"] = [] # MCP 工具适配器实例列表
self._mcp_tools_loaded = False # MCP 工具是否已加载
# EventHandler特定注册表
self._event_handler_registry: dict[str, type["BaseEventHandler"]] = {}
"""event_handler名 -> event_handler类"""
@@ -891,6 +895,7 @@ class ComponentRegistry:
"action_components": action_components,
"command_components": command_components,
"tool_components": tool_components,
"mcp_tools": len(self._mcp_tools),
"event_handlers": events_handlers,
"plus_command_components": plus_command_components,
"chatter_components": chatter_components,
@@ -904,6 +909,34 @@ class ComponentRegistry:
"enabled_plugins": len([p for p in self._plugins.values() if p.enabled]),
}
# === MCP 工具相关方法 ===
async def load_mcp_tools(self) -> None:
"""加载 MCP 工具(异步方法)"""
if self._mcp_tools_loaded:
logger.debug("MCP 工具已加载,跳过")
return
try:
from .mcp_tool_adapter import load_mcp_tools_as_adapters
logger.info("开始加载 MCP 工具...")
self._mcp_tools = await load_mcp_tools_as_adapters()
self._mcp_tools_loaded = True
logger.info(f"MCP 工具加载完成,共 {len(self._mcp_tools)} 个工具")
except Exception as e:
logger.error(f"加载 MCP 工具失败: {e}")
self._mcp_tools = []
self._mcp_tools_loaded = True # 标记为已尝试加载,避免重复尝试
def get_mcp_tools(self) -> list["BaseTool"]:
"""获取所有 MCP 工具适配器实例"""
return self._mcp_tools.copy()
def is_mcp_tool(self, tool_name: str) -> bool:
"""检查工具名是否为 MCP 工具"""
return tool_name.startswith("mcp_")
# === 组件移除相关 ===
async def unregister_plugin(self, plugin_name: str) -> bool:

View File

@@ -0,0 +1,266 @@
"""
MCP Client Manager
管理多个 MCP (Model Context Protocol) 客户端连接,支持动态加载和工具注册
"""
import asyncio
import json
from pathlib import Path
from typing import Any
import mcp.types
from fastmcp.client import Client, StreamableHttpTransport
from src.common.logger import get_logger
logger = get_logger("mcp_client_manager")
class MCPServerConfig:
"""单个 MCP 服务器的配置"""
def __init__(self, name: str, config: dict[str, Any]):
self.name = name
self.description = config.get("description", "")
self.enabled = config.get("enabled", True)
self.transport_config = config["transport"]
self.auth_config = config.get("auth")
self.timeout = config.get("timeout", 30)
self.retry_config = config.get("retry", {"max_retries": 3, "retry_delay": 1})
def __repr__(self):
return f"<MCPServerConfig {self.name} (enabled={self.enabled})>"
class MCPClientManager:
"""
MCP 客户端管理器
负责:
1. 从配置文件加载 MCP 服务器配置
2. 建立和维护与 MCP 服务器的连接
3. 获取可用的工具列表
4. 执行工具调用
"""
def __init__(self, config_path: str | Path | None = None):
"""
初始化 MCP 客户端管理器
Args:
config_path: mcp.json 配置文件路径,默认为 config/mcp.json
"""
if config_path is None:
# 默认配置路径
config_path = Path(__file__).parent.parent.parent.parent / "config" / "mcp.json"
self.config_path = Path(config_path)
self.servers: dict[str, MCPServerConfig] = {}
self.clients: dict[str, Client] = {}
self._initialized = False
self._lock = asyncio.Lock()
logger.info(f"MCP 客户端管理器初始化,配置文件: {self.config_path}")
def load_config(self) -> dict[str, MCPServerConfig]:
"""
从配置文件加载 MCP 服务器配置
Returns:
Dict[str, MCPServerConfig]: 服务器名称 -> 配置对象
"""
if not self.config_path.exists():
logger.warning(f"MCP 配置文件不存在: {self.config_path}")
return {}
try:
with open(self.config_path, encoding="utf-8") as f:
config_data = json.load(f)
servers = {}
mcp_servers = config_data.get("mcpServers", {})
for server_name, server_config in mcp_servers.items():
try:
server = MCPServerConfig(server_name, server_config)
servers[server_name] = server
logger.debug(f"加载 MCP 服务器配置: {server}")
except Exception as e:
logger.error(f"加载服务器配置 '{server_name}' 失败: {e}")
continue
logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置")
return servers
except json.JSONDecodeError as e:
logger.error(f"解析 MCP 配置文件失败: {e}")
return {}
except Exception as e:
logger.error(f"读取 MCP 配置文件失败: {e}")
return {}
async def initialize(self) -> None:
"""
初始化所有启用的 MCP 客户端连接
这个方法会:
1. 加载配置文件
2. 为每个启用的服务器创建客户端
3. 建立连接并验证
"""
async with self._lock:
if self._initialized:
logger.debug("MCP 客户端管理器已初始化,跳过")
return
logger.info("开始初始化 MCP 客户端连接...")
# 加载配置
self.servers = self.load_config()
if not self.servers:
logger.warning("没有找到任何 MCP 服务器配置")
self._initialized = True
return
# 为每个启用的服务器创建客户端
for server_name, server_config in self.servers.items():
if not server_config.enabled:
logger.debug(f"服务器 '{server_name}' 未启用,跳过")
continue
try:
client = await self._create_client(server_config)
self.clients[server_name] = client
logger.info(f"✅ MCP 服务器 '{server_name}' 连接成功")
except Exception as e:
logger.error(f"❌ 连接 MCP 服务器 '{server_name}' 失败: {e}")
continue
self._initialized = True
logger.info(f"MCP 客户端管理器初始化完成,成功连接 {len(self.clients)}/{len(self.servers)} 个服务器")
async def _create_client(self, server_config: MCPServerConfig) -> Client:
"""
根据配置创建 MCP 客户端
Args:
server_config: 服务器配置
Returns:
Client: 已连接的 MCP 客户端
"""
transport_type = server_config.transport_config.get("type", "streamable-http")
if transport_type == "streamable-http":
url = server_config.transport_config["url"]
transport = StreamableHttpTransport(url)
# 设置认证(如果有)
if server_config.auth_config:
auth_type = server_config.auth_config.get("type")
if auth_type == "bearer":
from fastmcp.client.auth import BearerAuth
token = server_config.auth_config.get("token", "")
transport._set_auth(BearerAuth(token))
client = Client(transport, timeout=server_config.timeout)
elif transport_type == "sse":
from fastmcp.client import SSETransport
url = server_config.transport_config["url"]
client = Client(SSETransport(url), timeout=server_config.timeout)
else:
raise ValueError(f"不支持的传输类型: {transport_type}")
# 进入客户端上下文(建立连接)
await client.__aenter__()
return client
async def get_all_tools(self) -> dict[str, list[mcp.types.Tool]]:
"""
获取所有 MCP 服务器提供的工具列表
Returns:
Dict[str, List[mcp.types.Tool]]: 服务器名称 -> 工具列表
"""
if not self._initialized:
await self.initialize()
all_tools = {}
for server_name, client in self.clients.items():
try:
# fastmcp 的 list_tools() 直接返回 List[Tool],不是包含 tools 属性的对象
tools = await client.list_tools()
all_tools[server_name] = tools
logger.debug(f"从服务器 '{server_name}' 获取到 {len(tools)} 个工具")
except Exception as e:
logger.error(f"从服务器 '{server_name}' 获取工具列表失败: {e}")
all_tools[server_name] = []
return all_tools
async def call_tool(
self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None
) -> Any:
"""
调用指定 MCP 服务器的工具
Args:
server_name: 服务器名称
tool_name: 工具名称
arguments: 工具参数
Returns:
Any: 工具执行结果CallToolResult 的兼容类型)
"""
if not self._initialized:
await self.initialize()
if server_name not in self.clients:
raise ValueError(f"MCP 服务器 '{server_name}' 未连接")
client = self.clients[server_name]
try:
logger.debug(f"调用 MCP 工具: {server_name}.{tool_name} | 参数: {arguments}")
result = await client.call_tool(tool_name, arguments or {})
logger.debug(f"MCP 工具调用成功: {server_name}.{tool_name}")
return result
except Exception as e:
logger.error(f"MCP 工具调用失败: {server_name}.{tool_name} | 错误: {e}")
raise
async def close(self) -> None:
"""关闭所有 MCP 客户端连接"""
async with self._lock:
if not self._initialized:
return
logger.info("关闭所有 MCP 客户端连接...")
for server_name, client in self.clients.items():
try:
await client.__aexit__(None, None, None)
logger.debug(f"已关闭 MCP 服务器 '{server_name}' 的连接")
except Exception as e:
logger.error(f"关闭服务器 '{server_name}' 连接失败: {e}")
self.clients.clear()
self._initialized = False
logger.info("所有 MCP 客户端连接已关闭")
def __repr__(self):
return f"<MCPClientManager servers={len(self.servers)} clients={len(self.clients)}>"
# 全局单例
mcp_client_manager = MCPClientManager()

View File

@@ -0,0 +1,249 @@
"""
MCP Tool Adapter
将 MCP 工具适配为 BaseTool使其能够被插件系统识别和调用
"""
from typing import Any, ClassVar
import mcp.types
from src.common.logger import get_logger
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.base.component_types import ToolParamType
from .mcp_client_manager import mcp_client_manager
logger = get_logger("mcp_tool_adapter")
class MCPToolAdapter(BaseTool):
"""
MCP 工具适配器
将 MCP 协议的工具适配为 BaseTool使其能够
1. 被插件系统识别和注册
2. 被 LLM 调用
3. 参与工具缓存机制
"""
# 类级别默认值,使用 ClassVar 标注
available_for_llm: ClassVar[bool] = True
def __init__(self, server_name: str, mcp_tool: mcp.types.Tool, plugin_config: dict | None = None):
"""
初始化 MCP 工具适配器
Args:
server_name: MCP 服务器名称
mcp_tool: MCP 工具对象
plugin_config: 插件配置(可选)
"""
super().__init__(plugin_config)
self.server_name = server_name
self.mcp_tool = mcp_tool
# 设置实例属性
self.name = f"mcp_{server_name}_{mcp_tool.name}"
self.description = mcp_tool.description or f"MCP tool from {server_name}"
# 转换参数定义
self.parameters = self._convert_parameters(mcp_tool.inputSchema)
logger.debug(f"创建 MCP 工具适配器: {self.name}")
def _convert_parameters(
self, input_schema: dict[str, Any] | None
) -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]:
"""
将 MCP 工具的 JSON Schema 参数转换为 BaseTool 参数格式
Args:
input_schema: MCP 工具的 inputSchema (JSON Schema)
Returns:
List[Tuple]: BaseTool 参数格式列表
"""
if not input_schema:
return []
parameters = []
# JSON Schema 通常有 properties 和 required 字段
properties = input_schema.get("properties", {})
required_fields = input_schema.get("required", [])
for param_name, param_def in properties.items():
# 获取参数类型
param_type_str = param_def.get("type", "string")
param_type = self._map_json_type_to_tool_param_type(param_type_str)
# 获取参数描述
param_desc = param_def.get("description", f"Parameter {param_name}")
# 判断是否必填
is_required = param_name in required_fields
# 获取枚举值(如果有)
enum_values = param_def.get("enum")
parameters.append((param_name, param_type, param_desc, is_required, enum_values))
return parameters
@staticmethod
def _map_json_type_to_tool_param_type(json_type: str) -> ToolParamType:
"""
将 JSON Schema 类型映射到 ToolParamType
Args:
json_type: JSON Schema 类型字符串
Returns:
ToolParamType: 对应的工具参数类型
"""
type_mapping = {
"string": ToolParamType.STRING,
"integer": ToolParamType.INTEGER,
"number": ToolParamType.FLOAT,
"boolean": ToolParamType.BOOLEAN,
}
return type_mapping.get(json_type, ToolParamType.STRING)
async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]:
"""
执行 MCP 工具调用
Args:
function_args: 工具调用参数
Returns:
Dict: 工具执行结果
"""
try:
logger.debug(f"执行 MCP 工具: {self.name} | 服务器: {self.server_name} | 参数: {function_args}")
# 移除 llm_called 标记(这是内部使用的)
clean_args = {k: v for k, v in function_args.items() if k != "llm_called"}
# 调用 MCP 客户端管理器执行工具
result = await mcp_client_manager.call_tool(
server_name=self.server_name, tool_name=self.mcp_tool.name, arguments=clean_args
)
# 解析结果
return self._format_result(result)
except Exception as e:
logger.error(f"MCP 工具执行失败: {self.name} | 错误: {e}")
return {
"type": "error",
"content": f"MCP 工具调用失败: {e!s}",
"id": self.name,
}
def _format_result(self, result: mcp.types.CallToolResult) -> dict[str, Any]:
"""
格式化 MCP 工具执行结果为标准格式
Args:
result: MCP CallToolResult 对象
Returns:
Dict: 标准化的工具执行结果
"""
# MCP 结果包含 content 列表
if not result.content:
return {
"type": "mcp_result",
"content": "",
"id": self.name,
}
# 提取所有内容
content_parts = []
for content_item in result.content:
# 根据内容类型提取文本
content_type = getattr(content_item, "type", None)
if content_type == "text":
# TextContent 类型
text = getattr(content_item, "text", "")
content_parts.append(text)
elif content_type == "image":
# ImageContent 类型
data = getattr(content_item, "data", b"")
content_parts.append(f"[Image data: {len(data)} bytes]")
elif content_type == "audio":
# AudioContent 类型
data = getattr(content_item, "data", b"")
content_parts.append(f"[Audio data: {len(data)} bytes]")
else:
# 尝试提取 text 或 data 属性
text = getattr(content_item, "text", None)
if text is not None:
content_parts.append(text)
else:
data = getattr(content_item, "data", None)
if data is not None:
data_len = len(data) if hasattr(data, "__len__") else "unknown"
content_parts.append(f"[Binary data: {data_len} bytes]")
else:
content_parts.append(str(content_item))
return {
"type": "mcp_result",
"content": "\n".join(content_parts),
"id": self.name,
"is_error": getattr(result, "isError", False),
}
@classmethod
def from_mcp_tool(cls, server_name: str, mcp_tool: mcp.types.Tool) -> "MCPToolAdapter":
"""
从 MCP 工具对象创建适配器实例
Args:
server_name: MCP 服务器名称
mcp_tool: MCP 工具对象
Returns:
MCPToolAdapter: 工具适配器实例
"""
return cls(server_name, mcp_tool)
async def load_mcp_tools_as_adapters() -> list[MCPToolAdapter]:
"""
加载所有 MCP 工具并转换为适配器
Returns:
List[MCPToolAdapter]: 工具适配器列表
"""
logger.info("开始加载 MCP 工具...")
# 初始化 MCP 客户端管理器
await mcp_client_manager.initialize()
# 获取所有工具
all_tools_dict = await mcp_client_manager.get_all_tools()
adapters = []
total_tools = 0
for server_name, tools in all_tools_dict.items():
logger.debug(f"处理服务器 '{server_name}'{len(tools)} 个工具")
total_tools += len(tools)
for mcp_tool in tools:
try:
adapter = MCPToolAdapter.from_mcp_tool(server_name, mcp_tool)
adapters.append(adapter)
logger.debug(f" ✓ 加载工具: {adapter.name}")
except Exception as e:
logger.error(f" ✗ 创建工具适配器失败: {mcp_tool.name} | 错误: {e}")
continue
logger.info(f"MCP 工具加载完成: 成功 {len(adapters)}/{total_tools}")
return adapters

View File

@@ -145,8 +145,7 @@ class ToolExecutor:
pending_step_two = getattr(self, "_pending_step_two_tools", {})
if pending_step_two:
# 添加第二步工具定义
for step_two_def in pending_step_two.values():
tool_definitions.append(step_two_def)
tool_definitions.extend(list(pending_step_two.values()))
return tool_definitions
@@ -286,10 +285,33 @@ class ToolExecutor:
logger.info(
f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}"
)
# 检查是否是MCP工具
pass
from src.plugin_system.core import component_registry
if component_registry.is_mcp_tool(function_name):
logger.debug(f"{self.log_prefix}识别到 MCP 工具: {function_name}")
# 找到对应的 MCP 工具实例
mcp_tools = component_registry.get_mcp_tools()
mcp_tool = next((t for t in mcp_tools if t.name == function_name), None)
if mcp_tool:
logger.debug(f"{self.log_prefix}执行 MCP 工具 {function_name}")
result = await mcp_tool.execute(function_args)
if result:
logger.debug(f"{self.log_prefix}MCP 工具 {function_name} 执行成功")
return {
"tool_call_id": tool_call.call_id,
"role": "tool",
"name": function_name,
"type": "function",
"content": result.get("content", ""),
}
else:
logger.warning(f"{self.log_prefix}未找到 MCP 工具: {function_name}")
return None
function_args["llm_called"] = True # 标记为LLM调用
# 检查是否是二步工具的第二步调用

View File

@@ -6,4 +6,4 @@ __plugin_meta__ = PluginMetadata(
usage="在 bot_config.toml 中将 asr_provider 设置为 'local' 即可启用",
version="0.1.0",
author="Elysia",
)
)

View File

@@ -1,9 +1,4 @@
import asyncio
import os
import tempfile
from typing import Any
from pathlib import Path
import toml
import whisper
@@ -40,7 +35,7 @@ class LocalASRTool(BaseTool):
model_size = plugin_config.get("whisper", {}).get("model_size", "tiny")
device = plugin_config.get("whisper", {}).get("device", "cpu")
logger.info(f"正在预加载 Whisper ASR 模型: {model_size} ({device})")
loop = asyncio.get_running_loop()
_whisper_model = await loop.run_in_executor(
None, whisper.load_model, model_size, device
@@ -61,10 +56,10 @@ class LocalASRTool(BaseTool):
# 增强的等待逻辑:只要模型还没准备好,就一直等待后台加载任务完成
while _is_loading:
await asyncio.sleep(0.2)
if _whisper_model is None:
return "Whisper 模型加载失败,无法识别语音。"
try:
logger.info(f"开始使用 Whisper 识别音频: {audio_path}")
loop = asyncio.get_running_loop()
@@ -110,6 +105,6 @@ class STTWhisperPlugin(BasePlugin):
), LocalASRTool)]
except Exception as e:
logger.error(f"检查 ASR provider 配置时出错: {e}")
logger.debug("ASR provider is not 'local', whisper plugin's tool is disabled.")
return []

View File

@@ -6,9 +6,9 @@ from pathlib import Path
import toml
from src.chat.utils.self_voice_cache import register_self_voice
from src.common.logger import get_logger
from src.plugin_system.base.base_action import BaseAction, ChatMode
from src.chat.utils.self_voice_cache import register_self_voice
from ..services.manager import get_service

View File

@@ -15,7 +15,7 @@ __plugin_meta__ = PluginMetadata(
"is_built_in": True,
},
# Python包依赖列表
python_dependencies = [ # noqa: RUF012
python_dependencies = [
PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False),
PythonDependency(
package_name="exa_py",