feat: 更新代码中的日志信息和注释为中文,增强可读性,修改适配器注册流程

This commit is contained in:
Windpicker-owo
2025-11-24 14:35:20 +08:00
parent 36fce6ca98
commit 81a209ed87
13 changed files with 252 additions and 132 deletions

View File

@@ -98,7 +98,7 @@ class AdapterBase:
try:
self.core_sink.set_outgoing_handler(self._on_outgoing_from_core)
except Exception:
logger.exception("Failed to register outgoing handler on core sink")
logger.exception("注册 outgoing 处理程序到核心接收器失败")
if isinstance(self._transport_config, WebSocketAdapterOptions):
await self._start_ws_transport(self._transport_config)
elif isinstance(self._transport_config, HttpAdapterOptions):
@@ -112,12 +112,12 @@ class AdapterBase:
try:
remove(self._on_outgoing_from_core)
except Exception:
logger.exception("Failed to detach outgoing handler on core sink")
logger.exception("从核心接收器分离 outgoing 处理程序失败")
elif hasattr(self.core_sink, "set_outgoing_handler"):
try:
self.core_sink.set_outgoing_handler(None) # type: ignore[arg-type]
except Exception:
logger.exception("Failed to detach outgoing handler on core sink")
logger.exception("从核心接收器分离 outgoing 处理程序失败")
if self._ws_task:
self._ws_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
@@ -135,12 +135,12 @@ class AdapterBase:
async def on_platform_message(self, raw: Any) -> None:
"""处理平台下发的单条消息并交给核心。"""
envelope = await _maybe_await(self.from_platform_message(raw))
envelope = await self.from_platform_message(raw)
await self.core_sink.send(envelope)
async def on_platform_messages(self, raw_messages: list[Any]) -> None:
"""批量推送入口,内部自动批量或逐条送入核心。"""
envelopes = [await _maybe_await(self.from_platform_message(raw)) for raw in raw_messages]
envelopes = [await self.from_platform_message(raw) for raw in raw_messages]
await _send_many(self.core_sink, envelopes)
async def send_to_platform(self, envelope: MessageEnvelope) -> None:
@@ -159,7 +159,7 @@ class AdapterBase:
return
await self._send_platform_message(envelope)
def from_platform_message(self, raw: Any) -> MessageEnvelope | Awaitable[MessageEnvelope]:
async def from_platform_message(self, raw: Any) -> MessageEnvelope:
"""子类必须实现:将平台原始结构转换为统一 MessageEnvelope。"""
raise NotImplementedError
@@ -291,7 +291,7 @@ class ProcessCoreSink(CoreSink):
await self.send(message)
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
logger.debug("ProcessCoreSink.push_outgoing called in child; ignored")
logger.debug("ProcessCoreSink.push_outgoing 在子进程中调用; 被忽略")
async def close(self) -> None:
if self._closed:
@@ -318,9 +318,9 @@ class ProcessCoreSink(CoreSink):
try:
await self._outgoing_handler(envelope)
except Exception: # pragma: no cover
logger.exception("Failed to handle outgoing envelope in ProcessCoreSink")
logger.exception("处理 ProcessCoreSink 中的 outgoing 信封失败")
else:
logger.debug("ProcessCoreSink received unknown payload: %r", item)
logger.debug(f"ProcessCoreSink 接受到未知负载: {item}")
class ProcessCoreSinkServer:
@@ -362,9 +362,9 @@ class ProcessCoreSinkServer:
try:
await self._core_handler(envelope)
except Exception: # pragma: no cover
logger.exception("Failed to dispatch incoming envelope from %s", self._name)
logger.exception(f"处理来自 {self._name} 的 incoming 信封时失败")
else:
logger.debug("ProcessCoreSinkServer ignored unknown payload from %s: %r", self._name, item)
logger.debug(f"ProcessCoreSinkServer 忽略来自 {self._name} 的未知负载: {item}")
async def push_outgoing(self, envelope: MessageEnvelope) -> None:
await asyncio.to_thread(self._outgoing_queue.put, {"kind": "outgoing", "payload": envelope})
@@ -390,12 +390,6 @@ async def _send_many(sink: CoreMessageSink, envelopes: list[MessageEnvelope]) ->
await sink.send(env)
async def _maybe_await(result: Any) -> Any:
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
class BatchDispatcher:
"""
批量消息分发器,负责将消息批量发送到核心 sink。

View File

@@ -18,6 +18,16 @@ DisconnectCallback = Callable[[str, str], Awaitable[None] | None]
def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any:
"""
将原始字节数据附加到消息负载中
Args:
payload: 消息负载
raw_bytes: 原始字节数据
Returns:
附加了原始数据的消息负载
"""
if isinstance(payload, dict):
payload.setdefault("raw_bytes", raw_bytes)
elif isinstance(payload, list):
@@ -28,6 +38,16 @@ def _attach_raw_bytes(payload: Any, raw_bytes: bytes) -> Any:
def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[str | bytes, bool]:
"""
编码消息用于 WebSocket 发送
Args:
message: 要发送的消息
use_raw_bytes: 是否使用原始字节数据
Returns:
(编码后的数据, 是否为二进制格式)
"""
if isinstance(message, (bytes, bytearray)):
return bytes(message), True
if use_raw_bytes and isinstance(message, dict):
@@ -44,15 +64,29 @@ def _encode_for_ws_send(message: Any, *, use_raw_bytes: bool = False) -> tuple[s
class BaseMessageHandler:
"""基础消息处理器,提供消息处理和任务管理功能"""
def __init__(self) -> None:
self.message_handlers: list[MessageHandler] = []
self.background_tasks: set[asyncio.Task] = set()
def register_message_handler(self, handler: MessageHandler) -> None:
"""
注册消息处理器
Args:
handler: 消息处理函数
"""
if handler not in self.message_handlers:
self.message_handlers.append(handler)
async def process_message(self, message: MessagePayload) -> None:
"""
处理单条消息,并发执行所有注册的处理器
Args:
message: 消息负载
"""
tasks: list[asyncio.Task] = []
for handler in self.message_handlers:
try:
@@ -63,7 +97,7 @@ class BaseMessageHandler:
self.background_tasks.add(task)
task.add_done_callback(self.background_tasks.discard)
except Exception: # pragma: no cover - logging only
logging.getLogger("mofox_bus.server").exception("Failed to handle message")
logging.getLogger("mofox_bus.server").exception("消息处理失败")
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
@@ -236,8 +270,8 @@ class MessageServer(BaseMessageHandler):
) -> None:
ws = self._platform_connections.get(platform)
if ws is None:
raise RuntimeError(f"No active connection for platform {platform}")
payload: MessagePayload | bytes = message.to_dict() if isinstance(message, MessageBase) else message
raise RuntimeError(f"平台 {platform} 没有活跃的连接")
payload: MessagePayload | bytes = message
data, is_binary = _encode_for_ws_send(payload, use_raw_bytes=use_raw_bytes)
if is_binary:
await ws.send_bytes(data if isinstance(data, (bytes, bytearray)) else str(data).encode("utf-8"))
@@ -439,7 +473,7 @@ class MessageClient(BaseMessageHandler):
logging.getLogger("mofox_bus.client").exception("Disconnect callback failed")
async def _reconnect(self) -> None:
self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval)
self._logger.info(f"WebSocket 连接断开, 正在 {self._reconnect_interval:.1f} 秒后重试")
await asyncio.sleep(self._reconnect_interval)
await self._connect_once()

View File

@@ -9,9 +9,9 @@ from .types import GroupInfoPayload, MessageEnvelope, MessageInfoPayload, SegPay
class MessageBuilder:
"""
Fluent helper to build MessageEnvelope safely with type hints.
流式构建 MessageEnvelope 的助手工具,提供类型安全的构建方法。
Example:
使用示例:
msg = (
MessageBuilder()
.text("Hello")
@@ -85,9 +85,10 @@ class MessageBuilder:
return self
def build(self) -> MessageEnvelope:
# message_info defaults
"""构建最终的消息信封"""
# 设置 message_info 默认值
if not self._segments:
raise ValueError("message_segment is required, add at least one segment before build()")
raise ValueError("需要至少添加一个消息段才能构建消息")
if self._message_id is None:
self._message_id = str(uuid.uuid4())
info = dict(self._message_info)

View File

@@ -63,7 +63,7 @@ def loads_messages(data: bytes | str) -> List[MessageEnvelope]:
obj = _loads(data)
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
if version != DEFAULT_SCHEMA_VERSION:
raise ValueError(f"Unsupported schema_version={version}")
raise ValueError(f"不支持的 schema_version={version}")
return [_upgrade_schema_if_needed(item) for item in obj.get("items", [])]
@@ -74,7 +74,7 @@ def _upgrade_schema_if_needed(obj: Dict[str, Any]) -> MessageEnvelope:
version = obj.get("schema_version", DEFAULT_SCHEMA_VERSION)
if version == DEFAULT_SCHEMA_VERSION:
return obj # type: ignore[return-value]
raise ValueError(f"Unsupported schema_version={version}")
raise ValueError(f"不支持的 schema_version={version}")

View File

@@ -14,15 +14,18 @@ logger = logging.getLogger("mofox_bus.router")
@dataclass
class TargetConfig:
"""路由目标配置,包含连接信息和认证配置"""
url: str
token: str | None = None
ssl_verify: str | None = None
def to_dict(self) -> Dict[str, str | None]:
"""转换为字典格式"""
return asdict(self)
@classmethod
def from_dict(cls, data: Dict[str, str | None]) -> "TargetConfig":
"""从字典创建配置对象"""
return cls(
url=data.get("url", ""),
token=data.get("token"),
@@ -32,13 +35,16 @@ class TargetConfig:
@dataclass
class RouteConfig:
"""路由配置,包含多个平台的路由目标"""
route_config: Dict[str, TargetConfig]
def to_dict(self) -> Dict[str, Dict[str, str | None]]:
"""转换为字典格式"""
return {"route_config": {k: v.to_dict() for k, v in self.route_config.items()}}
@classmethod
def from_dict(cls, data: Dict[str, Dict[str, str | None]]) -> "RouteConfig":
"""从字典创建路由配置对象"""
cfg = {
platform: TargetConfig.from_dict(target)
for platform, target in data.get("route_config", {}).items()
@@ -47,7 +53,16 @@ class RouteConfig:
class Router:
"""消息路由器,负责管理多个平台的消息客户端连接"""
def __init__(self, config: RouteConfig, custom_logger: logging.Logger | None = None) -> None:
"""
初始化路由器
Args:
config: 路由配置
custom_logger: 自定义日志记录器
"""
if custom_logger:
logger.handlers = custom_logger.handlers
self.config = config
@@ -58,12 +73,22 @@ class Router:
self._stop_event: asyncio.Event | None = None
async def connect(self, platform: str) -> None:
"""
连接到指定平台
Args:
platform: 平台标识
Raises:
ValueError: 未知平台
NotImplementedError: 不支持的模式
"""
if platform not in self.config.route_config:
raise ValueError(f"Unknown platform {platform}")
raise ValueError(f"未知平台: {platform}")
target = self.config.route_config[platform]
mode = "tcp" if target.url.startswith(("tcp://", "tcps://")) else "ws"
if mode != "ws":
raise NotImplementedError("TCP mode is not implemented yet")
raise NotImplementedError("TCP 模式暂未实现")
client = MessageClient(mode="ws")
client.set_disconnect_callback(self._handle_client_disconnect)
await client.connect(
@@ -84,6 +109,7 @@ class Router:
client.register_message_handler(handler)
async def run(self) -> None:
"""启动路由器,连接所有配置的平台并开始运行"""
self._running = True
self._stop_event = asyncio.Event()
for platform in self.config.route_config:
@@ -98,6 +124,12 @@ class Router:
raise
async def remove_platform(self, platform: str) -> None:
"""
移除指定平台的连接
Args:
platform: 平台标识
"""
if platform in self._client_tasks:
task = self._client_tasks.pop(platform)
task.cancel()
@@ -108,7 +140,14 @@ class Router:
await client.stop()
async def _handle_client_disconnect(self, platform: str, reason: str) -> None:
logger.info("Client for %s disconnected: %s (auto-reconnect handled by client)", platform, reason)
"""
处理客户端断开连接
Args:
platform: 平台标识
reason: 断开原因
"""
logger.info(f"平台 {platform} 的客户端断开连接: {reason} (客户端将自动重连)")
task = self._client_tasks.get(platform)
if task is not None and not task.done():
return
@@ -117,6 +156,7 @@ class Router:
self._start_client_task(platform, client)
async def stop(self) -> None:
"""停止路由器,关闭所有连接"""
self._running = False
if self._stop_event:
self._stop_event.set()
@@ -125,23 +165,46 @@ class Router:
self.clients.clear()
def _start_client_task(self, platform: str, client: MessageClient) -> None:
"""
启动客户端任务
Args:
platform: 平台标识
client: 消息客户端
"""
task = asyncio.create_task(client.run())
task.add_done_callback(lambda t, plat=platform: asyncio.create_task(self._restart_if_needed(plat, t)))
self._client_tasks[platform] = task
async def _restart_if_needed(self, platform: str, task: asyncio.Task) -> None:
"""
必要时重启客户端任务
Args:
platform: 平台标识
task: 已完成的任务
"""
if not self._running:
return
if task.cancelled():
return
exc = task.exception()
if exc:
logger.warning("Client task for %s ended with exception: %s", platform, exc)
logger.warning(f"平台 {platform} 的客户端任务异常结束: {exc}")
client = self.clients.get(platform)
if client:
self._start_client_task(platform, client)
def get_target_url(self, message: MessageEnvelope) -> Optional[str]:
"""
根据消息获取目标 URL
Args:
message: 消息信封
Returns:
目标 URL 或 None
"""
platform = message.get("message_info", {}).get("platform")
if not platform:
return None
@@ -149,24 +212,48 @@ class Router:
return target.url if target else None
async def send_message(self, message: MessageEnvelope):
"""
发送消息到指定平台
Args:
message: 消息信封
Raises:
ValueError: 缺少平台信息
RuntimeError: 未找到对应平台的客户端
"""
platform = message.get("message_info", {}).get("platform")
if not platform:
raise ValueError("message_info.platform is required")
raise ValueError("消息中缺少必需的 message_info.platform 字段")
client = self.clients.get(platform)
if client is None:
raise RuntimeError(f"No client connected for platform {platform}")
raise RuntimeError(f"平台 {platform} 没有已连接的客户端")
return await client.send_message(message)
async def update_config(self, config_data: Dict[str, Dict[str, str | None]]) -> None:
"""
更新路由配置
Args:
config_data: 新的配置数据
"""
new_config = RouteConfig.from_dict(config_data)
await self._adjust_connections(new_config)
self.config = new_config
async def _adjust_connections(self, new_config: RouteConfig) -> None:
"""
调整连接以匹配新配置
Args:
new_config: 新的路由配置
"""
current = set(self.config.route_config.keys())
updated = set(new_config.route_config.keys())
# 移除不再存在的平台
for platform in current - updated:
await self.remove_platform(platform)
# 添加或更新平台
for platform in updated:
if platform not in current:
await self.connect(platform)

View File

@@ -25,13 +25,14 @@ class MessageProcessingError(RuntimeError):
def __init__(self, message: MessageEnvelope, original: BaseException):
detail = message.get("id", "<unknown>")
super().__init__(f"Failed to handle message {detail}: {original}")
super().__init__(f"处理消息 {detail} 时出错: {original}")
self.message_envelope = message
self.original = original
@dataclass
class MessageRoute:
"""消息路由配置,包含匹配条件和处理函数"""
predicate: Predicate
handler: MessageHandler
name: str | None = None
@@ -41,7 +42,7 @@ class MessageRoute:
class MessageRuntime:
"""
负责调度消息路由、执行前后 hook 以及批量处理
消息运行时环境,负责调度消息路由、执行前后处理钩子以及批量处理消息
"""
def __init__(self) -> None:
@@ -64,6 +65,16 @@ class MessageRuntime:
message_type: str | None = None,
event_types: Iterable[str] | None = None,
) -> None:
"""
添加消息路由
Args:
predicate: 路由匹配条件
handler: 消息处理函数
name: 路由名称(可选)
message_type: 消息类型(可选)
event_types: 事件类型列表(可选)
"""
with self._lock:
route = MessageRoute(
predicate=predicate,
@@ -245,14 +256,8 @@ class MessageRuntime:
return wrapped
async def _maybe_await(result):
if asyncio.iscoroutine(result) or isinstance(result, asyncio.Future):
return await result
return result
async def _invoke_callable(func: Callable[..., object], *args, prefer_thread: bool = False):
"""Support sync/async callables with optional thread offloading."""
"""支持 sync/async 调用,并可选择在线程中执行。"""
if inspect.iscoroutinefunction(func):
return await func(*args)
if prefer_thread:

View File

@@ -11,7 +11,7 @@ from ..types import MessageEnvelope
class HttpMessageClient:
"""
面向消息批量传输的 HTTP 客户端封装
面向消息批量传输的 HTTP 客户端封装
"""
def __init__(
@@ -39,14 +39,14 @@ class HttpMessageClient:
session = await self._ensure_session()
url = f"{self._base_url}{path}"
payload = dumps_messages(messages)
self._logger.debug("Sending %d message(s) -> %s", len(messages), url)
self._logger.debug(f"正在发送 {len(messages)} 条消息 -> {url}")
async with session.post(url, data=payload, timeout=self._timeout) as resp:
resp.raise_for_status()
if not expect_reply:
return None
raw = await resp.read()
replies = loads_messages(raw)
self._logger.debug("Received %d reply message(s)", len(replies))
self._logger.debug(f"接收到 {len(replies)} 条回复消息")
return replies
async def close(self) -> None:

View File

@@ -13,7 +13,7 @@ MessageHandler = Callable[[List[MessageEnvelope]], Awaitable[List[MessageEnvelop
class HttpMessageServer:
"""
轻量级 HTTP 消息入口可独立运行,也可挂载到现有 FastAPI / aiohttp 应用下
轻量级 HTTP 消息入口可独立运行,也可挂载到现有 FastAPI / aiohttp 应用下
"""
def __init__(self, handler: MessageHandler, *, path: str = "/messages") -> None:
@@ -27,10 +27,10 @@ class HttpMessageServer:
try:
raw = await request.read()
envelopes = loads_messages(raw)
self._logger.debug("Received %d message(s)", len(envelopes))
self._logger.debug(f"接收到 {len(envelopes)} 条消息")
except Exception as exc: # pragma: no cover - network errors are integration tested
self._logger.exception("Failed to parse incoming messages: %s", exc)
raise web.HTTPBadRequest(reason=f"Invalid payload: {exc}") from exc
self._logger.exception(f"解析请求失败: {exc}")
raise web.HTTPBadRequest(reason=f"无效的负载: {exc}") from exc
result = await self._handler(envelopes)
if result is None:

View File

@@ -14,7 +14,7 @@ IncomingHandler = Callable[[MessageEnvelope], Awaitable[None]]
class WsMessageClient:
"""
管理 WebSocket 连接,提供 send/receive API并在后台读取消息
管理 WebSocket 连接,提供 send/receive API并在后台读取消息
"""
def __init__(
@@ -42,7 +42,7 @@ class WsMessageClient:
async def _connect_once(self) -> None:
assert self._session is not None
self._ws = await self._session.ws_connect(self._url)
self._logger.info("Connected to %s", self._url)
self._logger.info(f"已连接到 {self._url}")
self._receive_task = asyncio.create_task(self._receive_loop())
async def send_messages(self, messages: Sequence[MessageEnvelope]) -> None:
@@ -76,7 +76,7 @@ class WsMessageClient:
if self._handler is not None:
await self._handler(env)
elif msg.type == aiohttp.WSMsgType.ERROR:
self._logger.warning("WebSocket error: %s", msg.data)
self._logger.warning(f"WebSocket 错误: {msg.data}")
break
except asyncio.CancelledError: # pragma: no cover - cancellation path
return
@@ -85,7 +85,7 @@ class WsMessageClient:
await self._reconnect()
async def _reconnect(self) -> None:
self._logger.info("WebSocket disconnected, retrying in %.1fs", self._reconnect_interval)
self._logger.info(f"WebSocket 断开, 正在 {self._reconnect_interval:.1f} 秒后重试")
await asyncio.sleep(self._reconnect_interval)
await self._connect_once()

View File

@@ -15,7 +15,7 @@ WsMessageHandler = Callable[[MessageEnvelope], Awaitable[None]]
class WsMessageServer:
"""
封装 WebSocket 服务端逻辑,负责接收消息并广播响应
封装 WebSocket 服务端逻辑,负责接收消息并广播响应
"""
def __init__(self, handler: WsMessageHandler, *, path: str = "/ws") -> None:
@@ -30,7 +30,7 @@ class WsMessageServer:
async def _handle_ws(self, request: web.Request) -> web.WebSocketResponse:
ws = web.WebSocketResponse()
await ws.prepare(request)
self._logger.info("WebSocket connection opened: %s", request.remote)
self._logger.info(f"WebSocket 连接打开: {request.remote}")
async with self._track_connection(ws):
async for message in ws:
@@ -39,10 +39,10 @@ class WsMessageServer:
for env in envelopes:
await self._handler(env)
elif message.type == WSMsgType.ERROR:
self._logger.warning("WebSocket connection error: %s", ws.exception())
self._logger.warning(f"WebSocket 连接错误: {ws.exception()}")
break
self._logger.info("WebSocket connection closed: %s", request.remote)
self._logger.info(f"WebSocket 连接关闭: {request.remote}")
return ws
@asynccontextmanager

View File

@@ -12,7 +12,7 @@ from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreSink, MessageEnvelope
from mofox_bus import AdapterBase as MoFoxAdapterBase, CoreSink, MessageEnvelope, ProcessCoreSink
if TYPE_CHECKING:
from src.plugin_system.base.base_plugin import BasePlugin
@@ -62,6 +62,28 @@ class BaseAdapter(MoFoxAdapterBase, ABC):
self._config: Dict[str, Any] = {}
self._health_check_task: Optional[asyncio.Task] = None
self._running = False
# 标记是否在子进程中运行(由核心管理器传入 ProcessCoreSink 时自动生效)
self._is_subprocess = isinstance(core_sink, ProcessCoreSink)
@classmethod
def from_process_queues(
cls,
to_core_queue,
from_core_queue,
plugin: Optional["BasePlugin"] = None,
**kwargs: Any,
) -> "BaseAdapter":
"""
子进程入口便捷构造:使用 multiprocessing.Queue 与核心建立 ProcessCoreSink 通讯。
Args:
to_core_queue: 发往核心的 multiprocessing.Queue
from_core_queue: 核心回传的 multiprocessing.Queue
plugin: 可选插件实例
**kwargs: 透传给适配器构造函数
"""
sink = ProcessCoreSink(to_core_queue=to_core_queue, from_core_queue=from_core_queue)
return cls(core_sink=sink, plugin=plugin, **kwargs)
@property
def config(self) -> Dict[str, Any]:

View File

@@ -61,24 +61,24 @@ def _adapter_process_entry(
class AdapterProcess:
"""适配器子进程包装器,负责适配器子进程的启动和生命周期管理"""
"""适配器子进程封装:管理子进程的生命周期与通信桥接"""
def __init__(self, adapter: "BaseAdapter", core_sink) -> None:
self.adapter = adapter
self.adapter_name = adapter.adapter_name
def __init__(self, adapter_cls: "type[BaseAdapter]", plugin, core_sink) -> None:
self.adapter_cls = adapter_cls
self.adapter_name = adapter_cls.adapter_name
self.plugin = plugin
self.process: mp.Process | None = None
self._ctx = mp.get_context("spawn")
self._incoming_queue: mp.Queue = self._ctx.Queue()
self._outgoing_queue: mp.Queue = self._ctx.Queue()
self._bridge: ProcessCoreSinkServer | None = None
self._core_sink = core_sink
self._adapter_path: tuple[str, str] = (adapter.__class__.__module__, adapter.__class__.__name__)
self._plugin_info = self._extract_plugin_info(adapter)
self._adapter_path: tuple[str, str] = (adapter_cls.__module__, adapter_cls.__name__)
self._plugin_info = self._extract_plugin_info(plugin)
self._outgoing_handler = None
@staticmethod
def _extract_plugin_info(adapter: "BaseAdapter") -> dict | None:
plugin = getattr(adapter, "plugin", None)
def _extract_plugin_info(plugin) -> dict | None:
if plugin is None:
return None
return {
@@ -158,57 +158,50 @@ class AdapterManager:
"""适配器管理器"""
def __init__(self):
self._adapters: Dict[str, BaseAdapter] = {}
# 注册信息name -> (adapter class, plugin instance | None)
self._adapter_defs: Dict[str, tuple[type[BaseAdapter], object | None]] = {}
self._adapter_processes: Dict[str, AdapterProcess] = {}
self._in_process_adapters: Dict[str, BaseAdapter] = {}
def register_adapter(self, adapter: BaseAdapter) -> None:
def register_adapter(self, adapter_cls: type[BaseAdapter], plugin=None) -> None:
"""
注册适配器
Args:
adapter: 要注册的适配器实例
adapter_cls: 适配器
plugin: 可选 Plugin 实例
"""
adapter_name = adapter.adapter_name
adapter_name = getattr(adapter_cls, 'adapter_name', adapter_cls.__name__)
if adapter_name in self._adapters:
logger.warning(f"适配器 {adapter_name}注册,将被覆盖")
if adapter_name in self._adapter_defs:
logger.warning(f"适配器 {adapter_name} 已注册,覆盖")
self._adapters[adapter_name] = adapter
logger.info(f"已注册适配器: {adapter_name} v{adapter.adapter_version}")
self._adapter_defs[adapter_name] = (adapter_cls, plugin)
adapter_version = getattr(adapter_cls, 'adapter_version', 'unknown')
logger.info(f"注册适配器: {adapter_name} v{adapter_version}")
async def start_adapter(self, adapter_name: str) -> bool:
"""
启动指定的适配器
Args:
adapter_name: 适配器名称
Returns:
bool: 是否成功启动
"""
adapter = self._adapters.get(adapter_name)
if not adapter:
"""启动指定适配器"""
definition = self._adapter_defs.get(adapter_name)
if not definition:
logger.error(f"适配器 {adapter_name} 未注册")
return False
adapter_cls, plugin = definition
run_in_subprocess = getattr(adapter_cls, "run_in_subprocess", False)
# 检查是否需要在子进程中运行
if adapter.run_in_subprocess:
return await self._start_adapter_subprocess(adapter)
else:
return await self._start_adapter_in_process(adapter)
if run_in_subprocess:
return await self._start_adapter_subprocess(adapter_name, adapter_cls, plugin)
return await self._start_adapter_in_process(adapter_name, adapter_cls, plugin)
async def _start_adapter_subprocess(self, adapter: BaseAdapter) -> bool:
"""启动适配器子进程"""
adapter_name = adapter.adapter_name
async def _start_adapter_subprocess(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool:
"""在子进程中启动适配器"""
try:
core_sink = get_core_sink()
except Exception as e:
logger.error(f"无法获取 core_sink启动适配器子进程 {adapter_name} 失败: {e}", exc_info=True)
logger.error(f"无法获取 core_sink启动子进程 {adapter_name} 失败: {e}", exc_info=True)
return False
adapter_process = AdapterProcess(adapter, core_sink)
adapter_process = AdapterProcess(adapter_cls, plugin, core_sink)
success = await adapter_process.start()
if success:
@@ -216,17 +209,17 @@ class AdapterManager:
return success
async def _start_adapter_in_process(self, adapter: BaseAdapter) -> bool:
"""进程中启动适配器"""
adapter_name = adapter.adapter_name
async def _start_adapter_in_process(self, adapter_name: str, adapter_cls: type[BaseAdapter], plugin) -> bool:
"""当前进程中启动适配器"""
try:
core_sink = get_core_sink()
adapter = adapter_cls(core_sink, plugin=plugin) # type: ignore[call-arg]
await adapter.start()
self._in_process_adapters[adapter_name] = adapter
logger.info(f"适配器 {adapter_name} 已在进程启动")
logger.info(f"适配器 {adapter_name} 已在当前进程启动")
return True
except Exception as e:
logger.error(f"在主进程中启动适配器 {adapter_name} 失败: {e}", exc_info=True)
logger.error(f"启动适配器 {adapter_name} 失败: {e}", exc_info=True)
return False
async def stop_adapter(self, adapter_name: str) -> None:
@@ -251,10 +244,10 @@ class AdapterManager:
logger.error(f"停止适配器 {adapter_name} 时出错: {e}", exc_info=True)
async def start_all_adapters(self) -> None:
"""启动所有注册的适配器"""
logger.info(f"开始启动 {len(self._adapters)} 个适配器...")
"""启动所有注册的适配器"""
logger.info(f"开始启动 {len(self._adapter_defs)} 个适配器...")
for adapter_name in list(self._adapters.keys()):
for adapter_name in list(self._adapter_defs.keys()):
await self.start_adapter(adapter_name)
async def stop_all_adapters(self) -> None:
@@ -285,32 +278,26 @@ class AdapterManager:
return self._in_process_adapters.get(adapter_name)
def list_adapters(self) -> Dict[str, Dict[str, any]]:
"""
列出所有适配器的状态
Returns:
Dict: 适配器状态信息
"""
"""列出适配器状态"""
result = {}
for adapter_name, adapter in self._adapters.items():
for adapter_name, definition in self._adapter_defs.items():
adapter_cls, _plugin = definition
status = {
"name": adapter_name,
"version": adapter.adapter_version,
"platform": adapter.platform,
"run_in_subprocess": adapter.run_in_subprocess,
"version": getattr(adapter_cls, "adapter_version", "unknown"),
"platform": getattr(adapter_cls, "platform", "unknown"),
"run_in_subprocess": getattr(adapter_cls, "run_in_subprocess", False),
"running": False,
"location": "unknown",
}
# 检查运行状态
if adapter_name in self._adapter_processes:
process = self._adapter_processes[adapter_name]
status["running"] = process.is_running()
status["location"] = "subprocess"
if process.process:
status["pid"] = process.process.pid
elif adapter_name in self._in_process_adapters:
status["running"] = True
status["location"] = "in-process"
@@ -332,4 +319,4 @@ def get_adapter_manager() -> AdapterManager:
return _adapter_manager
__all__ = ["AdapterManager", "AdapterProcess", "get_adapter_manager"]
__all__ = ["AdapterManager", "AdapterProcess", "get_adapter_manager"]

View File

@@ -224,18 +224,8 @@ class PluginManager:
continue
# 创建适配器实例,传入 core_sink 和 plugin
if self._core_sink is not None:
adapter_instance = adapter_class(self._core_sink, plugin=plugin_instance) # type: ignore
else:
logger.warning(
f"适配器 '{comp_info.name}' 未获得 core_sink"
"请在主程序中调用 plugin_manager.set_core_sink()"
)
# 尝试无参数创建(某些适配器可能不需要 core_sink
adapter_instance = adapter_class(plugin=plugin_instance) # type: ignore
# 注册到适配器管理器
adapter_manager.register_adapter(adapter_instance) # type: ignore
# 注册到适配器管理器,由管理器统一在运行时创建实例
adapter_manager.register_adapter(adapter_class, plugin_instance) # type: ignore
logger.info(
f"插件 '{plugin_name}' 注册了适配器组件: {comp_info.name} "
f"(平台: {comp_info.platform})"
@@ -708,4 +698,4 @@ class PluginManager:
# 全局插件管理器实例
plugin_manager = PluginManager()
plugin_manager = PluginManager()