diff --git a/src/common/server.py b/src/common/server.py new file mode 100644 index 000000000..fd1f3ff18 --- /dev/null +++ b/src/common/server.py @@ -0,0 +1,73 @@ +from fastapi import FastAPI, APIRouter +from typing import Optional, Union +from uvicorn import Config, Server as UvicornServer +import os + + +class Server: + def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"): + self.app = FastAPI(title=app_name) + self._host: str = "127.0.0.1" + self._port: int = 8080 + self._server: Optional[UvicornServer] = None + self.set_address(host, port) + + def register_router(self, router: APIRouter, prefix: str = ""): + """注册路由 + + APIRouter 用于对相关的路由端点进行分组和模块化管理: + 1. 可以将相关的端点组织在一起,便于管理 + 2. 支持添加统一的路由前缀 + 3. 可以为一组路由添加共同的依赖项、标签等 + + 示例: + router = APIRouter() + + @router.get("/users") + def get_users(): + return {"users": [...]} + + @router.post("/users") + def create_user(): + return {"msg": "user created"} + + # 注册路由,添加前缀 "/api/v1" + server.register_router(router, prefix="/api/v1") + """ + self.app.include_router(router, prefix=prefix) + + def set_address(self, host: Optional[str] = None, port: Optional[int] = None): + """设置服务器地址和端口""" + if host: + self._host = host + if port: + self._port = port + + async def run(self): + """启动服务器""" + config = Config(app=self.app, host=self._host, port=self._port) + self._server = UvicornServer(config=config) + try: + await self._server.serve() + except KeyboardInterrupt: + await self.shutdown() + raise + except Exception as e: + await self.shutdown() + raise RuntimeError(f"服务器运行错误: {str(e)}") from e + finally: + await self.shutdown() + + async def shutdown(self): + """安全关闭服务器""" + if self._server: + self._server.should_exit = True + await self._server.shutdown() + self._server = None + + def get_app(self) -> FastAPI: + """获取 FastAPI 实例""" + return self.app + + +global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"])) diff --git a/src/main.py b/src/main.py index aa6f908bf..d94cfce64 100644 --- a/src/main.py +++ b/src/main.py @@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot from .common.logger import get_module_logger from .plugins.remote import heartbeat_thread # noqa: F401 from .individuality.individuality import Individuality - +from .common.server import global_server logger = get_module_logger("main") @@ -33,6 +33,7 @@ class MainSystem: from .plugins.message import global_api self.app = global_api + self.server = global_server async def initialize(self): """初始化系统组件""" @@ -126,6 +127,7 @@ class MainSystem: emoji_manager.start_periodic_check_register(), # emoji_manager.start_periodic_register(), self.app.run(), + self.server.run(), ] await asyncio.gather(*tasks) diff --git a/src/plugins/message/__init__.py b/src/plugins/message/__init__.py index bee5c5e58..286ef2310 100644 --- a/src/plugins/message/__init__.py +++ b/src/plugins/message/__init__.py @@ -2,7 +2,7 @@ __version__ = "0.1.0" -from .api import BaseMessageAPI, global_api +from .api import global_api from .message_base import ( Seg, GroupInfo, @@ -14,7 +14,6 @@ from .message_base import ( ) __all__ = [ - "BaseMessageAPI", "Seg", "global_api", "GroupInfo", diff --git a/src/plugins/message/api.py b/src/plugins/message/api.py index 19457bbec..0c3e3a5a1 100644 --- a/src/plugins/message/api.py +++ b/src/plugins/message/api.py @@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect from typing import Dict, Any, Callable, List, Set, Optional from src.common.logger import get_module_logger from src.plugins.message.message_base import MessageBase +from src.common.server import global_server import aiohttp import asyncio import uvicorn @@ -242,105 +243,4 @@ class MessageServer(BaseMessageHandler): raise e -class BaseMessageAPI: - def __init__(self, host: str = "0.0.0.0", port: int = 18000): - self.app = FastAPI() - self.host = host - self.port = port - self.message_handlers: List[Callable] = [] - self.cache = [] - self._setup_routes() - self._running = False - - def _setup_routes(self): - """设置基础路由""" - - @self.app.post("/api/message") - async def handle_message(message: Dict[str, Any]): - try: - # 创建后台任务处理消息 - asyncio.create_task(self._background_message_handler(message)) - return {"status": "success"} - except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) from e - - async def _background_message_handler(self, message: Dict[str, Any]): - """后台处理单个消息""" - try: - await self.process_single_message(message) - except Exception as e: - logger.error(f"Background message processing failed: {str(e)}") - logger.error(traceback.format_exc()) - - def register_message_handler(self, handler: Callable): - """注册消息处理函数""" - self.message_handlers.append(handler) - - async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]: - """发送消息到指定端点""" - async with aiohttp.ClientSession() as session: - try: - async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response: - return await response.json() - except Exception: - # logger.error(f"发送消息失败: {str(e)}") - pass - - async def process_single_message(self, message: Dict[str, Any]): - """处理单条消息""" - tasks = [] - for handler in self.message_handlers: - try: - tasks.append(handler(message)) - except Exception as e: - logger.error(str(e)) - logger.error(traceback.format_exc()) - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) - - def run_sync(self): - """同步方式运行服务器""" - uvicorn.run(self.app, host=self.host, port=self.port) - - async def run(self): - """异步方式运行服务器""" - config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio") - self.server = uvicorn.Server(config) - try: - await self.server.serve() - except KeyboardInterrupt as e: - await self.stop() - raise KeyboardInterrupt from e - - async def start_server(self): - """启动服务器的异步方法""" - if not self._running: - self._running = True - await self.run() - - async def stop(self): - """停止服务器""" - if hasattr(self, "server"): - self._running = False - # 正确关闭 uvicorn 服务器 - self.server.should_exit = True - await self.server.shutdown() - # 等待服务器完全停止 - if hasattr(self.server, "started") and self.server.started: - await self.server.main_loop() - # 清理处理程序 - self.message_handlers.clear() - - def start(self): - """启动服务器的便捷方法""" - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - try: - loop.run_until_complete(self.start_server()) - except KeyboardInterrupt: - pass - finally: - loop.close() - - -global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"])) +global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())