添加了MCP SEE支持

能不能用我不知道,先加进来。主要我没有服务,无法测试
This commit is contained in:
雅诺狐
2025-08-14 17:22:07 +08:00
parent c38be26bbb
commit d5777c2980
12 changed files with 1629 additions and 4 deletions

View File

@@ -448,6 +448,11 @@ MODULE_COLORS = {
"web_surfing_tool": "\033[38;5;130m", # 棕色
"tts": "\033[38;5;136m", # 浅棕色
# MCP SSE
"mcp_sse_manager": "\033[38;5;202m", # 橙红色
"mcp_event_handler": "\033[38;5;208m", # 橙色
"mcp_sse_client": "\033[38;5;214m", # 橙黄色
# mais4u系统扩展
"s4u_config": "\033[38;5;18m", # 深蓝色
"action": "\033[38;5;52m", # 深红色mais4u的action
@@ -530,6 +535,10 @@ MODULE_ALIASES = {
"web_surfing_tool": "网络搜索",
"tts": "语音合成",
# MCP SSE
"mcp_sse_manager": "MCP管理器",
"mcp_event_handler": "MCP事件处",
"mcp_sse_client": "MCP客户端",
# mais4u系统扩展
"s4u_config": "直播配置",
"action": "直播动作",

View File

@@ -42,6 +42,7 @@ from src.config.official_configs import (
ExaConfig,
WebSearchConfig,
TavilyConfig,
MCPSSEConfig,
)
from .api_ada_configs import (
@@ -362,6 +363,7 @@ class Config(ConfigBase):
exa: ExaConfig = field(default_factory=lambda: ExaConfig())
web_search: WebSearchConfig = field(default_factory=lambda: WebSearchConfig())
tavily: TavilyConfig = field(default_factory=lambda: TavilyConfig())
mcp_sse: MCPSSEConfig = field(default_factory=lambda: MCPSSEConfig())
@dataclass

View File

@@ -1,7 +1,7 @@
import re
from dataclasses import dataclass, field
from typing import Literal, Optional
from typing import Literal, Optional, Dict
from src.config.config_base import ConfigBase
@@ -1002,4 +1002,112 @@ class WebSearchConfig(ConfigBase):
"""启用的搜索引擎列表,可选: 'exa', 'tavily', 'ddg'"""
search_strategy: str = "single"
"""搜索策略: 'single'(使用第一个可用引擎), 'parallel'(并行使用所有启用的引擎), 'fallback'(按顺序尝试,失败则尝试下一个)"""
"""搜索策略: 'single'(使用第一个可用引擎), 'parallel'(并行使用所有启用的引擎), 'fallback'(按顺序尝试,失败则尝试下一个)"""
@dataclass
class MCPSSEConfig(ConfigBase):
"""MCP Server-Sent Events 客户端配置类"""
enable: bool = False
"""是否启用 MCP SSE 客户端"""
server_url: str = ""
"""MCP 服务器 SSE 端点 URL例如: http://localhost:8080/events"""
auth_key: str = ""
"""MCP 服务器认证密钥"""
# 连接配置
connection_timeout: int = 30
"""连接超时时间(秒)"""
read_timeout: int = 60
"""读取超时时间(秒)"""
# 重连配置
enable_reconnect: bool = True
"""是否启用自动重连"""
max_reconnect_attempts: int = 10
"""最大重连尝试次数,-1 表示无限重连"""
initial_reconnect_delay: float = 1.0
"""初始重连延迟时间(秒)"""
max_reconnect_delay: float = 60.0
"""最大重连延迟时间(秒)"""
reconnect_backoff_factor: float = 2.0
"""重连延迟指数退避因子"""
# 事件处理配置
event_buffer_size: int = 1000
"""事件缓冲区大小"""
enable_event_logging: bool = True
"""是否启用事件日志记录"""
# 订阅配置
subscribed_events: list[str] = field(default_factory=lambda: [])
"""订阅的事件类型列表,空列表表示订阅所有事件"""
# 高级配置
custom_headers: Dict[str, str] = field(default_factory=dict)
"""自定义 HTTP 头部"""
user_agent: str = "MaiBot-MCP-SSE-Client/1.0"
"""用户代理字符串"""
# SSL 配置
verify_ssl: bool = True
"""是否验证 SSL 证书"""
ssl_cert_path: Optional[str] = None
"""SSL 客户端证书路径"""
ssl_key_path: Optional[str] = None
"""SSL 客户端密钥路径"""
def __post_init__(self):
"""配置验证"""
# 只有在启用时才验证必需的配置
if self.enable:
if not self.server_url:
raise ValueError("启用 MCP SSE 客户端时必须提供 server_url")
# 这些参数无论是否启用都需要验证(因为有默认值)
if self.connection_timeout <= 0:
raise ValueError("connection_timeout 必须大于 0")
if self.read_timeout <= 0:
raise ValueError("read_timeout 必须大于 0")
if self.max_reconnect_attempts < -1:
raise ValueError("max_reconnect_attempts 必须大于等于 -1")
if self.initial_reconnect_delay <= 0:
raise ValueError("initial_reconnect_delay 必须大于 0")
if self.max_reconnect_delay <= 0:
raise ValueError("max_reconnect_delay 必须大于 0")
if self.reconnect_backoff_factor <= 1.0:
raise ValueError("reconnect_backoff_factor 必须大于 1.0")
if self.event_buffer_size <= 0:
raise ValueError("event_buffer_size 必须大于 0")
def get_headers(self) -> Dict[str, str]:
"""获取完整的 HTTP 头部"""
headers = {
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
"User-Agent": self.user_agent,
}
if self.auth_key:
headers["Authorization"] = f"Bearer {self.auth_key}"
headers.update(self.custom_headers)
return headers

View File

@@ -31,6 +31,10 @@ from src.common.message import get_global_api
if global_config.memory.enable_memory:
from src.chat.memory_system.Hippocampus import hippocampus_manager
# 条件导入 MCP SSE 系统
if global_config.mcp_sse.enable:
from src.mcp import initialize_mcp_sse_manager, start_mcp_sse_manager, stop_mcp_sse_manager
# 插件系统现在使用统一的插件加载器
install(extra_lines=3)
@@ -48,6 +52,12 @@ class MainSystem:
else:
self.hippocampus_manager = None
# 根据配置条件性地初始化 MCP SSE 系统
if global_config.mcp_sse.enable:
self.mcp_sse_manager = initialize_mcp_sse_manager(global_config.mcp_sse)
else:
self.mcp_sse_manager = None
self.individuality: Individuality = get_individuality()
# 使用消息API替代直接的FastAPI实例
@@ -76,6 +86,14 @@ class MainSystem:
except Exception as e:
logger.error(f"停止热重载系统时出错: {e}")
# 停止 MCP SSE 系统
if global_config.mcp_sse.enable and self.mcp_sse_manager:
try:
asyncio.create_task(stop_mcp_sse_manager())
logger.info("🛑 MCP SSE 系统已停止")
except Exception as e:
logger.error(f"停止 MCP SSE 系统时出错: {e}")
async def initialize(self):
"""初始化系统组件"""
logger.info(f"正在唤醒{global_config.bot.nickname}......")
@@ -161,6 +179,19 @@ MaiMbot-Pro-Max(第三方改版)
await schedule_manager.load_or_generate_today_schedule()
logger.info("日程表管理器初始化成功。")
# 根据配置条件性地启动 MCP SSE 系统
if global_config.mcp_sse.enable:
if self.mcp_sse_manager:
try:
await start_mcp_sse_manager()
logger.info("MCP SSE 系统初始化成功")
except Exception as e:
logger.error(f"MCP SSE 系统初始化失败: {e}")
else:
logger.warning("MCP SSE 系统已启用但管理器未初始化")
else:
logger.info("MCP SSE 系统已禁用,跳过初始化")
try:
init_time = int(1000 * (time.time() - init_start_time))
logger.info(f"初始化完成,神经元放电{init_time}")

32
src/mcp/__init__.py Normal file
View File

@@ -0,0 +1,32 @@
"""
MCP (Model Context Protocol) 模块
提供 MCP 服务器的 Server-Sent Events (SSE) 客户端功能,
支持实时事件订阅、断线重连和事件处理。
"""
from .sse_client import MCPSSEClient
from .event_handler import MCPEventHandler, MCPEvent
from .exceptions import MCPConnectionError, MCPEventError
from .manager import (
MCPSSEManager,
get_mcp_sse_manager,
initialize_mcp_sse_manager,
start_mcp_sse_manager,
stop_mcp_sse_manager,
)
from .config import MCPSSEConfig
__all__ = [
"MCPSSEClient",
"MCPEventHandler",
"MCPEvent",
"MCPConnectionError",
"MCPEventError",
"MCPSSEManager",
"MCPSSEConfig",
"get_mcp_sse_manager",
"initialize_mcp_sse_manager",
"start_mcp_sse_manager",
"stop_mcp_sse_manager",
]

112
src/mcp/config.py Normal file
View File

@@ -0,0 +1,112 @@
"""
MCP SSE 客户端配置类
"""
from dataclasses import dataclass, field
from typing import Optional, Dict, Any
from src.config.config_base import ConfigBase
@dataclass
class MCPSSEConfig(ConfigBase):
"""MCP Server-Sent Events 客户端配置类"""
enable: bool = False
"""是否启用 MCP SSE 客户端"""
server_url: str = ""
"""MCP 服务器 SSE 端点 URL例如: http://localhost:8080/events"""
auth_key: str = ""
"""MCP 服务器认证密钥"""
# 连接配置
connection_timeout: int = 30
"""连接超时时间(秒)"""
read_timeout: int = 60
"""读取超时时间(秒)"""
# 重连配置
enable_reconnect: bool = True
"""是否启用自动重连"""
max_reconnect_attempts: int = 10
"""最大重连尝试次数,-1 表示无限重连"""
initial_reconnect_delay: float = 1.0
"""初始重连延迟时间(秒)"""
max_reconnect_delay: float = 60.0
"""最大重连延迟时间(秒)"""
reconnect_backoff_factor: float = 2.0
"""重连延迟指数退避因子"""
# 事件处理配置
event_buffer_size: int = 1000
"""事件缓冲区大小"""
enable_event_logging: bool = True
"""是否启用事件日志记录"""
# 订阅配置
subscribed_events: list[str] = field(default_factory=lambda: [])
"""订阅的事件类型列表,空列表表示订阅所有事件"""
# 高级配置
custom_headers: Dict[str, str] = field(default_factory=dict)
"""自定义 HTTP 头部"""
user_agent: str = "MaiBot-MCP-SSE-Client/1.0"
"""用户代理字符串"""
# SSL 配置
verify_ssl: bool = True
"""是否验证 SSL 证书"""
ssl_cert_path: Optional[str] = None
"""SSL 客户端证书路径"""
ssl_key_path: Optional[str] = None
"""SSL 客户端密钥路径"""
def __post_init__(self):
"""配置验证"""
if self.enable and not self.server_url:
raise ValueError("启用 MCP SSE 客户端时必须提供 server_url")
if self.connection_timeout <= 0:
raise ValueError("connection_timeout 必须大于 0")
if self.read_timeout <= 0:
raise ValueError("read_timeout 必须大于 0")
if self.max_reconnect_attempts < -1:
raise ValueError("max_reconnect_attempts 必须大于等于 -1")
if self.initial_reconnect_delay <= 0:
raise ValueError("initial_reconnect_delay 必须大于 0")
if self.max_reconnect_delay <= 0:
raise ValueError("max_reconnect_delay 必须大于 0")
if self.reconnect_backoff_factor <= 1.0:
raise ValueError("reconnect_backoff_factor 必须大于 1.0")
if self.event_buffer_size <= 0:
raise ValueError("event_buffer_size 必须大于 0")
def get_headers(self) -> Dict[str, str]:
"""获取完整的 HTTP 头部"""
headers = {
"Accept": "text/event-stream",
"Cache-Control": "no-cache",
"User-Agent": self.user_agent,
}
if self.auth_key:
headers["Authorization"] = f"Bearer {self.auth_key}"
headers.update(self.custom_headers)
return headers

256
src/mcp/event_handler.py Normal file
View File

@@ -0,0 +1,256 @@
"""
MCP 事件处理器
"""
import json
import asyncio
from typing import Dict, Any, Callable, Optional, List
from dataclasses import dataclass
from datetime import datetime
from src.common.logger import get_logger
logger = get_logger("mcp_event_handler")
@dataclass
class MCPEvent:
"""MCP 事件数据类"""
event_type: str
"""事件类型"""
data: Dict[str, Any]
"""事件数据"""
timestamp: datetime
"""事件时间戳"""
event_id: Optional[str] = None
"""事件 ID"""
retry: Optional[int] = None
"""重试间隔(毫秒)"""
class MCPEventHandler:
"""MCP 事件处理器"""
def __init__(self):
self._event_handlers: Dict[str, List[Callable]] = {}
self._global_handlers: List[Callable] = []
self._event_buffer: List[MCPEvent] = []
self._buffer_size = 1000
self._running = False
def register_handler(self, event_type: str, handler: Callable[[MCPEvent], None]):
"""
注册事件处理器
Args:
event_type: 事件类型
handler: 事件处理函数
"""
if event_type not in self._event_handlers:
self._event_handlers[event_type] = []
self._event_handlers[event_type].append(handler)
logger.debug(f"注册事件处理器: {event_type}")
def register_global_handler(self, handler: Callable[[MCPEvent], None]):
"""
注册全局事件处理器(处理所有事件)
Args:
handler: 事件处理函数
"""
self._global_handlers.append(handler)
logger.debug("注册全局事件处理器")
def unregister_handler(self, event_type: str, handler: Callable[[MCPEvent], None]):
"""
取消注册事件处理器
Args:
event_type: 事件类型
handler: 事件处理函数
"""
if event_type in self._event_handlers:
try:
self._event_handlers[event_type].remove(handler)
logger.debug(f"取消注册事件处理器: {event_type}")
except ValueError:
logger.warning(f"尝试取消注册不存在的事件处理器: {event_type}")
def unregister_global_handler(self, handler: Callable[[MCPEvent], None]):
"""
取消注册全局事件处理器
Args:
handler: 事件处理函数
"""
try:
self._global_handlers.remove(handler)
logger.debug("取消注册全局事件处理器")
except ValueError:
logger.warning("尝试取消注册不存在的全局事件处理器")
async def handle_event(self, event: MCPEvent):
"""
处理单个事件
Args:
event: MCP 事件
"""
logger.debug(f"处理事件: {event.event_type}")
# 添加到事件缓冲区
self._add_to_buffer(event)
# 处理特定类型的事件处理器
if event.event_type in self._event_handlers:
for handler in self._event_handlers[event.event_type]:
try:
if asyncio.iscoroutinefunction(handler):
await handler(event)
else:
handler(event)
except Exception as e:
logger.error(f"事件处理器执行失败: {e}", exc_info=True)
# 处理全局事件处理器
for handler in self._global_handlers:
try:
if asyncio.iscoroutinefunction(handler):
await handler(event)
else:
handler(event)
except Exception as e:
logger.error(f"全局事件处理器执行失败: {e}", exc_info=True)
def _add_to_buffer(self, event: MCPEvent):
"""
添加事件到缓冲区
Args:
event: MCP 事件
"""
self._event_buffer.append(event)
# 如果缓冲区超过限制,移除最旧的事件
if len(self._event_buffer) > self._buffer_size:
self._event_buffer.pop(0)
def get_recent_events(self, count: int = 10) -> List[MCPEvent]:
"""
获取最近的事件
Args:
count: 获取的事件数量
Returns:
最近的事件列表
"""
return self._event_buffer[-count:]
def get_events_by_type(self, event_type: str, count: int = 10) -> List[MCPEvent]:
"""
根据类型获取事件
Args:
event_type: 事件类型
count: 获取的事件数量
Returns:
指定类型的事件列表
"""
filtered_events = [e for e in self._event_buffer if e.event_type == event_type]
return filtered_events[-count:]
def clear_buffer(self):
"""清空事件缓冲区"""
self._event_buffer.clear()
logger.debug("清空事件缓冲区")
def set_buffer_size(self, size: int):
"""
设置缓冲区大小
Args:
size: 缓冲区大小
"""
if size <= 0:
raise ValueError("缓冲区大小必须大于 0")
self._buffer_size = size
# 如果当前缓冲区超过新大小,截断
if len(self._event_buffer) > size:
self._event_buffer = self._event_buffer[-size:]
logger.debug(f"设置事件缓冲区大小: {size}")
def get_handler_count(self) -> Dict[str, int]:
"""
获取各类型事件处理器数量
Returns:
事件类型到处理器数量的映射
"""
counts = {}
for event_type, handlers in self._event_handlers.items():
counts[event_type] = len(handlers)
counts["global"] = len(self._global_handlers)
return counts
def parse_sse_event(raw_data: str) -> Optional[MCPEvent]:
"""
解析 SSE 事件数据
Args:
raw_data: 原始 SSE 数据
Returns:
解析后的 MCP 事件,如果解析失败返回 None
"""
try:
lines = raw_data.strip().split('\n')
event_type = None
event_data = None
event_id = None
retry = None
for line in lines:
line = line.strip()
if line.startswith('event:'):
event_type = line[6:].strip()
elif line.startswith('data:'):
data_str = line[5:].strip()
if data_str:
try:
event_data = json.loads(data_str)
except json.JSONDecodeError:
# 如果不是 JSON直接使用字符串
event_data = {"message": data_str}
elif line.startswith('id:'):
event_id = line[3:].strip()
elif line.startswith('retry:'):
try:
retry = int(line[6:].strip())
except ValueError:
pass
if event_type and event_data is not None:
return MCPEvent(
event_type=event_type,
data=event_data,
timestamp=datetime.now(),
event_id=event_id,
retry=retry
)
return None
except Exception as e:
logger.error(f"解析 SSE 事件失败: {e}")
return None

67
src/mcp/exceptions.py Normal file
View File

@@ -0,0 +1,67 @@
"""
MCP SSE 客户端异常类
"""
class MCPError(Exception):
"""MCP 基础异常类"""
pass
class MCPConnectionError(MCPError):
"""MCP 连接异常"""
def __init__(self, message: str, url: str = None, status_code: int = None):
super().__init__(message)
self.url = url
self.status_code = status_code
def __str__(self):
base_msg = super().__str__()
if self.url:
base_msg += f" (URL: {self.url})"
if self.status_code:
base_msg += f" (Status: {self.status_code})"
return base_msg
class MCPEventError(MCPError):
"""MCP 事件处理异常"""
def __init__(self, message: str, event_type: str = None, event_data: str = None):
super().__init__(message)
self.event_type = event_type
self.event_data = event_data
def __str__(self):
base_msg = super().__str__()
if self.event_type:
base_msg += f" (Event Type: {self.event_type})"
return base_msg
class MCPAuthenticationError(MCPConnectionError):
"""MCP 认证异常"""
pass
class MCPTimeoutError(MCPConnectionError):
"""MCP 超时异常"""
pass
class MCPReconnectError(MCPConnectionError):
"""MCP 重连异常"""
def __init__(self, message: str, attempts: int = 0, max_attempts: int = 0):
super().__init__(message)
self.attempts = attempts
self.max_attempts = max_attempts
def __str__(self):
base_msg = super().__str__()
if self.max_attempts > 0:
base_msg += f" (Attempts: {self.attempts}/{self.max_attempts})"
else:
base_msg += f" (Attempts: {self.attempts})"
return base_msg

260
src/mcp/manager.py Normal file
View File

@@ -0,0 +1,260 @@
"""
MCP SSE 管理器
负责管理 MCP SSE 客户端的生命周期,集成到 MaiBot 主系统中。
"""
import asyncio
from typing import Optional, Dict, Any, Callable
from datetime import datetime
from .sse_client import MCPSSEClient
from .config import MCPSSEConfig
from .event_handler import MCPEvent
from .exceptions import MCPConnectionError, MCPReconnectError
from src.common.logger import get_logger
logger = get_logger("mcp_sse_manager")
class MCPSSEManager:
"""MCP SSE 管理器"""
def __init__(self, config: MCPSSEConfig):
"""
初始化 MCP SSE 管理器
Args:
config: MCP SSE 配置
"""
self.config = config
self.client: Optional[MCPSSEClient] = None
self._task: Optional[asyncio.Task] = None
self._running = False
logger.info("初始化 MCP SSE 管理器")
async def start(self):
"""启动 MCP SSE 客户端"""
if not self.config.enable:
logger.info("MCP SSE 客户端未启用,跳过启动")
return
if self._running:
logger.warning("MCP SSE 客户端已在运行")
return
try:
# 创建客户端
self.client = MCPSSEClient(self.config)
# 注册默认事件处理器
self._register_default_handlers()
# 启动监听任务
self._task = asyncio.create_task(self._run_client())
self._running = True
logger.info("MCP SSE 客户端启动成功")
except Exception as e:
logger.error(f"启动 MCP SSE 客户端失败: {e}", exc_info=True)
await self.stop()
raise
async def stop(self):
"""停止 MCP SSE 客户端"""
if not self._running:
return
logger.info("停止 MCP SSE 客户端")
self._running = False
# 停止客户端
if self.client:
self.client.stop()
# 取消任务
if self._task and not self._task.done():
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
# 断开连接
if self.client:
await self.client.disconnect()
self.client = None
self._task = None
logger.info("MCP SSE 客户端已停止")
async def _run_client(self):
"""运行客户端监听循环"""
if not self.client:
return
try:
await self.client.start_listening()
except MCPReconnectError as e:
logger.error(f"MCP SSE 客户端重连失败: {e}")
except Exception as e:
logger.error(f"MCP SSE 客户端运行异常: {e}", exc_info=True)
finally:
self._running = False
def _register_default_handlers(self):
"""注册默认事件处理器"""
if not self.client:
return
# 注册全局事件处理器用于日志记录
self.client.register_global_event_handler(self._log_event_handler)
# 注册一些常见事件的处理器
self.client.register_event_handler("system.status", self._handle_system_status)
self.client.register_event_handler("chat.message", self._handle_chat_message)
self.client.register_event_handler("user.action", self._handle_user_action)
logger.debug("注册默认 MCP 事件处理器")
def _log_event_handler(self, event: MCPEvent):
"""全局事件日志处理器"""
if self.config.enable_event_logging:
logger.debug(f"MCP 事件: {event.event_type} - {event.data}")
def _handle_system_status(self, event: MCPEvent):
"""处理系统状态事件"""
logger.info(f"收到系统状态事件: {event.data}")
# 这里可以添加具体的系统状态处理逻辑
def _handle_chat_message(self, event: MCPEvent):
"""处理聊天消息事件"""
logger.info(f"收到聊天消息事件: {event.data}")
# 这里可以添加具体的聊天消息处理逻辑
# 例如:触发 MaiBot 的回复逻辑
def _handle_user_action(self, event: MCPEvent):
"""处理用户行为事件"""
logger.info(f"收到用户行为事件: {event.data}")
# 这里可以添加具体的用户行为处理逻辑
def register_event_handler(self, event_type: str, handler: Callable[[MCPEvent], None]):
"""
注册自定义事件处理器
Args:
event_type: 事件类型
handler: 事件处理函数
"""
if self.client:
self.client.register_event_handler(event_type, handler)
logger.debug(f"注册自定义事件处理器: {event_type}")
else:
logger.warning("客户端未初始化,无法注册事件处理器")
def register_global_event_handler(self, handler: Callable[[MCPEvent], None]):
"""
注册全局事件处理器
Args:
handler: 事件处理函数
"""
if self.client:
self.client.register_global_event_handler(handler)
logger.debug("注册全局事件处理器")
else:
logger.warning("客户端未初始化,无法注册全局事件处理器")
def is_running(self) -> bool:
"""检查是否正在运行"""
return self._running
def is_connected(self) -> bool:
"""检查是否已连接"""
return self.client.is_connected() if self.client else False
def get_stats(self) -> Dict[str, Any]:
"""
获取统计信息
Returns:
统计信息字典
"""
if not self.client:
return {
"enabled": self.config.enable,
"running": False,
"connected": False,
"client_initialized": False
}
stats = self.client.get_stats()
stats.update({
"enabled": self.config.enable,
"client_initialized": True,
"server_url": self.config.server_url,
"subscribed_events": self.config.subscribed_events,
})
return stats
def get_recent_events(self, count: int = 10):
"""
获取最近的事件
Args:
count: 获取的事件数量
Returns:
最近的事件列表
"""
if self.client:
return self.client.get_recent_events(count)
return []
# 全局 MCP SSE 管理器实例
_mcp_sse_manager: Optional[MCPSSEManager] = None
def get_mcp_sse_manager() -> Optional[MCPSSEManager]:
"""获取全局 MCP SSE 管理器实例"""
return _mcp_sse_manager
def initialize_mcp_sse_manager(config: MCPSSEConfig) -> MCPSSEManager:
"""
初始化全局 MCP SSE 管理器
Args:
config: MCP SSE 配置
Returns:
MCP SSE 管理器实例
"""
global _mcp_sse_manager
if _mcp_sse_manager:
logger.warning("MCP SSE 管理器已初始化")
return _mcp_sse_manager
_mcp_sse_manager = MCPSSEManager(config)
logger.info("全局 MCP SSE 管理器初始化完成")
return _mcp_sse_manager
async def start_mcp_sse_manager():
"""启动全局 MCP SSE 管理器"""
if _mcp_sse_manager:
await _mcp_sse_manager.start()
else:
logger.warning("MCP SSE 管理器未初始化")
async def stop_mcp_sse_manager():
"""停止全局 MCP SSE 管理器"""
if _mcp_sse_manager:
await _mcp_sse_manager.stop()

379
src/mcp/sse_client.py Normal file
View File

@@ -0,0 +1,379 @@
"""
MCP Server-Sent Events 客户端
"""
import asyncio
import aiohttp
import ssl
from typing import Optional, Dict, Any, Callable
from datetime import datetime
import time
import json
from .config import MCPSSEConfig
from .event_handler import MCPEventHandler, MCPEvent, parse_sse_event
from .exceptions import (
MCPConnectionError,
MCPAuthenticationError,
MCPTimeoutError,
MCPReconnectError,
MCPEventError
)
from src.common.logger import get_logger
logger = get_logger("mcp_sse_client")
class MCPSSEClient:
"""MCP Server-Sent Events 客户端"""
def __init__(self, config: MCPSSEConfig):
"""
初始化 MCP SSE 客户端
Args:
config: MCP SSE 配置
"""
self.config = config
self.event_handler = MCPEventHandler()
# 连接状态
self._session: Optional[aiohttp.ClientSession] = None
self._response: Optional[aiohttp.ClientResponse] = None
self._connected = False
self._running = False
# 重连状态
self._reconnect_attempts = 0
self._last_event_id: Optional[str] = None
# 统计信息
self._connection_start_time: Optional[datetime] = None
self._total_events_received = 0
self._last_event_time: Optional[datetime] = None
# 设置事件缓冲区大小
self.event_handler.set_buffer_size(config.event_buffer_size)
logger.info(f"初始化 MCP SSE 客户端: {config.server_url}")
async def connect(self) -> bool:
"""
连接到 MCP 服务器
Returns:
连接是否成功
"""
if self._connected:
logger.warning("客户端已连接")
return True
try:
# 创建 SSL 上下文
ssl_context = None
if self.config.server_url.startswith('https://'):
ssl_context = ssl.create_default_context()
if not self.config.verify_ssl:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
if self.config.ssl_cert_path and self.config.ssl_key_path:
ssl_context.load_cert_chain(
self.config.ssl_cert_path,
self.config.ssl_key_path
)
# 创建会话
timeout = aiohttp.ClientTimeout(
connect=self.config.connection_timeout,
sock_read=self.config.read_timeout
)
self._session = aiohttp.ClientSession(
timeout=timeout,
headers=self.config.get_headers()
)
# 建立连接
headers = {}
if self._last_event_id:
headers['Last-Event-ID'] = self._last_event_id
logger.info(f"连接到 MCP 服务器: {self.config.server_url}")
self._response = await self._session.get(
self.config.server_url,
headers=headers,
ssl=ssl_context
)
# 检查响应状态
if self._response.status == 401:
raise MCPAuthenticationError(
"认证失败",
url=self.config.server_url,
status_code=self._response.status
)
elif self._response.status != 200:
raise MCPConnectionError(
f"连接失败: HTTP {self._response.status}",
url=self.config.server_url,
status_code=self._response.status
)
# 检查内容类型
content_type = self._response.headers.get('Content-Type', '')
if 'text/event-stream' not in content_type:
raise MCPConnectionError(
f"无效的内容类型: {content_type}",
url=self.config.server_url
)
self._connected = True
self._connection_start_time = datetime.now()
self._reconnect_attempts = 0
logger.info("成功连接到 MCP 服务器")
return True
except asyncio.TimeoutError:
raise MCPTimeoutError(
"连接超时",
url=self.config.server_url
)
except Exception as e:
await self._cleanup_connection()
if isinstance(e, (MCPConnectionError, MCPAuthenticationError, MCPTimeoutError)):
raise
else:
raise MCPConnectionError(f"连接失败: {str(e)}", url=self.config.server_url)
async def disconnect(self):
"""断开连接"""
logger.info("断开 MCP 服务器连接")
self._running = False
await self._cleanup_connection()
async def _cleanup_connection(self):
"""清理连接资源"""
self._connected = False
if self._response:
self._response.close()
self._response = None
if self._session:
await self._session.close()
self._session = None
async def start_listening(self):
"""开始监听事件"""
if not self.config.enable:
logger.warning("MCP SSE 客户端未启用")
return
self._running = True
while self._running:
try:
if not self._connected:
await self.connect()
await self._listen_events()
except (MCPConnectionError, MCPTimeoutError) as e:
logger.error(f"连接错误: {e}")
await self._cleanup_connection()
if self.config.enable_reconnect:
await self._handle_reconnect()
else:
break
except Exception as e:
logger.error(f"监听事件时发生未知错误: {e}", exc_info=True)
await self._cleanup_connection()
if self.config.enable_reconnect:
await self._handle_reconnect()
else:
break
await self._cleanup_connection()
logger.info("停止监听 MCP 事件")
async def _listen_events(self):
"""监听事件流"""
if not self._response:
raise MCPConnectionError("没有活动的连接")
logger.info("开始监听 MCP 事件流")
buffer = ""
async for chunk in self._response.content.iter_chunked(1024):
if not self._running:
break
try:
# 解码数据
data = chunk.decode('utf-8')
buffer += data
# 处理完整的事件
while '\n\n' in buffer:
event_data, buffer = buffer.split('\n\n', 1)
if event_data.strip():
await self._process_event_data(event_data)
except UnicodeDecodeError as e:
logger.error(f"解码事件数据失败: {e}")
continue
except Exception as e:
logger.error(f"处理事件数据失败: {e}", exc_info=True)
continue
async def _process_event_data(self, event_data: str):
"""
处理事件数据
Args:
event_data: 原始事件数据
"""
try:
# 解析 SSE 事件
event = parse_sse_event(event_data)
if not event:
return
# 更新统计信息
self._total_events_received += 1
self._last_event_time = event.timestamp
if event.event_id:
self._last_event_id = event.event_id
# 检查事件订阅
if self.config.subscribed_events:
if event.event_type not in self.config.subscribed_events:
logger.debug(f"跳过未订阅的事件类型: {event.event_type}")
return
# 记录事件日志
if self.config.enable_event_logging:
logger.info(f"收到 MCP 事件: {event.event_type}")
logger.debug(f"事件数据: {event.data}")
# 处理事件
await self.event_handler.handle_event(event)
except Exception as e:
logger.error(f"处理事件失败: {e}", exc_info=True)
raise MCPEventError(f"处理事件失败: {str(e)}")
async def _handle_reconnect(self):
"""处理重连逻辑"""
if not self.config.enable_reconnect:
return
self._reconnect_attempts += 1
# 检查最大重连次数
if (self.config.max_reconnect_attempts > 0 and
self._reconnect_attempts > self.config.max_reconnect_attempts):
raise MCPReconnectError(
"超过最大重连次数",
attempts=self._reconnect_attempts,
max_attempts=self.config.max_reconnect_attempts
)
# 计算重连延迟(指数退避)
delay = min(
self.config.initial_reconnect_delay * (
self.config.reconnect_backoff_factor ** (self._reconnect_attempts - 1)
),
self.config.max_reconnect_delay
)
logger.info(f"{self._reconnect_attempts} 次重连尝试,延迟 {delay:.2f}")
await asyncio.sleep(delay)
def stop(self):
"""停止客户端"""
logger.info("停止 MCP SSE 客户端")
self._running = False
def is_connected(self) -> bool:
"""检查是否已连接"""
return self._connected
def is_running(self) -> bool:
"""检查是否正在运行"""
return self._running
def get_stats(self) -> Dict[str, Any]:
"""
获取客户端统计信息
Returns:
统计信息字典
"""
stats = {
"connected": self._connected,
"running": self._running,
"reconnect_attempts": self._reconnect_attempts,
"total_events_received": self._total_events_received,
"connection_start_time": self._connection_start_time,
"last_event_time": self._last_event_time,
"last_event_id": self._last_event_id,
}
if self._connection_start_time:
stats["uptime_seconds"] = (datetime.now() - self._connection_start_time).total_seconds()
# 添加事件处理器统计
stats["event_handlers"] = self.event_handler.get_handler_count()
return stats
def register_event_handler(self, event_type: str, handler: Callable[[MCPEvent], None]):
"""
注册事件处理器
Args:
event_type: 事件类型
handler: 事件处理函数
"""
self.event_handler.register_handler(event_type, handler)
def register_global_event_handler(self, handler: Callable[[MCPEvent], None]):
"""
注册全局事件处理器
Args:
handler: 事件处理函数
"""
self.event_handler.register_global_handler(handler)
def unregister_event_handler(self, event_type: str, handler: Callable[[MCPEvent], None]):
"""
取消注册事件处理器
Args:
event_type: 事件类型
handler: 事件处理函数
"""
self.event_handler.unregister_handler(event_type, handler)
def get_recent_events(self, count: int = 10):
"""
获取最近的事件
Args:
count: 获取的事件数量
Returns:
最近的事件列表
"""
return self.event_handler.get_recent_events(count)