feat: 更新代码中的日志信息和注释为中文,增强可读性,修改适配器注册流程
This commit is contained in:
@@ -98,7 +98,7 @@ class AdapterBase:
|
||||
try:
|
||||
self.core_sink.set_outgoing_handler(self._on_outgoing_from_core)
|
||||
except Exception:
|
||||
logger.exception("Failed to register outgoing handler on core sink")
|
||||
logger.exception("注册 outgoing 处理程序到核心接收器失败")
|
||||
if isinstance(self._transport_config, WebSocketAdapterOptions):
|
||||
await self._start_ws_transport(self._transport_config)
|
||||
elif isinstance(self._transport_config, HttpAdapterOptions):
|
||||
@@ -112,12 +112,12 @@ class AdapterBase:
|
||||
try:
|
||||
remove(self._on_outgoing_from_core)
|
||||
except Exception:
|
||||
logger.exception("Failed to detach outgoing handler on core sink")
|
||||
logger.exception("从核心接收器分离 outgoing 处理程序失败")
|
||||
elif hasattr(self.core_sink, "set_outgoing_handler"):
|
||||
try:
|
||||
self.core_sink.set_outgoing_handler(None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
logger.exception("Failed to detach outgoing handler on core sink")
|
||||
logger.exception("从核心接收器分离 outgoing 处理程序失败")
|
||||
if self._ws_task:
|
||||
self._ws_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
@@ -135,12 +135,12 @@ class AdapterBase:
|
||||
|
||||
async def on_platform_message(self, raw: Any) -> None:
|
||||
"""处理平台下发的单条消息并交给核心。"""
|
||||
envelope = await _maybe_await(self.from_platform_message(raw))
|
||||
envelope = await self.from_platform_message(raw)
|
||||
await self.core_sink.send(envelope)
|
||||
|
||||
async def on_platform_messages(self, raw_messages: list[Any]) -> None:
|
||||
"""批量推送入口,内部自动批量或逐条送入核心。"""
|
||||
envelopes = [await _maybe_await(self.from_platform_message(raw)) for raw in raw_messages]
|
||||
envelopes = [await self.from_platform_message(raw) for raw in raw_messages]
|
||||
await _send_many(self.core_sink, envelopes)
|
||||
|
||||
async def send_to_platform(self, envelope: MessageEnvelope) -> None:
|
||||
@@ -159,7 +159,7 @@ class AdapterBase:
|
||||
return
|
||||
await self._send_platform_message(envelope)
|
||||
|
||||
def from_platform_message(self, raw: Any) -> MessageEnvelope | Awaitable[MessageEnvelope]:
|
||||
async def from_platform_message(self, raw: Any) -> MessageEnvelope:
|
||||
"""子类必须实现:将平台原始结构转换为统一 MessageEnvelope。"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -291,7 +291,7 @@ class ProcessCoreSink(CoreSink):
|
||||
await self.send(message)
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||
logger.debug("ProcessCoreSink.push_outgoing called in child; ignored")
|
||||
logger.debug("ProcessCoreSink.push_outgoing 在子进程中调用; 被忽略")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self._closed:
|
||||
@@ -318,9 +318,9 @@ class ProcessCoreSink(CoreSink):
|
||||
try:
|
||||
await self._outgoing_handler(envelope)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to handle outgoing envelope in ProcessCoreSink")
|
||||
logger.exception("处理 ProcessCoreSink 中的 outgoing 信封失败")
|
||||
else:
|
||||
logger.debug("ProcessCoreSink received unknown payload: %r", item)
|
||||
logger.debug(f"ProcessCoreSink 接受到未知负载: {item}")
|
||||
|
||||
|
||||
class ProcessCoreSinkServer:
|
||||
@@ -362,9 +362,9 @@ class ProcessCoreSinkServer:
|
||||
try:
|
||||
await self._core_handler(envelope)
|
||||
except Exception: # pragma: no cover
|
||||
logger.exception("Failed to dispatch incoming envelope from %s", self._name)
|
||||
logger.exception(f"处理来自 {self._name} 的 incoming 信封时失败")
|
||||
else:
|
||||
logger.debug("ProcessCoreSinkServer ignored unknown payload from %s: %r", self._name, item)
|
||||
logger.debug(f"ProcessCoreSinkServer 忽略来自 {self._name} 的未知负载: {item}")
|
||||
|
||||
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
|
||||
await asyncio.to_thread(self._outgoing_queue.put, {"kind": "outgoing", "payload": envelope})
|
||||
@@ -390,12 +390,6 @@ async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) ->
|
||||
await sink.send(env)
|
||||
|
||||
|
||||
async def _maybe_await(result: Any) -> Any:
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
class BatchDispatcher:
|
||||
"""
|
||||
批量消息分发器,负责将消息批量发送到核心 sink。
|
||||
|
||||
@@ -18,6 +18,16 @@ DisconnectCallback = Callable[[str, str], Awaitable[None] | None]
|
||||
|
||||
|
||||
def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any:
|
||||
"""
|
||||
将原始字节数据附加到消息负载中
|
||||
|
||||
Args:
|
||||
payload: 消息负载
|
||||
raw_bytes: 原始字节数据
|
||||
|
||||
Returns:
|
||||
附加了原始数据的消息负载
|
||||
"""
|
||||
if isinstance(payload, dict):
|
||||
payload.setdefault("raw_bytes", raw_bytes)
|
||||
elif isinstance(payload, list):
|
||||
@@ -28,6 +38,16 @@ def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any:
|
||||
|
||||
|
||||
def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[str | bytes, bool]:
|
||||
"""
|
||||
编码消息用于 WebSocket 发送
|
||||
|
||||
Args:
|
||||
message: 要发送的消息
|
||||
use_raw_bytes: 是否使用原始字节数据
|
||||
|
||||
Returns:
|
||||
(编码后的数据, 是否为二进制格式)
|
||||
"""
|
||||
if isinstance(message, (bytes, bytearray)):
|
||||
return bytes(message), True
|
||||
if use_raw_bytes and isinstance(message, dict):
|
||||
@@ -44,15 +64,29 @@ def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[s
|
||||
|
||||
|
||||
class BaseMessageHandler:
|
||||
"""基础消息处理器,提供消息处理和任务管理功能"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.message_handlers: list[MessageHandler] = []
|
||||
self.background_tasks: set[asyncio.Task] = set()
|
||||
|
||||
def register_message_handler(self, handler: MessageHandler) -> None:
|
||||
"""
|
||||
注册消息处理器
|
||||
|
||||
Args:
|
||||
handler: 消息处理函数
|
||||
"""
|
||||
if handler not in self.message_handlers:
|
||||
self.message_handlers.append(handler)
|
||||
|
||||
async def process_message(self, message: MessagePayload) -> None:
|
||||
"""
|
||||
处理单条消息,并发执行所有注册的处理器
|
||||
|
||||
Args:
|
||||
message: 消息负载
|
||||
"""
|
||||
tasks: list[asyncio.Task] = []
|
||||
for handler in self.message_handlers:
|
||||
try:
|
||||
@@ -63,7 +97,7 @@ class BaseMessageHandler:
|
||||
self.background_tasks.add(task)
|
||||
task.add_done_callback(self.background_tasks.discard)
|
||||
except Exception: # pragma: no cover - logging only
|
||||
logging.getLogger("mofox_bus.server").exception("Failed to handle message")
|
||||
logging.getLogger("mofox_bus.server").exception("消息处理失败")
|
||||
if tasks:
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
@@ -236,8 +270,8 @@ class MessageServer(BaseMessageHandler):
|
||||
) -> None:
|
||||
ws = self._platform_connections.get(platform)
|
||||
if ws is None:
|
||||
raise RuntimeError(f"No active connection for platform {platform}")
|
||||
payload: MessagePayload | bytes = message.to_dict() if isinstance(message, MessageBase) else message
|
||||
raise RuntimeError(f"平台 {platform} 没有活跃的连接")
|
||||
payload: MessagePayload | bytes = message
|
||||
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
|
||||
if is_binary:
|
||||
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
|
||||
@@ -439,7 +473,7 @@ class MessageClient(BaseMessageHandler):
|
||||
logging.getLogger("mofox_bus.client").exception("Disconnect callback failed")
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval)
|
||||
self._logger.info(f"WebSocket 连接断开, 正在 {self._reconnect_interval:.1f} 秒后重试")
|
||||
await asyncio.sleep(self._reconnect_interval)
|
||||
await self._connect_once()
|
||||
|
||||
|
||||
@@ -9,9 +9,9 @@ from .types import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPay
|
||||
|
||||
class MessageBuilder:
|
||||
"""
|
||||
Fluent helper to build MessageEnvelope safely with type hints.
|
||||
流式构建 MessageEnvelope 的助手工具,提供类型安全的构建方法。
|
||||
|
||||
Example:
|
||||
使用示例:
|
||||
msg = (
|
||||
MessageBuilder()
|
||||
.text("Hello")
|
||||
@@ -85,9 +85,10 @@ class MessageBuilder:
|
||||
return self
|
||||
|
||||
def build(self) -> MessageEnvelope:
|
||||
# message_info defaults
|
||||
"""构建最终的消息信封"""
|
||||
# 设置 message_info 默认值
|
||||
if not self._segments:
|
||||
raise ValueError("message_segment is required, add at least one segment before build()")
|
||||
raise ValueError("需要至少添加一个消息段才能构建消息")
|
||||
if self._message_id is None:
|
||||
self._message_id = str(uuid.uuid4())
|
||||
info = dict(self._message_info)
|
||||
|
||||
@@ -63,7 +63,7 @@ def loads_messages(data: bytes | str) -> List[MessageEnvelope]:
|
||||
obj = _loads(data)
|
||||
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
|
||||
if version != DEFAULT_SCHEMA_VERSION:
|
||||
raise ValueError(f"Unsupported schema_version={version}")
|
||||
raise ValueError(f"不支持的 schema_version={version}")
|
||||
return [_upgrade_schema_if_needed(item) for item in obj.get("items", [])]
|
||||
|
||||
|
||||
@@ -74,7 +74,7 @@ def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope:
|
||||
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
|
||||
if version == DEFAULT_SCHEMA_VERSION:
|
||||
return obj # type: ignore[return-value]
|
||||
raise ValueError(f"Unsupported schema_version={version}")
|
||||
raise ValueError(f"不支持的 schema_version={version}")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -14,15 +14,18 @@ logger = logging.getLogger("mofox_bus.router")
|
||||
|
||||
@dataclass
|
||||
class TargetConfig:
|
||||
"""路由目标配置,包含连接信息和认证配置"""
|
||||
url: str
|
||||
token: str | None = None
|
||||
ssl_verify: str | None = None
|
||||
|
||||
def to_dict(self) -> Dict[str, str | None]:
|
||||
"""转换为字典格式"""
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, str | None]) -> "TargetConfig":
|
||||
"""从字典创建配置对象"""
|
||||
return cls(
|
||||
url=data.get("url", ""),
|
||||
token=data.get("token"),
|
||||
@@ -32,13 +35,16 @@ class TargetConfig:
|
||||
|
||||
@dataclass
|
||||
class RouteConfig:
|
||||
"""路由配置,包含多个平台的路由目标"""
|
||||
route_config: Dict[str, TargetConfig]
|
||||
|
||||
def to_dict(self) -> Dict[str, Dict[str, str | None]]:
|
||||
"""转换为字典格式"""
|
||||
return {"route_config": {k: v.to_dict() for k, v in self.route_config.items()}}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Dict[str, str | None]]) -> "RouteConfig":
|
||||
"""从字典创建路由配置对象"""
|
||||
cfg = {
|
||||
platform: TargetConfig.from_dict(target)
|
||||
for platform, target in data.get("route_config", {}).items()
|
||||
@@ -47,7 +53,16 @@ class RouteConfig:
|
||||
|
||||
|
||||
class Router:
|
||||
"""消息路由器,负责管理多个平台的消息客户端连接"""
|
||||
|
||||
def __init__(self, config: RouteConfig, custom_logger: logging.Logger | None = None) -> None:
|
||||
"""
|
||||
初始化路由器
|
||||
|
||||
Args:
|
||||
config: 路由配置
|
||||
custom_logger: 自定义日志记录器
|
||||
"""
|
||||
if custom_logger:
|
||||
logger.handlers = custom_logger.handlers
|
||||
self.config = config
|
||||
@@ -58,12 +73,22 @@ class Router:
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
|
||||
async def connect(self, platform: str) -> None:
|
||||
"""
|
||||
连接到指定平台
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
|
||||
Raises:
|
||||
ValueError: 未知平台
|
||||
NotImplementedError: 不支持的模式
|
||||
"""
|
||||
if platform not in self.config.route_config:
|
||||
raise ValueError(f"Unknown platform {platform}")
|
||||
raise ValueError(f"未知平台: {platform}")
|
||||
target = self.config.route_config[platform]
|
||||
mode = "tcp" if target.url.startswith(("tcp://", "tcps://")) else "ws"
|
||||
if mode != "ws":
|
||||
raise NotImplementedError("TCP mode is not implemented yet")
|
||||
raise NotImplementedError("TCP 模式暂未实现")
|
||||
client = MessageClient(mode="ws")
|
||||
client.set_disconnect_callback(self._handle_client_disconnect)
|
||||
await client.connect(
|
||||
@@ -84,6 +109,7 @@ class Router:
|
||||
client.register_message_handler(handler)
|
||||
|
||||
async def run(self) -> None:
|
||||
"""启动路由器,连接所有配置的平台并开始运行"""
|
||||
self._running = True
|
||||
self._stop_event = asyncio.Event()
|
||||
for platform in self.config.route_config:
|
||||
@@ -98,6 +124,12 @@ class Router:
|
||||
raise
|
||||
|
||||
async def remove_platform(self, platform: str) -> None:
|
||||
"""
|
||||
移除指定平台的连接
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
"""
|
||||
if platform in self._client_tasks:
|
||||
task = self._client_tasks.pop(platform)
|
||||
task.cancel()
|
||||
@@ -108,7 +140,14 @@ class Router:
|
||||
await client.stop()
|
||||
|
||||
async def _handle_client_disconnect(self, platform: str, reason: str) -> None:
|
||||
logger.info("Client for %s disconnected: %s (auto-reconnect handled by client)", platform, reason)
|
||||
"""
|
||||
处理客户端断开连接
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
reason: 断开原因
|
||||
"""
|
||||
logger.info(f"平台 {platform} 的客户端断开连接: {reason} (客户端将自动重连)")
|
||||
task = self._client_tasks.get(platform)
|
||||
if task is not None and not task.done():
|
||||
return
|
||||
@@ -117,6 +156,7 @@ class Router:
|
||||
self._start_client_task(platform, client)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""停止路由器,关闭所有连接"""
|
||||
self._running = False
|
||||
if self._stop_event:
|
||||
self._stop_event.set()
|
||||
@@ -125,23 +165,46 @@ class Router:
|
||||
self.clients.clear()
|
||||
|
||||
def _start_client_task(self, platform: str, client: MessageClient) -> None:
|
||||
"""
|
||||
启动客户端任务
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
client: 消息客户端
|
||||
"""
|
||||
task = asyncio.create_task(client.run())
|
||||
task.add_done_callback(lambda t, plat=platform: asyncio.create_task(self._restart_if_needed(plat, t)))
|
||||
self._client_tasks[platform] = task
|
||||
|
||||
async def _restart_if_needed(self, platform: str, task: asyncio.Task) -> None:
|
||||
"""
|
||||
必要时重启客户端任务
|
||||
|
||||
Args:
|
||||
platform: 平台标识
|
||||
task: 已完成的任务
|
||||
"""
|
||||
if not self._running:
|
||||
return
|
||||
if task.cancelled():
|
||||
return
|
||||
exc = task.exception()
|
||||
if exc:
|
||||
logger.warning("Client task for %s ended with exception: %s", platform, exc)
|
||||
logger.warning(f"平台 {platform} 的客户端任务异常结束: {exc}")
|
||||
client = self.clients.get(platform)
|
||||
if client:
|
||||
self._start_client_task(platform, client)
|
||||
|
||||
def get_target_url(self, message: MessageEnvelope) -> Optional[str]:
|
||||
"""
|
||||
根据消息获取目标 URL
|
||||
|
||||
Args:
|
||||
message: 消息信封
|
||||
|
||||
Returns:
|
||||
目标 URL 或 None
|
||||
"""
|
||||
platform = message.get("message_info", {}).get("platform")
|
||||
if not platform:
|
||||
return None
|
||||
@@ -149,24 +212,48 @@ class Router:
|
||||
return target.url if target else None
|
||||
|
||||
async def send_message(self, message: MessageEnvelope):
|
||||
"""
|
||||
发送消息到指定平台
|
||||
|
||||
Args:
|
||||
message: 消息信封
|
||||
|
||||
Raises:
|
||||
ValueError: 缺少平台信息
|
||||
RuntimeError: 未找到对应平台的客户端
|
||||
"""
|
||||
platform = message.get("message_info", {}).get("platform")
|
||||
if not platform:
|
||||
raise ValueError("message_info.platform is required")
|
||||
raise ValueError("消息中缺少必需的 message_info.platform 字段")
|
||||
client = self.clients.get(platform)
|
||||
if client is None:
|
||||
raise RuntimeError(f"No client connected for platform {platform}")
|
||||
raise RuntimeError(f"平台 {platform} 没有已连接的客户端")
|
||||
return await client.send_message(message)
|
||||
|
||||
async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None:
|
||||
"""
|
||||
更新路由配置
|
||||
|
||||
Args:
|
||||
config_data: 新的配置数据
|
||||
"""
|
||||
new_config = RouteConfig.from_dict(config_data)
|
||||
await self._adjust_connections(new_config)
|
||||
self.config = new_config
|
||||
|
||||
async def _adjust_connections(self, new_config: RouteConfig) -> None:
|
||||
"""
|
||||
调整连接以匹配新配置
|
||||
|
||||
Args:
|
||||
new_config: 新的路由配置
|
||||
"""
|
||||
current = set(self.config.route_config.keys())
|
||||
updated = set(new_config.route_config.keys())
|
||||
# 移除不再存在的平台
|
||||
for platform in current - updated:
|
||||
await self.remove_platform(platform)
|
||||
# 添加或更新平台
|
||||
for platform in updated:
|
||||
if platform not in current:
|
||||
await self.connect(platform)
|
||||
|
||||
@@ -25,13 +25,14 @@ class MessageProcessingError(RuntimeError):
|
||||
|
||||
def __init__(self, message: MessageEnvelope, original: BaseException):
|
||||
detail = message.get("id", "<unknown>")
|
||||
super().__init__(f"Failed to handle message {detail}: {original}")
|
||||
super().__init__(f"处理消息 {detail} 时出错: {original}")
|
||||
self.message_envelope = message
|
||||
self.original = original
|
||||
|
||||
|
||||
@dataclass
|
||||
class MessageRoute:
|
||||
"""消息路由配置,包含匹配条件和处理函数"""
|
||||
predicate: Predicate
|
||||
handler: MessageHandler
|
||||
name: str | None = None
|
||||
@@ -41,7 +42,7 @@ class MessageRoute:
|
||||
|
||||
class MessageRuntime:
|
||||
"""
|
||||
负责调度消息路由、执行前后 hook 以及批量处理。
|
||||
消息运行时环境,负责调度消息路由、执行前后处理钩子以及批量处理消息
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -64,6 +65,16 @@ class MessageRuntime:
|
||||
message_type: str | None = None,
|
||||
event_types: Iterable[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
添加消息路由
|
||||
|
||||
Args:
|
||||
predicate: 路由匹配条件
|
||||
handler: 消息处理函数
|
||||
name: 路由名称(可选)
|
||||
message_type: 消息类型(可选)
|
||||
event_types: 事件类型列表(可选)
|
||||
"""
|
||||
with self._lock:
|
||||
route = MessageRoute(
|
||||
predicate=predicate,
|
||||
@@ -245,14 +256,8 @@ class MessageRuntime:
|
||||
return wrapped
|
||||
|
||||
|
||||
async def _maybe_await(result):
|
||||
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
|
||||
return await result
|
||||
return result
|
||||
|
||||
|
||||
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
|
||||
"""Support sync/async callables with optional thread offloading."""
|
||||
"""支持 sync/async 调用,并可选择在线程中执行。"""
|
||||
if inspect.iscoroutinefunction(func):
|
||||
return await func(*args)
|
||||
if prefer_thread:
|
||||
|
||||
@@ -11,7 +11,7 @@ from ..types import MessageEnvelope
|
||||
|
||||
class HttpMessageClient:
|
||||
"""
|
||||
面向消息批量传输的 HTTP 客户端封装。
|
||||
面向消息批量传输的 HTTP 客户端封装
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -39,14 +39,14 @@ class HttpMessageClient:
|
||||
session = await self._ensure_session()
|
||||
url = f"{self._base_url}{path}"
|
||||
payload = dumps_messages(messages)
|
||||
self._logger.debug("Sending %d message(s) -> %s", len(messages), url)
|
||||
self._logger.debug(f"正在发送 {len(messages)} 条消息 -> {url}")
|
||||
async with session.post(url, data=payload, timeout=self._timeout) as resp:
|
||||
resp.raise_for_status()
|
||||
if not expect_reply:
|
||||
return None
|
||||
raw = await resp.read()
|
||||
replies = loads_messages(raw)
|
||||
self._logger.debug("Received %d reply message(s)", len(replies))
|
||||
self._logger.debug(f"接收到 {len(replies)} 条回复消息")
|
||||
return replies
|
||||
|
||||
async def close(self) -> None:
|
||||
|
||||
@@ -13,7 +13,7 @@ MessageHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelop
|
||||
|
||||
class HttpMessageServer:
|
||||
"""
|
||||
轻量级 HTTP 消息入口。可独立运行,也可挂载到现有 FastAPI / aiohttp 应用下。
|
||||
轻量级 HTTP 消息入口,可独立运行,也可挂载到现有 FastAPI / aiohttp 应用下
|
||||
"""
|
||||
|
||||
def __init__(self, handler: MessageHandler, *, path: str = "/messages") -> None:
|
||||
@@ -27,10 +27,10 @@ class HttpMessageServer:
|
||||
try:
|
||||
raw = await request.read()
|
||||
envelopes = loads_messages(raw)
|
||||
self._logger.debug("Received %d message(s)", len(envelopes))
|
||||
self._logger.debug(f"接收到 {len(envelopes)} 条消息")
|
||||
except Exception as exc: # pragma: no cover - network errors are integration tested
|
||||
self._logger.exception("Failed to parse incoming messages: %s", exc)
|
||||
raise web.HTTPBadRequest(reason=f"Invalid payload: {exc}") from exc
|
||||
self._logger.exception(f"解析请求失败: {exc}")
|
||||
raise web.HTTPBadRequest(reason=f"无效的负载: {exc}") from exc
|
||||
|
||||
result = await self._handler(envelopes)
|
||||
if result is None:
|
||||
|
||||
@@ -14,7 +14,7 @@ IncomingHandler = Callable[[MessageEnvelope], Awaitable[None]]
|
||||
|
||||
class WsMessageClient:
|
||||
"""
|
||||
管理 WebSocket 连接,提供 send/receive API,并在后台读取消息。
|
||||
管理 WebSocket 连接,提供 send/receive API,并在后台读取消息
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -42,7 +42,7 @@ class WsMessageClient:
|
||||
async def _connect_once(self) -> None:
|
||||
assert self._session is not None
|
||||
self._ws = await self._session.ws_connect(self._url)
|
||||
self._logger.info("Connected to %s", self._url)
|
||||
self._logger.info(f"已连接到 {self._url}")
|
||||
self._receive_task = asyncio.create_task(self._receive_loop())
|
||||
|
||||
async def send_messages(self, messages: Sequence[MessageEnvelope]) -> None:
|
||||
@@ -76,7 +76,7 @@ class WsMessageClient:
|
||||
if self._handler is not None:
|
||||
await self._handler(env)
|
||||
elif msg.type == aiohttp.WSMsgType.ERROR:
|
||||
self._logger.warning("WebSocket error: %s", msg.data)
|
||||
self._logger.warning(f"WebSocket 错误: {msg.data}")
|
||||
break
|
||||
except asyncio.CancelledError: # pragma: no cover - cancellation path
|
||||
return
|
||||
@@ -85,7 +85,7 @@ class WsMessageClient:
|
||||
await self._reconnect()
|
||||
|
||||
async def _reconnect(self) -> None:
|
||||
self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval)
|
||||
self._logger.info(f"WebSocket 断开, 正在 {self._reconnect_interval:.1f} 秒后重试")
|
||||
await asyncio.sleep(self._reconnect_interval)
|
||||
await self._connect_once()
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ WsMessageHandler = Callable[[MessageEnvelope], Awaitable[None]]
|
||||
|
||||
class WsMessageServer:
|
||||
"""
|
||||
封装 WebSocket 服务端逻辑,负责接收消息并广播响应。
|
||||
封装 WebSocket 服务端逻辑,负责接收消息并广播响应
|
||||
"""
|
||||
|
||||
def __init__(self, handler: WsMessageHandler, *, path: str = "/ws") -> None:
|
||||
@@ -30,7 +30,7 @@ class WsMessageServer:
|
||||
async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse:
|
||||
ws = web.WebSocketResponse()
|
||||
await ws.prepare(request)
|
||||
self._logger.info("WebSocket connection opened: %s", request.remote)
|
||||
self._logger.info(f"WebSocket 连接打开: {request.remote}")
|
||||
|
||||
async with self._track_connection(ws):
|
||||
async for message in ws:
|
||||
@@ -39,10 +39,10 @@ class WsMessageServer:
|
||||
for env in envelopes:
|
||||
await self._handler(env)
|
||||
elif message.type == WSMsgType.ERROR:
|
||||
self._logger.warning("WebSocket connection error: %s", ws.exception())
|
||||
self._logger.warning(f"WebSocket 连接错误: {ws.exception()}")
|
||||
break
|
||||
|
||||
self._logger.info("WebSocket connection closed: %s", request.remote)
|
||||
self._logger.info(f"WebSocket 连接关闭: {request.remote}")
|
||||
return ws
|
||||
|
||||
@asynccontextmanager
|
||||
|
||||
@@ -12,7 +12,7 @@ from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||
|
||||
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreSink, MessageEnvelope
|
||||
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreSink, MessageEnvelope, ProcessCoreSink
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.plugin_system.base.base_plugin import BasePlugin
|
||||
@@ -62,6 +62,28 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
|
||||
self._config: Dict[str, Any] = {}
|
||||
self._health_check_task: Optional[asyncio.Task] = None
|
||||
self._running = False
|
||||
# 标记是否在子进程中运行(由核心管理器传入 ProcessCoreSink 时自动生效)
|
||||
self._is_subprocess = isinstance(core_sink, ProcessCoreSink)
|
||||
|
||||
@classmethod
|
||||
def from_process_queues(
|
||||
cls,
|
||||
to_core_queue,
|
||||
from_core_queue,
|
||||
plugin: Optional["BasePlugin"] = None,
|
||||
**kwargs: Any,
|
||||
) -> "BaseAdapter":
|
||||
"""
|
||||
子进程入口便捷构造:使用 multiprocessing.Queue 与核心建立 ProcessCoreSink 通讯。
|
||||
|
||||
Args:
|
||||
to_core_queue: 发往核心的 multiprocessing.Queue
|
||||
from_core_queue: 核心回传的 multiprocessing.Queue
|
||||
plugin: 可选插件实例
|
||||
**kwargs: 透传给适配器构造函数
|
||||
"""
|
||||
sink = ProcessCoreSink(to_core_queue=to_core_queue, from_core_queue=from_core_queue)
|
||||
return cls(core_sink=sink, plugin=plugin, **kwargs)
|
||||
|
||||
@property
|
||||
def config(self) -> Dict[str, Any]:
|
||||
|
||||
@@ -61,24 +61,24 @@ def _adapter_process_entry(
|
||||
|
||||
|
||||
class AdapterProcess:
|
||||
"""适配器子进程包装器,负责适配器子进程的启动和生命周期管理"""
|
||||
"""适配器子进程封装:管理子进程的生命周期与通信桥接"""
|
||||
|
||||
def __init__(self, adapter: "BaseAdapter", core_sink) -> None:
|
||||
self.adapter = adapter
|
||||
self.adapter_name = adapter.adapter_name
|
||||
def __init__(self, adapter_cls: "type[BaseAdapter]", plugin, core_sink) -> None:
|
||||
self.adapter_cls = adapter_cls
|
||||
self.adapter_name = adapter_cls.adapter_name
|
||||
self.plugin = plugin
|
||||
self.process: mp.Process | None = None
|
||||
self._ctx = mp.get_context("spawn")
|
||||
self._incoming_queue: mp.Queue = self._ctx.Queue()
|
||||
self._outgoing_queue: mp.Queue = self._ctx.Queue()
|
||||
self._bridge: ProcessCoreSinkServer | None = None
|
||||
self._core_sink = core_sink
|
||||
self._adapter_path: tuple[str, str] = (adapter.__class__.__module__, adapter.__class__.__name__)
|
||||
self._plugin_info = self._extract_plugin_info(adapter)
|
||||
self._adapter_path: tuple[str, str] = (adapter_cls.__module__, adapter_cls.__name__)
|
||||
self._plugin_info = self._extract_plugin_info(plugin)
|
||||
self._outgoing_handler = None
|
||||
|
||||
@staticmethod
|
||||
def _extract_plugin_info(adapter: "BaseAdapter") -> dict | None:
|
||||
plugin = getattr(adapter, "plugin", None)
|
||||
def _extract_plugin_info(plugin) -> dict | None:
|
||||
if plugin is None:
|
||||
return None
|
||||
return {
|
||||
@@ -158,57 +158,50 @@ class AdapterManager:
|
||||
"""适配器管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._adapters: Dict[str, BaseAdapter] = {}
|
||||
# 注册信息:name -> (adapter class, plugin instance | None)
|
||||
self._adapter_defs: Dict[str, tuple[type[BaseAdapter], object | None]] = {}
|
||||
self._adapter_processes: Dict[str, AdapterProcess] = {}
|
||||
self._in_process_adapters: Dict[str, BaseAdapter] = {}
|
||||
|
||||
def register_adapter(self, adapter: BaseAdapter) -> None:
|
||||
def register_adapter(self, adapter_cls: type[BaseAdapter], plugin=None) -> None:
|
||||
"""
|
||||
注册适配器
|
||||
|
||||
|
||||
Args:
|
||||
adapter: 要注册的适配器实例
|
||||
adapter_cls: 适配器类
|
||||
plugin: 可选 Plugin 实例
|
||||
"""
|
||||
adapter_name = adapter.adapter_name
|
||||
adapter_name = getattr(adapter_cls, 'adapter_name', adapter_cls.__name__)
|
||||
|
||||
if adapter_name in self._adapters:
|
||||
logger.warning(f"适配器 {adapter_name} 已经注册,将被覆盖")
|
||||
if adapter_name in self._adapter_defs:
|
||||
logger.warning(f"适配器 {adapter_name} 已注册,已覆盖")
|
||||
|
||||
self._adapters[adapter_name] = adapter
|
||||
logger.info(f"已注册适配器: {adapter_name} v{adapter.adapter_version}")
|
||||
self._adapter_defs[adapter_name] = (adapter_cls, plugin)
|
||||
adapter_version = getattr(adapter_cls, 'adapter_version', 'unknown')
|
||||
logger.info(f"注册适配器: {adapter_name} v{adapter_version}")
|
||||
|
||||
async def start_adapter(self, adapter_name: str) -> bool:
|
||||
"""
|
||||
启动指定的适配器
|
||||
|
||||
Args:
|
||||
adapter_name: 适配器名称
|
||||
|
||||
Returns:
|
||||
bool: 是否成功启动
|
||||
"""
|
||||
adapter = self._adapters.get(adapter_name)
|
||||
if not adapter:
|
||||
"""启动指定适配器"""
|
||||
definition = self._adapter_defs.get(adapter_name)
|
||||
if not definition:
|
||||
logger.error(f"适配器 {adapter_name} 未注册")
|
||||
return False
|
||||
adapter_cls, plugin = definition
|
||||
run_in_subprocess = getattr(adapter_cls, "run_in_subprocess", False)
|
||||
|
||||
# 检查是否需要在子进程中运行
|
||||
if adapter.run_in_subprocess:
|
||||
return await self._start_adapter_subprocess(adapter)
|
||||
else:
|
||||
return await self._start_adapter_in_process(adapter)
|
||||
if run_in_subprocess:
|
||||
return await self._start_adapter_subprocess(adapter_name, adapter_cls, plugin)
|
||||
return await self._start_adapter_in_process(adapter_name, adapter_cls, plugin)
|
||||
|
||||
|
||||
async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool:
|
||||
"""启动适配器子进程"""
|
||||
adapter_name = adapter.adapter_name
|
||||
async def _start_adapter_subprocess(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool:
|
||||
"""在子进程中启动适配器"""
|
||||
try:
|
||||
core_sink = get_core_sink()
|
||||
except Exception as e:
|
||||
logger.error(f"无法获取 core_sink,启动适配器子进程 {adapter_name} 失败: {e}", exc_info=True)
|
||||
logger.error(f"无法获取 core_sink,启动子进程 {adapter_name} 失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
adapter_process = AdapterProcess(adapter, core_sink)
|
||||
adapter_process = AdapterProcess(adapter_cls, plugin, core_sink)
|
||||
success = await adapter_process.start()
|
||||
|
||||
if success:
|
||||
@@ -216,17 +209,17 @@ class AdapterManager:
|
||||
|
||||
return success
|
||||
|
||||
async def _start_adapter_in_process(self, adapter: BaseAdapter) -> bool:
|
||||
"""在主进程中启动适配器"""
|
||||
adapter_name = adapter.adapter_name
|
||||
|
||||
async def _start_adapter_in_process(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool:
|
||||
"""在当前进程中启动适配器"""
|
||||
try:
|
||||
core_sink = get_core_sink()
|
||||
adapter = adapter_cls(core_sink, plugin=plugin) # type: ignore[call-arg]
|
||||
await adapter.start()
|
||||
self._in_process_adapters[adapter_name] = adapter
|
||||
logger.info(f"适配器 {adapter_name} 已在主进程中启动")
|
||||
logger.info(f"适配器 {adapter_name} 已在当前进程启动")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"在主进程中启动适配器 {adapter_name} 失败: {e}", exc_info=True)
|
||||
logger.error(f"启动适配器 {adapter_name} 失败: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def stop_adapter(self, adapter_name: str) -> None:
|
||||
@@ -251,10 +244,10 @@ class AdapterManager:
|
||||
logger.error(f"停止适配器 {adapter_name} 时出错: {e}", exc_info=True)
|
||||
|
||||
async def start_all_adapters(self) -> None:
|
||||
"""启动所有注册的适配器"""
|
||||
logger.info(f"开始启动 {len(self._adapters)} 个适配器...")
|
||||
"""启动所有已注册的适配器"""
|
||||
logger.info(f"开始启动 {len(self._adapter_defs)} 个适配器...")
|
||||
|
||||
for adapter_name in list(self._adapters.keys()):
|
||||
for adapter_name in list(self._adapter_defs.keys()):
|
||||
await self.start_adapter(adapter_name)
|
||||
|
||||
async def stop_all_adapters(self) -> None:
|
||||
@@ -285,32 +278,26 @@ class AdapterManager:
|
||||
return self._in_process_adapters.get(adapter_name)
|
||||
|
||||
def list_adapters(self) -> Dict[str, Dict[str, any]]:
|
||||
"""
|
||||
列出所有适配器的状态
|
||||
|
||||
Returns:
|
||||
Dict: 适配器状态信息
|
||||
"""
|
||||
"""列出适配器状态"""
|
||||
result = {}
|
||||
|
||||
for adapter_name, adapter in self._adapters.items():
|
||||
for adapter_name, definition in self._adapter_defs.items():
|
||||
adapter_cls, _plugin = definition
|
||||
status = {
|
||||
"name": adapter_name,
|
||||
"version": adapter.adapter_version,
|
||||
"platform": adapter.platform,
|
||||
"run_in_subprocess": adapter.run_in_subprocess,
|
||||
"version": getattr(adapter_cls, "adapter_version", "unknown"),
|
||||
"platform": getattr(adapter_cls, "platform", "unknown"),
|
||||
"run_in_subprocess": getattr(adapter_cls, "run_in_subprocess", False),
|
||||
"running": False,
|
||||
"location": "unknown",
|
||||
}
|
||||
|
||||
# 检查运行状态
|
||||
if adapter_name in self._adapter_processes:
|
||||
process = self._adapter_processes[adapter_name]
|
||||
status["running"] = process.is_running()
|
||||
status["location"] = "subprocess"
|
||||
if process.process:
|
||||
status["pid"] = process.process.pid
|
||||
|
||||
elif adapter_name in self._in_process_adapters:
|
||||
status["running"] = True
|
||||
status["location"] = "in-process"
|
||||
@@ -332,4 +319,4 @@ def get_adapter_manager() -> AdapterManager:
|
||||
return _adapter_manager
|
||||
|
||||
|
||||
__all__ = ["AdapterManager", "AdapterProcess", "get_adapter_manager"]
|
||||
__all__ = ["AdapterManager", "AdapterProcess", "get_adapter_manager"]
|
||||
@@ -224,18 +224,8 @@ class PluginManager:
|
||||
continue
|
||||
|
||||
# 创建适配器实例,传入 core_sink 和 plugin
|
||||
if self._core_sink is not None:
|
||||
adapter_instance = adapter_class(self._core_sink, plugin=plugin_instance) # type: ignore
|
||||
else:
|
||||
logger.warning(
|
||||
f"适配器 '{comp_info.name}' 未获得 core_sink,"
|
||||
"请在主程序中调用 plugin_manager.set_core_sink()"
|
||||
)
|
||||
# 尝试无参数创建(某些适配器可能不需要 core_sink)
|
||||
adapter_instance = adapter_class(plugin=plugin_instance) # type: ignore
|
||||
|
||||
# 注册到适配器管理器
|
||||
adapter_manager.register_adapter(adapter_instance) # type: ignore
|
||||
# 注册到适配器管理器,由管理器统一在运行时创建实例
|
||||
adapter_manager.register_adapter(adapter_class, plugin_instance) # type: ignore
|
||||
logger.info(
|
||||
f"插件 '{plugin_name}' 注册了适配器组件: {comp_info.name} "
|
||||
f"(平台: {comp_info.platform})"
|
||||
@@ -708,4 +698,4 @@ class PluginManager:
|
||||
|
||||
|
||||
# 全局插件管理器实例
|
||||
plugin_manager = PluginManager()
|
||||
plugin_manager = PluginManager()
|
||||
Reference in New Issue
Block a user