diff --git a/README.md b/README.md index 46bb7174c..97a06ab53 100644 --- a/README.md +++ b/README.md @@ -85,6 +85,7 @@ - 🧠 **拓展记忆系统** - 支持瞬时记忆等多种记忆 - 🎪 **完善的 Event** - 支持动态事件注册和处理器订阅,并实现了聚合结果管理 - 🔍 **内嵌魔改插件** - 内置联网搜索等诸多功能,等你来探索 +- 🔌 **MCP 协议支持** - 集成 Model Context Protocol,支持外部工具服务器连接(仅 Streamable HTTP) - 🌟 **还有更多** - 请参阅详细修改 [commits](https://github.com/MoFox-Studio/MoFox_Bot/commits) diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 01e819463..ea44da9b5 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -19,8 +19,8 @@ from src.plugin_system import ( ToolParamType, register_plugin, ) -from src.plugin_system.base.component_types import InjectionRule,InjectionType from src.plugin_system.base.base_event import HandlerResult +from src.plugin_system.base.component_types import InjectionRule, InjectionType logger = get_logger("hello_world_plugin") diff --git a/pyproject.toml b/pyproject.toml index f9346708f..7b0a544a9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ dependencies = [ "aiosqlite>=0.21.0", "inkfox>=0.1.1", "rjieba>=0.1.13", + "fastmcp>=2.13.0", ] [[tool.uv.index]] @@ -110,6 +111,7 @@ ignore = [ "RUF001", # ambiguous-unicode-character-string "RUF002", # ambiguous-unicode-character-docstring "RUF003", # ambiguous-unicode-character-comment + "PERF203", # try-except-in-loop (我们需要单独处理每个项的错误) ] diff --git a/requirements.txt b/requirements.txt index 6c7edfc80..a2dcb7b81 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ customtkinter dotenv faiss-cpu fastapi +fastmcp rjieba jsonlines maim_message diff --git a/scripts/debug_mcp_tools.py b/scripts/debug_mcp_tools.py new file mode 100644 index 000000000..27ba5b2e5 --- /dev/null +++ b/scripts/debug_mcp_tools.py @@ -0,0 +1,59 @@ +""" +调试 MCP 工具列表获取 + +直接测试 MCP 客户端是否能获取工具 +""" + +import asyncio + +from fastmcp.client import Client, StreamableHttpTransport + + +async def test_direct_connection(): + """直接连接 MCP 服务器并获取工具列表""" + print("=" * 60) + print("直接测试 MCP 服务器连接") + print("=" * 60) + + url = "http://localhost:8000/mcp" + print(f"\n连接到: {url}") + + try: + # 创建传输层 + transport = StreamableHttpTransport(url) + print("✓ 传输层创建成功") + + # 创建客户端 + async with Client(transport) as client: + print("✓ 客户端连接成功") + + # 获取工具列表 + print("\n正在获取工具列表...") + tools_result = await client.list_tools() + + print(f"\n获取结果类型: {type(tools_result)}") + print(f"结果内容: {tools_result}") + + # 检查是否有 tools 属性 + if hasattr(tools_result, "tools"): + tools = tools_result.tools + print(f"\n✓ 找到 tools 属性,包含 {len(tools)} 个工具") + + for i, tool in enumerate(tools, 1): + print(f"\n工具 {i}:") + print(f" 名称: {tool.name}") + print(f" 描述: {tool.description}") + if hasattr(tool, "inputSchema"): + print(f" 参数 Schema: {tool.inputSchema}") + else: + print("\n✗ 结果中没有 tools 属性") + print(f"可用属性: {dir(tools_result)}") + + except Exception as e: + print(f"\n✗ 连接失败: {e}") + import traceback + traceback.print_exc() + + +if __name__ == "__main__": + asyncio.run(test_direct_connection()) diff --git a/scripts/simple_mcp_server.py b/scripts/simple_mcp_server.py new file mode 100644 index 000000000..78e6391bf --- /dev/null +++ b/scripts/simple_mcp_server.py @@ -0,0 +1,142 @@ +""" +简单的 MCP 测试服务器 + +使用 fastmcp 创建一个简单的 MCP 服务器供测试使用。 + +安装依赖: + pip install fastmcp uvicorn + +运行服务器: + python scripts/simple_mcp_server.py + +服务器将在 http://localhost:8000/mcp 提供 MCP 服务 +""" + +from fastmcp import FastMCP + +# 创建 MCP 服务器实例 +mcp = FastMCP("Demo Server") + + +@mcp.tool() +def add(a: int, b: int) -> int: + """将两个数字相加 + + Args: + a: 第一个数字 + b: 第二个数字 + + Returns: + 两个数字的和 + """ + return a + b + + +@mcp.tool() +def multiply(a: float, b: float) -> float: + """将两个数字相乘 + + Args: + a: 第一个数字 + b: 第二个数字 + + Returns: + 两个数字的乘积 + """ + return a * b + + +@mcp.tool() +def get_weather(city: str) -> str: + """获取指定城市的天气信息(模拟) + + Args: + city: 城市名称 + + Returns: + 天气信息字符串 + """ + # 这是一个模拟实现 + weather_data = { + "beijing": "北京今天晴朗,温度 20°C", + "shanghai": "上海今天多云,温度 18°C", + "guangzhou": "广州今天有雨,温度 25°C", + } + + city_lower = city.lower() + return weather_data.get( + city_lower, + f"{city}的天气信息暂不可用" + ) + + +@mcp.tool() +def echo(message: str, repeat: int = 1) -> str: + """重复输出一条消息 + + Args: + message: 要重复的消息 + repeat: 重复次数,默认为 1 + + Returns: + 重复后的消息 + """ + return (message + "\n") * repeat + + +@mcp.tool() +def check_prime(number: int) -> bool: + """检查一个数字是否为质数 + + Args: + number: 要检查的数字 + + Returns: + 如果是质数返回 True,否则返回 False + """ + if number < 2: + return False + + for i in range(2, int(number ** 0.5) + 1): + if number % i == 0: + return False + + return True + + +if __name__ == "__main__": + print("=" * 60) + print("简单 MCP 测试服务器") + print("=" * 60) + print("\n服务器配置:") + print(" - 传输协议: Streamable HTTP") + print(" - 地址: http://localhost:8000/mcp") + print("\n可用工具:") + print(" 1. add(a: int, b: int) -> int") + print(" 2. multiply(a: float, b: float) -> float") + print(" 3. get_weather(city: str) -> str") + print(" 4. echo(message: str, repeat: int = 1) -> str") + print(" 5. check_prime(number: int) -> bool") + print("\n配置示例 (config/mcp.json):") + print(""" +{ + "$schema": "./mcp.schema.json", + "mcpServers": { + "demo_server": { + "enabled": true, + "transport": { + "type": "streamable-http", + "url": "http://localhost:8000/mcp" + }, + "timeout": 30 + } + } +} + """) + print("=" * 60) + print("\n正在启动服务器...") + print("请参考 fastmcp 官方文档了解如何运行此服务器。") + print("文档: https://github.com/jlowin/fastmcp") + print("\n基本命令:") + print(" fastmcp run simple_mcp_server:mcp") + print("=" * 60) diff --git a/scripts/test/demo_mcp_server.py b/scripts/test/demo_mcp_server.py new file mode 100644 index 000000000..723d43ca3 --- /dev/null +++ b/scripts/test/demo_mcp_server.py @@ -0,0 +1,15 @@ +from fastmcp import FastMCP + +app = FastMCP( + name="Demo MCP Server", + streamable_http_path="/mcp" +) + +@app.tool() +async def echo_tool(input: str) -> str: + """一个简单的回声工具""" + return f"Echo: {input}" + +if __name__ == "__main__": + app.run(host="0.0.0.0", port=8000, transport="streamable-http" + ) diff --git a/scripts/test_mcp_integration.py b/scripts/test_mcp_integration.py new file mode 100644 index 000000000..b5cfa2b28 --- /dev/null +++ b/scripts/test_mcp_integration.py @@ -0,0 +1,190 @@ +""" +MCP 集成测试脚本 + +测试 MCP 客户端连接、工具列表获取和工具调用功能 +""" + +import asyncio +import sys +from pathlib import Path + +# 添加项目根目录到 Python 路径 +project_root = Path(__file__).parent.parent +sys.path.insert(0, str(project_root)) + +from src.common.logger import get_logger +from src.plugin_system.core.component_registry import ComponentRegistry +from src.plugin_system.core.mcp_client_manager import MCPClientManager + +logger = get_logger("test_mcp_integration") + + +async def test_mcp_client_manager(): + """测试 MCPClientManager 基本功能""" + print("\n" + "="*60) + print("测试 1: MCPClientManager 连接和工具列表") + print("="*60) + + try: + # 初始化 MCP 客户端管理器 + manager = MCPClientManager() + await manager.initialize() + + print("\n✓ MCP 客户端管理器初始化成功") + print(f"已连接服务器数量: {len(manager.clients)}") + + # 获取所有工具 + tools = await manager.get_all_tools() + print(f"\n获取到 {len(tools)} 个 MCP 工具:") + + for tool in tools: + print(f"\n 工具: {tool}") + # 注意: 这里 tool 是字符串形式的工具名称 + # 如果需要工具详情,需要从其他地方获取 + + return manager, tools + + except Exception as e: + print(f"\n✗ 测试失败: {e}") + logger.exception("MCPClientManager 测试失败") + return None, [] + + +async def test_tool_call(manager: MCPClientManager, tools): + """测试工具调用""" + print("\n" + "="*60) + print("测试 2: MCP 工具调用") + print("="*60) + + if not tools: + print("\n⚠ 没有可用的工具进行测试") + return + + try: + # 工具列表测试已在第一个测试中完成 + print("\n✓ 工具列表获取成功") + print(f"可用工具数量: {len(tools)}") + + except Exception as e: + print(f"\n✗ 工具调用测试失败: {e}") + logger.exception("工具调用测试失败") + + +async def test_component_registry_integration(): + """测试 ComponentRegistry 集成""" + print("\n" + "="*60) + print("测试 3: ComponentRegistry MCP 工具集成") + print("="*60) + + try: + registry = ComponentRegistry() + + # 加载 MCP 工具 + await registry.load_mcp_tools() + + # 获取 MCP 工具 + mcp_tools = registry.get_mcp_tools() + print(f"\n✓ ComponentRegistry 加载了 {len(mcp_tools)} 个 MCP 工具") + + for tool in mcp_tools: + print(f"\n 工具: {tool.name}") + print(f" 描述: {tool.description}") + print(f" 参数数量: {len(tool.parameters)}") + + # 测试 is_mcp_tool 方法 + is_mcp = registry.is_mcp_tool(tool.name) + print(f" is_mcp_tool 检测: {'✓' if is_mcp else '✗'}") + + return mcp_tools + + except Exception as e: + print(f"\n✗ ComponentRegistry 集成测试失败: {e}") + logger.exception("ComponentRegistry 集成测试失败") + return [] + + +async def test_tool_execution(mcp_tools): + """测试通过适配器执行工具""" + print("\n" + "="*60) + print("测试 4: MCPToolAdapter 工具执行") + print("="*60) + + if not mcp_tools: + print("\n⚠ 没有可用的 MCP 工具进行测试") + return + + try: + # 选择第一个工具测试 + test_tool = mcp_tools[0] + print(f"\n测试工具: {test_tool.name}") + + # 构建测试参数 + test_args = {} + for param_name, param_type, param_desc, is_required, enum_values in test_tool.parameters: + if is_required: + # 根据类型提供默认值 + from src.llm_models.payload_content.tool_option import ToolParamType + + if param_type == ToolParamType.STRING: + test_args[param_name] = "test_value" + elif param_type == ToolParamType.INTEGER: + test_args[param_name] = 1 + elif param_type == ToolParamType.FLOAT: + test_args[param_name] = 1.0 + elif param_type == ToolParamType.BOOLEAN: + test_args[param_name] = True + + print(f"测试参数: {test_args}") + + # 执行工具 + result = await test_tool.execute(test_args) + + if result: + print("\n✓ 工具执行成功") + print(f"结果类型: {result.get('type')}") + print(f"结果内容: {result.get('content', '')[:200]}...") # 只显示前200字符 + else: + print("\n✗ 工具执行失败,返回 None") + + except Exception as e: + print(f"\n✗ 工具执行测试失败: {e}") + logger.exception("工具执行测试失败") + + +async def main(): + """主测试流程""" + print("\n" + "="*60) + print("MCP 集成测试") + print("="*60) + + try: + # 测试 1: MCPClientManager 基本功能 + manager, tools = await test_mcp_client_manager() + + if manager: + # 测试 2: 工具调用 + await test_tool_call(manager, tools) + + # 测试 3: ComponentRegistry 集成 + mcp_tools = await test_component_registry_integration() + + # 测试 4: 工具执行 + await test_tool_execution(mcp_tools) + + # 关闭连接 + await manager.close() + print("\n✓ MCP 客户端连接已关闭") + + print("\n" + "="*60) + print("测试完成") + print("="*60 + "\n") + + except KeyboardInterrupt: + print("\n\n测试被用户中断") + except Exception as e: + print(f"\n测试过程中发生错误: {e}") + logger.exception("测试失败") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/api/statistic_router.py b/src/api/statistic_router.py index 7c78082f2..feda3e911 100644 --- a/src/api/statistic_router.py +++ b/src/api/statistic_router.py @@ -219,4 +219,4 @@ async def get_llm_stats( raise e except Exception as e: logger.error(f"获取LLM统计信息失败: {e}") - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index cc06d995a..3422ca296 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -878,7 +878,7 @@ class MemorySystem: except Exception as e: logger.warning(f"检索瞬时记忆失败: {e}", exc_info=True) - + # 最终截断 final_memories = final_memories[:effective_limit] diff --git a/src/chat/memory_system/message_collection_processor.py b/src/chat/memory_system/message_collection_processor.py index 756250dc4..b930aa3c9 100644 --- a/src/chat/memory_system/message_collection_processor.py +++ b/src/chat/memory_system/message_collection_processor.py @@ -72,4 +72,4 @@ class MessageCollectionProcessor: "active_buffers": len(self.message_buffers), "total_buffered_messages": total_buffered_messages, "buffer_capacity_per_chat": self.buffer_size, - } \ No newline at end of file + } diff --git a/src/chat/memory_system/message_collection_storage.py b/src/chat/memory_system/message_collection_storage.py index 22f4c75f3..d122ebed5 100644 --- a/src/chat/memory_system/message_collection_storage.py +++ b/src/chat/memory_system/message_collection_storage.py @@ -3,7 +3,6 @@ 专用于存储和检索消息集合,以提供即时上下文。 """ -import asyncio import time from typing import Any @@ -125,7 +124,7 @@ class MessageCollectionStorage: if results and results.get("ids") and results["ids"][0]: for metadata in results["metadatas"][0]: collections.append(MessageCollection.from_dict(metadata)) - + return collections except Exception as e: logger.error(f"检索相关消息集合失败: {e}", exc_info=True) @@ -163,7 +162,7 @@ class MessageCollectionStorage: # 格式化消息集合为 prompt 上下文 final_context = "\n\n---\n\n".join(context_parts) + "\n\n---" - + logger.info(f"🔗 为查询 '{query_text[:50]}...' 在聊天 '{chat_id}' 中找到 {len(collections)} 个相关消息集合上下文") return f"\n{final_context}\n" @@ -192,4 +191,4 @@ class MessageCollectionStorage: } except Exception as e: logger.error(f"获取消息集合存储统计失败: {e}") - return {} \ No newline at end of file + return {} diff --git a/src/chat/message_receive/message.py b/src/chat/message_receive/message.py index f03994a58..90805ee12 100644 --- a/src/chat/message_receive/message.py +++ b/src/chat/message_receive/message.py @@ -9,10 +9,10 @@ from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo from rich.traceback import install from src.chat.message_receive.chat_stream import ChatStream +from src.chat.utils.self_voice_cache import consume_self_voice_text from src.chat.utils.utils_image import get_image_manager from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available from src.chat.utils.utils_voice import get_voice_text -from src.chat.utils.self_voice_cache import consume_self_voice_text from src.common.logger import get_logger from src.config.config import global_config @@ -212,7 +212,7 @@ class MessageRecv(Message): return f"[语音:{cached_text}]" else: logger.warning("机器人自身语音消息缓存未命中,将回退到标准语音识别。") - + # 标准语音识别流程 (也作为缓存未命中的后备方案) if isinstance(segment.data, str): return await get_voice_text(segment.data) @@ -370,7 +370,7 @@ class MessageRecvS4U(MessageRecv): self.is_picid = False self.is_emoji = False self.is_voice = True - + # 检查消息是否由机器人自己发送 # 检查消息是否由机器人自己发送 if self.message_info and self.message_info.user_info and str(self.message_info.user_info.user_id) == str(global_config.bot.qq_account): diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 8c756b5e3..40cac9ed6 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -307,7 +307,7 @@ class Prompt: @staticmethod def _process_escaped_braces(template) -> str: - """预处理模板,将 `\{` 和 `\}` 替换为临时标记.""" + r"""预处理模板,将 `\{` 和 `\}` 替换为临时标记.""" if isinstance(template, list): template = "\n".join(str(item) for item in template) elif not isinstance(template, str): diff --git a/src/chat/utils/prompt_component_manager.py b/src/chat/utils/prompt_component_manager.py index c1fb92e13..3c68630d1 100644 --- a/src/chat/utils/prompt_component_manager.py +++ b/src/chat/utils/prompt_component_manager.py @@ -1,6 +1,5 @@ import asyncio import re -from typing import Type from src.chat.utils.prompt_params import PromptParameters from src.common.logger import get_logger @@ -21,7 +20,7 @@ class PromptComponentManager: 3. 提供一个接口,以便在构建核心Prompt时,能够获取并执行所有相关的组件。 """ - def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, Type[BasePrompt]]]: + def _get_rules_for(self, target_prompt_name: str) -> list[tuple[InjectionRule, type[BasePrompt]]]: """ 获取指定目标Prompt的所有注入规则及其关联的组件类。 diff --git a/src/chat/utils/self_voice_cache.py b/src/chat/utils/self_voice_cache.py index 90cf80344..d94bebc52 100644 --- a/src/chat/utils/self_voice_cache.py +++ b/src/chat/utils/self_voice_cache.py @@ -6,15 +6,14 @@ 避免不必要的自我语音识别。 """ import hashlib -from typing import Dict # 一个简单的内存缓存,用于将机器人自己发送的语音消息映射到其原始文本。 # 键是语音base64内容的SHA256哈希值。 -_self_voice_cache: Dict[str, str] = {} +_self_voice_cache: dict[str, str] = {} def get_voice_key(base64_content: str) -> str: """为语音内容生成一个一致的键。""" - return hashlib.sha256(base64_content.encode('utf-8')).hexdigest() + return hashlib.sha256(base64_content.encode("utf-8")).hexdigest() def register_self_voice(base64_content: str, text: str): """ @@ -39,4 +38,4 @@ def consume_self_voice_text(base64_content: str) -> str | None: str | None: 如果找到,则返回原始文本,否则返回None。 """ key = get_voice_key(base64_content) - return _self_voice_cache.pop(key, None) \ No newline at end of file + return _self_voice_cache.pop(key, None) diff --git a/src/chat/utils/utils_voice.py b/src/chat/utils/utils_voice.py index 483acefd2..f74359f18 100644 --- a/src/chat/utils/utils_voice.py +++ b/src/chat/utils/utils_voice.py @@ -19,10 +19,11 @@ async def get_voice_text(voice_base64: str) -> str: # 如果选择本地识别 if asr_provider == "local": - from src.plugin_system.apis import tool_api - import tempfile import base64 import os + import tempfile + + from src.plugin_system.apis import tool_api local_asr_tool = tool_api.get_tool_instance("local_asr") if not local_asr_tool: @@ -39,8 +40,8 @@ async def get_voice_text(voice_base64: str) -> str: text = await local_asr_tool.execute(function_args={"audio_path": audio_path}) if "失败" in text or "出错" in text or "错误" in text: logger.warning(f"本地语音识别失败: {text}") - return f"[语音(本地识别失败)]" - + return "[语音(本地识别失败)]" + logger.info(f"本地语音识别成功: {text}") return f"[语音] {text}" diff --git a/src/main.py b/src/main.py index 905990e2c..34f236085 100644 --- a/src/main.py +++ b/src/main.py @@ -4,9 +4,10 @@ import signal import sys import time import traceback +from collections.abc import Callable, Coroutine from functools import partial from random import choices -from typing import Any, Callable, Coroutine +from typing import Any from maim_message import MessageServer from rich.traceback import install @@ -24,8 +25,8 @@ from src.config.config import global_config from src.individuality.individuality import Individuality, get_individuality from src.manager.async_task_manager import async_task_manager from src.mood.mood_manager import mood_manager -from src.plugin_system.base.component_types import EventType from src.plugin_system.base.base_interest_calculator import BaseInterestCalculator +from src.plugin_system.base.component_types import EventType from src.plugin_system.core.event_manager import event_manager from src.plugin_system.core.plugin_manager import plugin_manager from src.schedule.monthly_plan_manager import monthly_plan_manager @@ -418,7 +419,7 @@ MoFox_Bot(第三方修改版) # 处理所有缓存的事件订阅(插件加载完成后) event_manager.process_all_pending_subscriptions() - + # 初始化表情管理器 get_emoji_manager().initialize() logger.info("表情包管理器初始化成功") diff --git a/src/mcp_integration/__init__.py b/src/mcp_integration/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mcp_integration/client_manager.py b/src/mcp_integration/client_manager.py new file mode 100644 index 000000000..5f282702b --- /dev/null +++ b/src/mcp_integration/client_manager.py @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/src/mcp_integration/config_loader.py b/src/mcp_integration/config_loader.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/mcp_integration/tool_wrapper.py b/src/mcp_integration/tool_wrapper.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/plugin_system/apis/tool_api.py b/src/plugin_system/apis/tool_api.py index 163b67385..01ce4c7dc 100644 --- a/src/plugin_system/apis/tool_api.py +++ b/src/plugin_system/apis/tool_api.py @@ -1,4 +1,5 @@ from typing import Any + from src.common.logger import get_logger from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.component_types import ComponentType @@ -22,7 +23,7 @@ def get_tool_instance(tool_name: str) -> BaseTool | None: def get_llm_available_tool_definitions() -> list[dict[str, Any]]: - """获取LLM可用的工具定义列表 + """获取LLM可用的工具定义列表(包括 MCP 工具) Returns: list[dict[str, Any]]: 工具定义列表 @@ -31,6 +32,8 @@ def get_llm_available_tool_definitions() -> list[dict[str, Any]]: llm_available_tools = component_registry.get_llm_available_tools() tool_definitions = [] + + # 获取常规工具定义 for tool_name, tool_class in llm_available_tools.items(): try: # 调用类方法 get_tool_definition 获取定义 @@ -38,5 +41,18 @@ def get_llm_available_tool_definitions() -> list[dict[str, Any]]: tool_definitions.append(definition) except Exception as e: logger.error(f"获取工具 {tool_name} 的定义失败: {e}") + + # 获取 MCP 工具定义 + try: + mcp_tools = component_registry.get_mcp_tools() + for mcp_tool in mcp_tools: + try: + definition = mcp_tool.get_tool_definition() + tool_definitions.append(definition) + except Exception as e: + logger.error(f"获取 MCP 工具 {mcp_tool.name} 的定义失败: {e}") + except Exception as e: + logger.debug(f"获取 MCP 工具列表失败(可能未启用): {e}") + return tool_definitions diff --git a/src/plugin_system/core/component_registry.py b/src/plugin_system/core/component_registry.py index 22ea50644..5f6e82825 100644 --- a/src/plugin_system/core/component_registry.py +++ b/src/plugin_system/core/component_registry.py @@ -87,6 +87,10 @@ class ComponentRegistry: self._tool_registry: dict[str, type["BaseTool"]] = {} # 工具名 -> 工具类 self._llm_available_tools: dict[str, type["BaseTool"]] = {} # llm可用的工具名 -> 工具类 + # MCP 工具注册表(运行时动态加载) + self._mcp_tools: list["BaseTool"] = [] # MCP 工具适配器实例列表 + self._mcp_tools_loaded = False # MCP 工具是否已加载 + # EventHandler特定注册表 self._event_handler_registry: dict[str, type["BaseEventHandler"]] = {} """event_handler名 -> event_handler类""" @@ -892,6 +896,7 @@ class ComponentRegistry: "action_components": action_components, "command_components": command_components, "tool_components": tool_components, + "mcp_tools": len(self._mcp_tools), "event_handlers": events_handlers, "plus_command_components": plus_command_components, "chatter_components": chatter_components, @@ -905,6 +910,34 @@ class ComponentRegistry: "enabled_plugins": len([p for p in self._plugins.values() if p.enabled]), } + # === MCP 工具相关方法 === + + async def load_mcp_tools(self) -> None: + """加载 MCP 工具(异步方法)""" + if self._mcp_tools_loaded: + logger.debug("MCP 工具已加载,跳过") + return + + try: + from .mcp_tool_adapter import load_mcp_tools_as_adapters + + logger.info("开始加载 MCP 工具...") + self._mcp_tools = await load_mcp_tools_as_adapters() + self._mcp_tools_loaded = True + logger.info(f"MCP 工具加载完成,共 {len(self._mcp_tools)} 个工具") + except Exception as e: + logger.error(f"加载 MCP 工具失败: {e}") + self._mcp_tools = [] + self._mcp_tools_loaded = True # 标记为已尝试加载,避免重复尝试 + + def get_mcp_tools(self) -> list["BaseTool"]: + """获取所有 MCP 工具适配器实例""" + return self._mcp_tools.copy() + + def is_mcp_tool(self, tool_name: str) -> bool: + """检查工具名是否为 MCP 工具""" + return tool_name.startswith("mcp_") + # === 组件移除相关 === async def unregister_plugin(self, plugin_name: str) -> bool: diff --git a/src/plugin_system/core/mcp_client_manager.py b/src/plugin_system/core/mcp_client_manager.py new file mode 100644 index 000000000..02bd0eaa3 --- /dev/null +++ b/src/plugin_system/core/mcp_client_manager.py @@ -0,0 +1,266 @@ +""" +MCP Client Manager + +管理多个 MCP (Model Context Protocol) 客户端连接,支持动态加载和工具注册 +""" + +import asyncio +import json +from pathlib import Path +from typing import Any + +import mcp.types +from fastmcp.client import Client, StreamableHttpTransport + +from src.common.logger import get_logger + +logger = get_logger("mcp_client_manager") + + +class MCPServerConfig: + """单个 MCP 服务器的配置""" + + def __init__(self, name: str, config: dict[str, Any]): + self.name = name + self.description = config.get("description", "") + self.enabled = config.get("enabled", True) + self.transport_config = config["transport"] + self.auth_config = config.get("auth") + self.timeout = config.get("timeout", 30) + self.retry_config = config.get("retry", {"max_retries": 3, "retry_delay": 1}) + + def __repr__(self): + return f"" + + +class MCPClientManager: + """ + MCP 客户端管理器 + + 负责: + 1. 从配置文件加载 MCP 服务器配置 + 2. 建立和维护与 MCP 服务器的连接 + 3. 获取可用的工具列表 + 4. 执行工具调用 + """ + + def __init__(self, config_path: str | Path | None = None): + """ + 初始化 MCP 客户端管理器 + + Args: + config_path: mcp.json 配置文件路径,默认为 config/mcp.json + """ + if config_path is None: + # 默认配置路径 + + config_path = Path(__file__).parent.parent.parent.parent / "config" / "mcp.json" + + self.config_path = Path(config_path) + self.servers: dict[str, MCPServerConfig] = {} + self.clients: dict[str, Client] = {} + self._initialized = False + self._lock = asyncio.Lock() + + logger.info(f"MCP 客户端管理器初始化,配置文件: {self.config_path}") + + def load_config(self) -> dict[str, MCPServerConfig]: + """ + 从配置文件加载 MCP 服务器配置 + + Returns: + Dict[str, MCPServerConfig]: 服务器名称 -> 配置对象 + """ + if not self.config_path.exists(): + logger.warning(f"MCP 配置文件不存在: {self.config_path}") + return {} + + try: + with open(self.config_path, encoding="utf-8") as f: + config_data = json.load(f) + + servers = {} + mcp_servers = config_data.get("mcpServers", {}) + + for server_name, server_config in mcp_servers.items(): + try: + server = MCPServerConfig(server_name, server_config) + servers[server_name] = server + logger.debug(f"加载 MCP 服务器配置: {server}") + except Exception as e: + logger.error(f"加载服务器配置 '{server_name}' 失败: {e}") + continue + + logger.info(f"成功加载 {len(servers)} 个 MCP 服务器配置") + return servers + + except json.JSONDecodeError as e: + logger.error(f"解析 MCP 配置文件失败: {e}") + return {} + except Exception as e: + logger.error(f"读取 MCP 配置文件失败: {e}") + return {} + + async def initialize(self) -> None: + """ + 初始化所有启用的 MCP 客户端连接 + + 这个方法会: + 1. 加载配置文件 + 2. 为每个启用的服务器创建客户端 + 3. 建立连接并验证 + """ + async with self._lock: + if self._initialized: + logger.debug("MCP 客户端管理器已初始化,跳过") + return + + logger.info("开始初始化 MCP 客户端连接...") + + # 加载配置 + self.servers = self.load_config() + + if not self.servers: + logger.warning("没有找到任何 MCP 服务器配置") + self._initialized = True + return + + # 为每个启用的服务器创建客户端 + for server_name, server_config in self.servers.items(): + if not server_config.enabled: + logger.debug(f"服务器 '{server_name}' 未启用,跳过") + continue + + try: + client = await self._create_client(server_config) + self.clients[server_name] = client + logger.info(f"✅ MCP 服务器 '{server_name}' 连接成功") + except Exception as e: + logger.error(f"❌ 连接 MCP 服务器 '{server_name}' 失败: {e}") + continue + + self._initialized = True + logger.info(f"MCP 客户端管理器初始化完成,成功连接 {len(self.clients)}/{len(self.servers)} 个服务器") + + async def _create_client(self, server_config: MCPServerConfig) -> Client: + """ + 根据配置创建 MCP 客户端 + + Args: + server_config: 服务器配置 + + Returns: + Client: 已连接的 MCP 客户端 + """ + transport_type = server_config.transport_config.get("type", "streamable-http") + + if transport_type == "streamable-http": + url = server_config.transport_config["url"] + transport = StreamableHttpTransport(url) + + # 设置认证(如果有) + if server_config.auth_config: + auth_type = server_config.auth_config.get("type") + if auth_type == "bearer": + from fastmcp.client.auth import BearerAuth + + token = server_config.auth_config.get("token", "") + transport._set_auth(BearerAuth(token)) + + client = Client(transport, timeout=server_config.timeout) + + elif transport_type == "sse": + from fastmcp.client import SSETransport + + url = server_config.transport_config["url"] + client = Client(SSETransport(url), timeout=server_config.timeout) + + else: + raise ValueError(f"不支持的传输类型: {transport_type}") + + # 进入客户端上下文(建立连接) + await client.__aenter__() + + return client + + async def get_all_tools(self) -> dict[str, list[mcp.types.Tool]]: + """ + 获取所有 MCP 服务器提供的工具列表 + + Returns: + Dict[str, List[mcp.types.Tool]]: 服务器名称 -> 工具列表 + """ + if not self._initialized: + await self.initialize() + + all_tools = {} + + for server_name, client in self.clients.items(): + try: + # fastmcp 的 list_tools() 直接返回 List[Tool],不是包含 tools 属性的对象 + tools = await client.list_tools() + all_tools[server_name] = tools + logger.debug(f"从服务器 '{server_name}' 获取到 {len(tools)} 个工具") + except Exception as e: + logger.error(f"从服务器 '{server_name}' 获取工具列表失败: {e}") + all_tools[server_name] = [] + + return all_tools + + async def call_tool( + self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None + ) -> Any: + """ + 调用指定 MCP 服务器的工具 + + Args: + server_name: 服务器名称 + tool_name: 工具名称 + arguments: 工具参数 + + Returns: + Any: 工具执行结果(CallToolResult 的兼容类型) + """ + if not self._initialized: + await self.initialize() + + if server_name not in self.clients: + raise ValueError(f"MCP 服务器 '{server_name}' 未连接") + + client = self.clients[server_name] + + try: + logger.debug(f"调用 MCP 工具: {server_name}.{tool_name} | 参数: {arguments}") + result = await client.call_tool(tool_name, arguments or {}) + logger.debug(f"MCP 工具调用成功: {server_name}.{tool_name}") + return result + + except Exception as e: + logger.error(f"MCP 工具调用失败: {server_name}.{tool_name} | 错误: {e}") + raise + + async def close(self) -> None: + """关闭所有 MCP 客户端连接""" + async with self._lock: + if not self._initialized: + return + + logger.info("关闭所有 MCP 客户端连接...") + + for server_name, client in self.clients.items(): + try: + await client.__aexit__(None, None, None) + logger.debug(f"已关闭 MCP 服务器 '{server_name}' 的连接") + except Exception as e: + logger.error(f"关闭服务器 '{server_name}' 连接失败: {e}") + + self.clients.clear() + self._initialized = False + logger.info("所有 MCP 客户端连接已关闭") + + def __repr__(self): + return f"" + + +# 全局单例 +mcp_client_manager = MCPClientManager() diff --git a/src/plugin_system/core/mcp_tool_adapter.py b/src/plugin_system/core/mcp_tool_adapter.py new file mode 100644 index 000000000..c971022eb --- /dev/null +++ b/src/plugin_system/core/mcp_tool_adapter.py @@ -0,0 +1,249 @@ +""" +MCP Tool Adapter + +将 MCP 工具适配为 BaseTool,使其能够被插件系统识别和调用 +""" + +from typing import Any, ClassVar + +import mcp.types + +from src.common.logger import get_logger +from src.plugin_system.base.base_tool import BaseTool +from src.plugin_system.base.component_types import ToolParamType + +from .mcp_client_manager import mcp_client_manager + +logger = get_logger("mcp_tool_adapter") + + +class MCPToolAdapter(BaseTool): + """ + MCP 工具适配器 + + 将 MCP 协议的工具适配为 BaseTool,使其能够: + 1. 被插件系统识别和注册 + 2. 被 LLM 调用 + 3. 参与工具缓存机制 + """ + + # 类级别默认值,使用 ClassVar 标注 + available_for_llm: ClassVar[bool] = True + + def __init__(self, server_name: str, mcp_tool: mcp.types.Tool, plugin_config: dict | None = None): + """ + 初始化 MCP 工具适配器 + + Args: + server_name: MCP 服务器名称 + mcp_tool: MCP 工具对象 + plugin_config: 插件配置(可选) + """ + super().__init__(plugin_config) + + self.server_name = server_name + self.mcp_tool = mcp_tool + + # 设置实例属性 + self.name = f"mcp_{server_name}_{mcp_tool.name}" + self.description = mcp_tool.description or f"MCP tool from {server_name}" + + # 转换参数定义 + self.parameters = self._convert_parameters(mcp_tool.inputSchema) + + logger.debug(f"创建 MCP 工具适配器: {self.name}") + + def _convert_parameters( + self, input_schema: dict[str, Any] | None + ) -> list[tuple[str, ToolParamType, str, bool, list[str] | None]]: + """ + 将 MCP 工具的 JSON Schema 参数转换为 BaseTool 参数格式 + + Args: + input_schema: MCP 工具的 inputSchema (JSON Schema) + + Returns: + List[Tuple]: BaseTool 参数格式列表 + """ + if not input_schema: + return [] + + parameters = [] + + # JSON Schema 通常有 properties 和 required 字段 + properties = input_schema.get("properties", {}) + required_fields = input_schema.get("required", []) + + for param_name, param_def in properties.items(): + # 获取参数类型 + param_type_str = param_def.get("type", "string") + param_type = self._map_json_type_to_tool_param_type(param_type_str) + + # 获取参数描述 + param_desc = param_def.get("description", f"Parameter {param_name}") + + # 判断是否必填 + is_required = param_name in required_fields + + # 获取枚举值(如果有) + enum_values = param_def.get("enum") + + parameters.append((param_name, param_type, param_desc, is_required, enum_values)) + + return parameters + + @staticmethod + def _map_json_type_to_tool_param_type(json_type: str) -> ToolParamType: + """ + 将 JSON Schema 类型映射到 ToolParamType + + Args: + json_type: JSON Schema 类型字符串 + + Returns: + ToolParamType: 对应的工具参数类型 + """ + type_mapping = { + "string": ToolParamType.STRING, + "integer": ToolParamType.INTEGER, + "number": ToolParamType.FLOAT, + "boolean": ToolParamType.BOOLEAN, + } + return type_mapping.get(json_type, ToolParamType.STRING) + + async def execute(self, function_args: dict[str, Any]) -> dict[str, Any]: + """ + 执行 MCP 工具调用 + + Args: + function_args: 工具调用参数 + + Returns: + Dict: 工具执行结果 + """ + try: + logger.debug(f"执行 MCP 工具: {self.name} | 服务器: {self.server_name} | 参数: {function_args}") + + # 移除 llm_called 标记(这是内部使用的) + clean_args = {k: v for k, v in function_args.items() if k != "llm_called"} + + # 调用 MCP 客户端管理器执行工具 + result = await mcp_client_manager.call_tool( + server_name=self.server_name, tool_name=self.mcp_tool.name, arguments=clean_args + ) + + # 解析结果 + return self._format_result(result) + + except Exception as e: + logger.error(f"MCP 工具执行失败: {self.name} | 错误: {e}") + return { + "type": "error", + "content": f"MCP 工具调用失败: {e!s}", + "id": self.name, + } + + def _format_result(self, result: mcp.types.CallToolResult) -> dict[str, Any]: + """ + 格式化 MCP 工具执行结果为标准格式 + + Args: + result: MCP CallToolResult 对象 + + Returns: + Dict: 标准化的工具执行结果 + """ + # MCP 结果包含 content 列表 + if not result.content: + return { + "type": "mcp_result", + "content": "", + "id": self.name, + } + + # 提取所有内容 + content_parts = [] + for content_item in result.content: + # 根据内容类型提取文本 + content_type = getattr(content_item, "type", None) + + if content_type == "text": + # TextContent 类型 + text = getattr(content_item, "text", "") + content_parts.append(text) + elif content_type == "image": + # ImageContent 类型 + data = getattr(content_item, "data", b"") + content_parts.append(f"[Image data: {len(data)} bytes]") + elif content_type == "audio": + # AudioContent 类型 + data = getattr(content_item, "data", b"") + content_parts.append(f"[Audio data: {len(data)} bytes]") + else: + # 尝试提取 text 或 data 属性 + text = getattr(content_item, "text", None) + if text is not None: + content_parts.append(text) + else: + data = getattr(content_item, "data", None) + if data is not None: + data_len = len(data) if hasattr(data, "__len__") else "unknown" + content_parts.append(f"[Binary data: {data_len} bytes]") + else: + content_parts.append(str(content_item)) + + return { + "type": "mcp_result", + "content": "\n".join(content_parts), + "id": self.name, + "is_error": getattr(result, "isError", False), + } + + @classmethod + def from_mcp_tool(cls, server_name: str, mcp_tool: mcp.types.Tool) -> "MCPToolAdapter": + """ + 从 MCP 工具对象创建适配器实例 + + Args: + server_name: MCP 服务器名称 + mcp_tool: MCP 工具对象 + + Returns: + MCPToolAdapter: 工具适配器实例 + """ + return cls(server_name, mcp_tool) + + +async def load_mcp_tools_as_adapters() -> list[MCPToolAdapter]: + """ + 加载所有 MCP 工具并转换为适配器 + + Returns: + List[MCPToolAdapter]: 工具适配器列表 + """ + logger.info("开始加载 MCP 工具...") + + # 初始化 MCP 客户端管理器 + await mcp_client_manager.initialize() + + # 获取所有工具 + all_tools_dict = await mcp_client_manager.get_all_tools() + + adapters = [] + total_tools = 0 + + for server_name, tools in all_tools_dict.items(): + logger.debug(f"处理服务器 '{server_name}' 的 {len(tools)} 个工具") + total_tools += len(tools) + + for mcp_tool in tools: + try: + adapter = MCPToolAdapter.from_mcp_tool(server_name, mcp_tool) + adapters.append(adapter) + logger.debug(f" ✓ 加载工具: {adapter.name}") + except Exception as e: + logger.error(f" ✗ 创建工具适配器失败: {mcp_tool.name} | 错误: {e}") + continue + + logger.info(f"MCP 工具加载完成: 成功 {len(adapters)}/{total_tools} 个") + return adapters diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 5d07c2cd6..44d47eb9f 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -145,8 +145,7 @@ class ToolExecutor: pending_step_two = getattr(self, "_pending_step_two_tools", {}) if pending_step_two: # 添加第二步工具定义 - for step_two_def in pending_step_two.values(): - tool_definitions.append(step_two_def) + tool_definitions.extend(list(pending_step_two.values())) return tool_definitions @@ -286,10 +285,33 @@ class ToolExecutor: logger.info( f"{self.log_prefix} 正在执行工具: [bold green]{function_name}[/bold green] | 参数: {function_args}" ) - + # 检查是否是MCP工具 - pass - + from src.plugin_system.core import component_registry + + if component_registry.is_mcp_tool(function_name): + logger.debug(f"{self.log_prefix}识别到 MCP 工具: {function_name}") + # 找到对应的 MCP 工具实例 + mcp_tools = component_registry.get_mcp_tools() + mcp_tool = next((t for t in mcp_tools if t.name == function_name), None) + + if mcp_tool: + logger.debug(f"{self.log_prefix}执行 MCP 工具 {function_name}") + result = await mcp_tool.execute(function_args) + + if result: + logger.debug(f"{self.log_prefix}MCP 工具 {function_name} 执行成功") + return { + "tool_call_id": tool_call.call_id, + "role": "tool", + "name": function_name, + "type": "function", + "content": result.get("content", ""), + } + else: + logger.warning(f"{self.log_prefix}未找到 MCP 工具: {function_name}") + return None + function_args["llm_called"] = True # 标记为LLM调用 # 检查是否是二步工具的第二步调用 diff --git a/src/plugins/built_in/stt_whisper_plugin/__init__.py b/src/plugins/built_in/stt_whisper_plugin/__init__.py index bd7cc2259..5cde004ae 100644 --- a/src/plugins/built_in/stt_whisper_plugin/__init__.py +++ b/src/plugins/built_in/stt_whisper_plugin/__init__.py @@ -6,4 +6,4 @@ __plugin_meta__ = PluginMetadata( usage="在 bot_config.toml 中将 asr_provider 设置为 'local' 即可启用", version="0.1.0", author="Elysia", -) \ No newline at end of file +) diff --git a/src/plugins/built_in/stt_whisper_plugin/plugin.py b/src/plugins/built_in/stt_whisper_plugin/plugin.py index fb5ea38a7..34d7a09c0 100644 --- a/src/plugins/built_in/stt_whisper_plugin/plugin.py +++ b/src/plugins/built_in/stt_whisper_plugin/plugin.py @@ -1,9 +1,4 @@ import asyncio -import os -import tempfile -from typing import Any -from pathlib import Path -import toml import whisper @@ -40,7 +35,7 @@ class LocalASRTool(BaseTool): model_size = plugin_config.get("whisper", {}).get("model_size", "tiny") device = plugin_config.get("whisper", {}).get("device", "cpu") logger.info(f"正在预加载 Whisper ASR 模型: {model_size} ({device})") - + loop = asyncio.get_running_loop() _whisper_model = await loop.run_in_executor( None, whisper.load_model, model_size, device @@ -61,10 +56,10 @@ class LocalASRTool(BaseTool): # 增强的等待逻辑:只要模型还没准备好,就一直等待后台加载任务完成 while _is_loading: await asyncio.sleep(0.2) - + if _whisper_model is None: return "Whisper 模型加载失败,无法识别语音。" - + try: logger.info(f"开始使用 Whisper 识别音频: {audio_path}") loop = asyncio.get_running_loop() @@ -110,6 +105,6 @@ class STTWhisperPlugin(BasePlugin): ), LocalASRTool)] except Exception as e: logger.error(f"检查 ASR provider 配置时出错: {e}") - + logger.debug("ASR provider is not 'local', whisper plugin's tool is disabled.") return [] diff --git a/src/plugins/built_in/tts_voice_plugin/actions/tts_action.py b/src/plugins/built_in/tts_voice_plugin/actions/tts_action.py index 3afca4ecc..058600ffa 100644 --- a/src/plugins/built_in/tts_voice_plugin/actions/tts_action.py +++ b/src/plugins/built_in/tts_voice_plugin/actions/tts_action.py @@ -6,9 +6,9 @@ from pathlib import Path import toml +from src.chat.utils.self_voice_cache import register_self_voice from src.common.logger import get_logger from src.plugin_system.base.base_action import BaseAction, ChatMode -from src.chat.utils.self_voice_cache import register_self_voice from ..services.manager import get_service diff --git a/src/plugins/built_in/web_search_tool/__init__.py b/src/plugins/built_in/web_search_tool/__init__.py index 458e2586b..2a9dfc1bf 100644 --- a/src/plugins/built_in/web_search_tool/__init__.py +++ b/src/plugins/built_in/web_search_tool/__init__.py @@ -15,7 +15,7 @@ __plugin_meta__ = PluginMetadata( "is_built_in": True, }, # Python包依赖列表 - python_dependencies = [ # noqa: RUF012 + python_dependencies = [ PythonDependency(package_name="asyncddgs", description="异步DuckDuckGo搜索库", optional=False), PythonDependency( package_name="exa_py",