diff --git a/TODO.md b/TODO.md index c55725b62..afdc43047 100644 --- a/TODO.md +++ b/TODO.md @@ -18,7 +18,7 @@ - [x] 添加表情包情感分析功能 - [x] 添加主动思考配置 - [x] 添加日程管理 -- [ ] 添加MCP SSE支持 +- [x] 添加MCP SSE支持 - [ ] 增加基于GPT-Sovits的多情感语音合成功能(插件形式) - [ ] 增加基于Open Voice的语音合成功能(插件形式) - [x] 对聊天信息的视频增加一个videoid(就像imageid一样) diff --git a/docs/MCP_SSE_USAGE.md b/docs/MCP_SSE_USAGE.md new file mode 100644 index 000000000..70fc9906b --- /dev/null +++ b/docs/MCP_SSE_USAGE.md @@ -0,0 +1,274 @@ +# MCP SSE 客户端使用指南 + +## 简介 + +MCP (Model Context Protocol) SSE (Server-Sent Events) 客户端支持通过SSE协议与MCP兼容的服务器进行通信。该客户端已集成到MoFox Bot的LLM模型客户端系统中。 + +## 功能特性 + +- ✅ 支持SSE流式响应 +- ✅ 支持多轮对话 +- ✅ 支持工具调用(Function Calling) +- ✅ 支持多模态内容(文本+图片) +- ✅ 自动处理中断信号 +- ✅ 完整的Token使用统计 + +## 配置方法 + +### 1. 安装依赖 + +依赖已自动添加到项目中: +```bash +pip install mcp>=0.9.0 sse-starlette>=2.2.1 +``` + +或使用uv: +```bash +uv sync +``` + +### 2. 配置API Provider + +在配置文件中添加MCP SSE provider: + +```python +# 在配置中添加 +api_providers = [ + { + "name": "mcp_provider", + "client_type": "mcp_sse", # 使用MCP SSE客户端 + "base_url": "https://your-mcp-server.com", + "api_key": "your-api-key", + "timeout": 60 + } +] +``` + +### 3. 配置模型 + +```python +models = [ + { + "name": "mcp_model", + "api_provider": "mcp_provider", + "model_identifier": "your-model-name", + "force_stream_mode": True # MCP SSE始终使用流式 + } +] +``` + +## 使用示例 + +### 基础对话 + +```python +from src.llm_models.model_client.base_client import client_registry +from src.llm_models.payload_content.message import Message, MessageBuilder, RoleType +from src.config.api_ada_configs import APIProvider, ModelInfo + +# 获取客户端 +api_provider = APIProvider( + name="mcp_provider", + client_type="mcp_sse", + base_url="https://your-mcp-server.com", + api_key="your-api-key" +) + +client = client_registry.get_client_class_instance(api_provider) + +# 构建消息 +messages = [ + MessageBuilder() + .set_role(RoleType.User) + .add_text_content("你好,请介绍一下你自己") + .build() +] + +# 获取响应 +model_info = ModelInfo( + name="mcp_model", + api_provider="mcp_provider", + model_identifier="your-model-name" +) + +response = await client.get_response( + model_info=model_info, + message_list=messages, + max_tokens=1024, + temperature=0.7 +) + +print(response.content) +``` + +### 使用工具调用 + +```python +from src.llm_models.payload_content.tool_option import ( + ToolOptionBuilder, + ToolParamType +) + +# 定义工具 +tools = [ + ToolOptionBuilder() + .set_name("get_weather") + .set_description("获取指定城市的天气信息") + .add_param( + name="city", + param_type=ToolParamType.STRING, + description="城市名称", + required=True + ) + .build() +] + +# 发送请求 +response = await client.get_response( + model_info=model_info, + message_list=messages, + tool_options=tools, + max_tokens=1024, + temperature=0.7 +) + +# 检查工具调用 +if response.tool_calls: + for tool_call in response.tool_calls: + print(f"调用工具: {tool_call.func_name}") + print(f"参数: {tool_call.args}") +``` + +### 多模态对话 + +```python +import base64 + +# 读取图片并编码 +with open("image.jpg", "rb") as f: + image_data = base64.b64encode(f.read()).decode("utf-8") + +# 构建多模态消息 +messages = [ + MessageBuilder() + .set_role(RoleType.User) + .add_text_content("这张图片里有什么?") + .add_image_content("jpg", image_data) + .build() +] + +response = await client.get_response( + model_info=model_info, + message_list=messages +) +``` + +### 中断处理 + +```python +import asyncio + +# 创建中断事件 +interrupt_flag = asyncio.Event() + +# 在另一个协程中设置中断 +async def interrupt_after_delay(): + await asyncio.sleep(5) + interrupt_flag.set() + +asyncio.create_task(interrupt_after_delay()) + +try: + response = await client.get_response( + model_info=model_info, + message_list=messages, + interrupt_flag=interrupt_flag + ) +except ReqAbortException: + print("请求被中断") +``` + +## MCP协议规范 + +MCP SSE客户端遵循以下协议规范: + +### 请求格式 + +```json +{ + "model": "model-name", + "messages": [ + { + "role": "user", + "content": "message content" + } + ], + "max_tokens": 1024, + "temperature": 0.7, + "stream": true, + "tools": [ + { + "name": "tool_name", + "description": "tool description", + "input_schema": { + "type": "object", + "properties": {...}, + "required": [...] + } + } + ] +} +``` + +### SSE事件类型 + +客户端处理以下SSE事件: + +1. **content_block_start** - 内容块开始 +2. **content_block_delta** - 内容块增量 +3. **content_block_stop** - 内容块结束 +4. **message_delta** - 消息元数据更新 +5. **message_stop** - 消息结束 + +## 限制说明 + +当前MCP SSE客户端的限制: + +- ❌ 不支持嵌入(Embedding)功能 +- ❌ 不支持音频转录功能 +- ✅ 仅支持流式响应(SSE特性) + +## 故障排查 + +### 连接失败 + +检查: +1. base_url是否正确 +2. API key是否有效 +3. 网络连接是否正常 +4. 服务器是否支持SSE协议 + +### 解析错误 + +检查: +1. 服务器返回的SSE格式是否符合MCP规范 +2. 查看日志中的详细错误信息 + +### 工具调用失败 + +检查: +1. 工具定义的schema是否正确 +2. 服务器是否支持工具调用功能 + +## 相关文档 + +- [MCP协议规范](https://github.com/anthropics/mcp) +- [SSE规范](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events) +- [MoFox Bot文档](../README.md) + +## 更新日志 + +### v0.8.1 +- ✅ 添加MCP SSE客户端支持 +- ✅ 支持流式响应和工具调用 +- ✅ 支持多模态内容 diff --git a/pyproject.toml b/pyproject.toml index 17b361c5f..2ad3c5433 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,8 @@ dependencies = [ "aiosqlite>=0.21.0", "inkfox>=0.1.0", "rrjieba>=0.1.13", + "mcp>=0.9.0", + "sse-starlette>=2.2.1", ] [[tool.uv.index]] diff --git a/requirements.txt b/requirements.txt index 86cd7d666..aa05f5b15 100644 --- a/requirements.txt +++ b/requirements.txt @@ -69,4 +69,6 @@ lunar_python fuzzywuzzy python-multipart aiofiles -inkfox \ No newline at end of file +inkfox +mcp +sse-starlette \ No newline at end of file diff --git a/src/llm_models/model_client/__init__.py b/src/llm_models/model_client/__init__.py index 6c4151c41..e89e65a80 100644 --- a/src/llm_models/model_client/__init__.py +++ b/src/llm_models/model_client/__init__.py @@ -6,3 +6,5 @@ if "openai" in used_client_types: from . import openai_client # noqa: F401 if "aiohttp_gemini" in used_client_types: from . import aiohttp_gemini_client # noqa: F401 +if "mcp_sse" in used_client_types: + from . import mcp_sse_client # noqa: F401 diff --git a/src/llm_models/model_client/mcp_sse_client.py b/src/llm_models/model_client/mcp_sse_client.py new file mode 100644 index 000000000..ec4502dbb --- /dev/null +++ b/src/llm_models/model_client/mcp_sse_client.py @@ -0,0 +1,410 @@ +""" +MCP (Model Context Protocol) SSE (Server-Sent Events) 客户端实现 +支持通过SSE协议与MCP服务器进行通信 +""" + +import asyncio +import io +import json +from collections.abc import Callable +from typing import Any + +import aiohttp +import orjson +from json_repair import repair_json + +from src.common.logger import get_logger +from src.config.api_ada_configs import APIProvider, ModelInfo + +from ..exceptions import ( + NetworkConnectionError, + ReqAbortException, + RespNotOkException, + RespParseException, +) +from ..payload_content.message import Message, RoleType +from ..payload_content.resp_format import RespFormat +from ..payload_content.tool_option import ToolCall, ToolOption +from .base_client import APIResponse, BaseClient, UsageRecord, client_registry + +logger = get_logger("MCP-SSE客户端") + + +def _convert_messages_to_mcp(messages: list[Message]) -> list[dict[str, Any]]: + """ + 将消息列表转换为MCP协议格式 + :param messages: 消息列表 + :return: MCP格式的消息列表 + """ + mcp_messages = [] + + for message in messages: + mcp_msg: dict[str, Any] = { + "role": message.role.value, + } + + # 处理内容 + if isinstance(message.content, str): + mcp_msg["content"] = message.content + elif isinstance(message.content, list): + # 处理多模态内容 + content_parts = [] + for item in message.content: + if isinstance(item, tuple): + # 图片内容 + content_parts.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": f"image/{item[0].lower()}", + "data": item[1], + }, + }) + elif isinstance(item, str): + # 文本内容 + content_parts.append({"type": "text", "text": item}) + mcp_msg["content"] = content_parts + + # 添加工具调用ID(如果是工具消息) + if message.role == RoleType.Tool and message.tool_call_id: + mcp_msg["tool_call_id"] = message.tool_call_id + + mcp_messages.append(mcp_msg) + + return mcp_messages + + +def _convert_tools_to_mcp(tool_options: list[ToolOption]) -> list[dict[str, Any]]: + """ + 将工具选项转换为MCP协议格式 + :param tool_options: 工具选项列表 + :return: MCP格式的工具列表 + """ + mcp_tools = [] + + for tool in tool_options: + mcp_tool = { + "name": tool.name, + "description": tool.description, + } + + if tool.params: + properties = {} + required = [] + + for param in tool.params: + properties[param.name] = { + "type": param.param_type.value, + "description": param.description, + } + + if param.enum_values: + properties[param.name]["enum"] = param.enum_values + + if param.required: + required.append(param.name) + + mcp_tool["input_schema"] = { + "type": "object", + "properties": properties, + "required": required, + } + + mcp_tools.append(mcp_tool) + + return mcp_tools + + +async def _parse_sse_stream( + session: aiohttp.ClientSession, + url: str, + payload: dict[str, Any], + headers: dict[str, str], + interrupt_flag: asyncio.Event | None = None, +) -> tuple[APIResponse, tuple[int, int, int] | None]: + """ + 解析SSE流式响应 + :param session: aiohttp会话 + :param url: 请求URL + :param payload: 请求负载 + :param headers: 请求头 + :param interrupt_flag: 中断标志 + :return: API响应和使用记录 + """ + content_buffer = io.StringIO() + reasoning_buffer = io.StringIO() + tool_calls_buffer: list[tuple[str, str, dict[str, Any]]] = [] + usage_record = None + + try: + async with session.post(url, json=payload, headers=headers) as response: + if response.status != 200: + error_text = await response.text() + raise RespNotOkException( + response.status, f"MCP SSE请求失败: {error_text}" + ) + + # 解析SSE流 + async for line in response.content: + if interrupt_flag and interrupt_flag.is_set(): + raise ReqAbortException("请求被外部信号中断") + + decoded_line = line.decode("utf-8").strip() + + # 跳过空行和注释 + if not decoded_line or decoded_line.startswith(":"): + continue + + # 解析SSE事件 + if decoded_line.startswith("data: "): + data_str = decoded_line[6:] # 移除"data: "前缀 + + # 跳过[DONE]标记 + if data_str == "[DONE]": + break + + try: + event_data = orjson.loads(data_str) + except orjson.JSONDecodeError: + logger.warning(f"无法解析SSE数据: {data_str}") + continue + + # 处理不同类型的事件 + event_type = event_data.get("type") + + if event_type == "content_block_start": + # 内容块开始 + block = event_data.get("content_block", {}) + if block.get("type") == "text": + pass # 准备接收文本内容 + elif block.get("type") == "tool_use": + # 工具调用开始 + tool_calls_buffer.append( + ( + block.get("id", ""), + block.get("name", ""), + {}, + ) + ) + + elif event_type == "content_block_delta": + # 内容块增量 + delta = event_data.get("delta", {}) + delta_type = delta.get("type") + + if delta_type == "text_delta": + # 文本增量 + text = delta.get("text", "") + content_buffer.write(text) + + elif delta_type == "input_json_delta": + # 工具调用参数增量 + if tool_calls_buffer: + partial_json = delta.get("partial_json", "") + # 累积JSON片段 + current_args = tool_calls_buffer[-1][2] + if "_json_buffer" not in current_args: + current_args["_json_buffer"] = "" + current_args["_json_buffer"] += partial_json + + elif event_type == "content_block_stop": + # 内容块结束 + if tool_calls_buffer: + # 解析完整的工具调用参数 + last_call = tool_calls_buffer[-1] + if "_json_buffer" in last_call[2]: + json_str = last_call[2].pop("_json_buffer") + try: + parsed_args = orjson.loads(repair_json(json_str)) + tool_calls_buffer[-1] = ( + last_call[0], + last_call[1], + parsed_args if isinstance(parsed_args, dict) else {}, + ) + except orjson.JSONDecodeError as e: + logger.error(f"解析工具调用参数失败: {e}") + + elif event_type == "message_delta": + # 消息元数据更新 + delta = event_data.get("delta", {}) + stop_reason = delta.get("stop_reason") + if stop_reason: + logger.debug(f"消息结束原因: {stop_reason}") + + # 提取使用统计 + usage = event_data.get("usage", {}) + if usage: + usage_record = ( + usage.get("input_tokens", 0), + usage.get("output_tokens", 0), + usage.get("input_tokens", 0) + usage.get("output_tokens", 0), + ) + + elif event_type == "message_stop": + # 消息结束 + break + + except aiohttp.ClientError as e: + raise NetworkConnectionError() from e + except Exception as e: + logger.error(f"解析SSE流时发生错误: {e}") + raise + + # 构建响应 + response = APIResponse() + + if content_buffer.tell() > 0: + response.content = content_buffer.getvalue() + + if reasoning_buffer.tell() > 0: + response.reasoning_content = reasoning_buffer.getvalue() + + if tool_calls_buffer: + response.tool_calls = [ + ToolCall(call_id, func_name, args) + for call_id, func_name, args in tool_calls_buffer + ] + + # 关闭缓冲区 + content_buffer.close() + reasoning_buffer.close() + + return response, usage_record + + +@client_registry.register_client_class("mcp_sse") +class MCPSSEClient(BaseClient): + """ + MCP SSE客户端实现 + 支持通过Server-Sent Events协议与MCP服务器通信 + """ + + def __init__(self, api_provider: APIProvider): + super().__init__(api_provider) + self._session: aiohttp.ClientSession | None = None + + async def _get_session(self) -> aiohttp.ClientSession: + """获取或创建aiohttp会话""" + if self._session is None or self._session.closed: + timeout = aiohttp.ClientTimeout(total=self.api_provider.timeout) + self._session = aiohttp.ClientSession(timeout=timeout) + return self._session + + async def close(self): + """关闭客户端会话""" + if self._session and not self._session.closed: + await self._session.close() + + async def get_response( + self, + model_info: ModelInfo, + message_list: list[Message], + tool_options: list[ToolOption] | None = None, + max_tokens: int = 1024, + temperature: float = 0.7, + response_format: RespFormat | None = None, + stream_response_handler: Callable[[Any, asyncio.Event | None], tuple[APIResponse, tuple[int, int, int]]] + | None = None, + async_response_parser: Callable[[Any], tuple[APIResponse, tuple[int, int, int]]] | None = None, + interrupt_flag: asyncio.Event | None = None, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取对话响应 + :param model_info: 模型信息 + :param message_list: 对话消息列表 + :param tool_options: 工具选项 + :param max_tokens: 最大token数 + :param temperature: 温度参数 + :param response_format: 响应格式 + :param stream_response_handler: 流式响应处理器 + :param async_response_parser: 异步响应解析器 + :param interrupt_flag: 中断标志 + :param extra_params: 额外参数 + :return: API响应 + """ + session = await self._get_session() + + # 构建请求负载 + payload: dict[str, Any] = { + "model": model_info.model_identifier, + "messages": _convert_messages_to_mcp(message_list), + "max_tokens": max_tokens, + "temperature": temperature, + "stream": True, # MCP SSE始终使用流式 + } + + # 添加工具 + if tool_options: + payload["tools"] = _convert_tools_to_mcp(tool_options) + + # 添加额外参数 + if extra_params: + payload.update(extra_params) + + # 构建请求头 + headers = { + "Content-Type": "application/json", + "Accept": "text/event-stream", + "Authorization": f"Bearer {self.api_provider.get_api_key()}", + } + + # 发送请求并解析响应 + url = f"{self.api_provider.base_url}/v1/messages" + + try: + response, usage_record = await _parse_sse_stream( + session, url, payload, headers, interrupt_flag + ) + except Exception as e: + logger.error(f"MCP SSE请求失败: {e}") + raise + + # 添加使用记录 + if usage_record: + response.usage = UsageRecord( + model_name=model_info.name, + provider_name=model_info.api_provider, + prompt_tokens=usage_record[0], + completion_tokens=usage_record[1], + total_tokens=usage_record[2], + ) + + return response + + async def get_embedding( + self, + model_info: ModelInfo, + embedding_input: str, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取文本嵌入 + MCP协议暂不支持嵌入功能 + :param model_info: 模型信息 + :param embedding_input: 嵌入输入文本 + :return: 嵌入响应 + """ + raise NotImplementedError("MCP SSE客户端暂不支持嵌入功能") + + async def get_audio_transcriptions( + self, + model_info: ModelInfo, + audio_base64: str, + extra_params: dict[str, Any] | None = None, + ) -> APIResponse: + """ + 获取音频转录 + MCP协议暂不支持音频转录功能 + :param model_info: 模型信息 + :param audio_base64: base64编码的音频数据 + :return: 音频转录响应 + """ + raise NotImplementedError("MCP SSE客户端暂不支持音频转录功能") + + def get_support_image_formats(self) -> list[str]: + """ + 获取支持的图片格式 + :return: 支持的图片格式列表 + """ + return ["jpg", "jpeg", "png", "webp", "gif"]