refactor: 更换fastapi初始化位置
This commit is contained in:
73
src/common/server.py
Normal file
73
src/common/server.py
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
from fastapi import FastAPI, APIRouter
|
||||||
|
from typing import Optional, Union
|
||||||
|
from uvicorn import Config, Server as UvicornServer
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
class Server:
|
||||||
|
def __init__(self, host: Optional[str] = None, port: Optional[int] = None, app_name: str = "MaiMCore"):
|
||||||
|
self.app = FastAPI(title=app_name)
|
||||||
|
self._host: str = "127.0.0.1"
|
||||||
|
self._port: int = 8080
|
||||||
|
self._server: Optional[UvicornServer] = None
|
||||||
|
self.set_address(host, port)
|
||||||
|
|
||||||
|
def register_router(self, router: APIRouter, prefix: str = ""):
|
||||||
|
"""注册路由
|
||||||
|
|
||||||
|
APIRouter 用于对相关的路由端点进行分组和模块化管理:
|
||||||
|
1. 可以将相关的端点组织在一起,便于管理
|
||||||
|
2. 支持添加统一的路由前缀
|
||||||
|
3. 可以为一组路由添加共同的依赖项、标签等
|
||||||
|
|
||||||
|
示例:
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
@router.get("/users")
|
||||||
|
def get_users():
|
||||||
|
return {"users": [...]}
|
||||||
|
|
||||||
|
@router.post("/users")
|
||||||
|
def create_user():
|
||||||
|
return {"msg": "user created"}
|
||||||
|
|
||||||
|
# 注册路由,添加前缀 "/api/v1"
|
||||||
|
server.register_router(router, prefix="/api/v1")
|
||||||
|
"""
|
||||||
|
self.app.include_router(router, prefix=prefix)
|
||||||
|
|
||||||
|
def set_address(self, host: Optional[str] = None, port: Optional[int] = None):
|
||||||
|
"""设置服务器地址和端口"""
|
||||||
|
if host:
|
||||||
|
self._host = host
|
||||||
|
if port:
|
||||||
|
self._port = port
|
||||||
|
|
||||||
|
async def run(self):
|
||||||
|
"""启动服务器"""
|
||||||
|
config = Config(app=self.app, host=self._host, port=self._port)
|
||||||
|
self._server = UvicornServer(config=config)
|
||||||
|
try:
|
||||||
|
await self._server.serve()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
await self.shutdown()
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
await self.shutdown()
|
||||||
|
raise RuntimeError(f"服务器运行错误: {str(e)}") from e
|
||||||
|
finally:
|
||||||
|
await self.shutdown()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
"""安全关闭服务器"""
|
||||||
|
if self._server:
|
||||||
|
self._server.should_exit = True
|
||||||
|
await self._server.shutdown()
|
||||||
|
self._server = None
|
||||||
|
|
||||||
|
def get_app(self) -> FastAPI:
|
||||||
|
"""获取 FastAPI 实例"""
|
||||||
|
return self.app
|
||||||
|
|
||||||
|
|
||||||
|
global_server = Server(host=os.environ["HOST"], port=int(os.environ["PORT"]))
|
||||||
@@ -16,7 +16,7 @@ from .plugins.chat.bot import chat_bot
|
|||||||
from .common.logger import get_module_logger
|
from .common.logger import get_module_logger
|
||||||
from .plugins.remote import heartbeat_thread # noqa: F401
|
from .plugins.remote import heartbeat_thread # noqa: F401
|
||||||
from .individuality.individuality import Individuality
|
from .individuality.individuality import Individuality
|
||||||
|
from .common.server import global_server
|
||||||
|
|
||||||
logger = get_module_logger("main")
|
logger = get_module_logger("main")
|
||||||
|
|
||||||
@@ -33,6 +33,7 @@ class MainSystem:
|
|||||||
from .plugins.message import global_api
|
from .plugins.message import global_api
|
||||||
|
|
||||||
self.app = global_api
|
self.app = global_api
|
||||||
|
self.server = global_server
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
"""初始化系统组件"""
|
"""初始化系统组件"""
|
||||||
@@ -126,6 +127,7 @@ class MainSystem:
|
|||||||
emoji_manager.start_periodic_check_register(),
|
emoji_manager.start_periodic_check_register(),
|
||||||
# emoji_manager.start_periodic_register(),
|
# emoji_manager.start_periodic_register(),
|
||||||
self.app.run(),
|
self.app.run(),
|
||||||
|
self.server.run(),
|
||||||
]
|
]
|
||||||
await asyncio.gather(*tasks)
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|||||||
@@ -2,7 +2,7 @@
|
|||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
|
|
||||||
from .api import BaseMessageAPI, global_api
|
from .api import global_api
|
||||||
from .message_base import (
|
from .message_base import (
|
||||||
Seg,
|
Seg,
|
||||||
GroupInfo,
|
GroupInfo,
|
||||||
@@ -14,7 +14,6 @@ from .message_base import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"BaseMessageAPI",
|
|
||||||
"Seg",
|
"Seg",
|
||||||
"global_api",
|
"global_api",
|
||||||
"GroupInfo",
|
"GroupInfo",
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
|
|||||||
from typing import Dict, Any, Callable, List, Set, Optional
|
from typing import Dict, Any, Callable, List, Set, Optional
|
||||||
from src.common.logger import get_module_logger
|
from src.common.logger import get_module_logger
|
||||||
from src.plugins.message.message_base import MessageBase
|
from src.plugins.message.message_base import MessageBase
|
||||||
|
from src.common.server import global_server
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
import uvicorn
|
import uvicorn
|
||||||
@@ -242,105 +243,4 @@ class MessageServer(BaseMessageHandler):
|
|||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
class BaseMessageAPI:
|
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]), app=global_server.get_app())
|
||||||
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
|
|
||||||
self.app = FastAPI()
|
|
||||||
self.host = host
|
|
||||||
self.port = port
|
|
||||||
self.message_handlers: List[Callable] = []
|
|
||||||
self.cache = []
|
|
||||||
self._setup_routes()
|
|
||||||
self._running = False
|
|
||||||
|
|
||||||
def _setup_routes(self):
|
|
||||||
"""设置基础路由"""
|
|
||||||
|
|
||||||
@self.app.post("/api/message")
|
|
||||||
async def handle_message(message: Dict[str, Any]):
|
|
||||||
try:
|
|
||||||
# 创建后台任务处理消息
|
|
||||||
asyncio.create_task(self._background_message_handler(message))
|
|
||||||
return {"status": "success"}
|
|
||||||
except Exception as e:
|
|
||||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
|
||||||
|
|
||||||
async def _background_message_handler(self, message: Dict[str, Any]):
|
|
||||||
"""后台处理单个消息"""
|
|
||||||
try:
|
|
||||||
await self.process_single_message(message)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"Background message processing failed: {str(e)}")
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
|
|
||||||
def register_message_handler(self, handler: Callable):
|
|
||||||
"""注册消息处理函数"""
|
|
||||||
self.message_handlers.append(handler)
|
|
||||||
|
|
||||||
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""发送消息到指定端点"""
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
try:
|
|
||||||
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
|
|
||||||
return await response.json()
|
|
||||||
except Exception:
|
|
||||||
# logger.error(f"发送消息失败: {str(e)}")
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def process_single_message(self, message: Dict[str, Any]):
|
|
||||||
"""处理单条消息"""
|
|
||||||
tasks = []
|
|
||||||
for handler in self.message_handlers:
|
|
||||||
try:
|
|
||||||
tasks.append(handler(message))
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(str(e))
|
|
||||||
logger.error(traceback.format_exc())
|
|
||||||
if tasks:
|
|
||||||
await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
def run_sync(self):
|
|
||||||
"""同步方式运行服务器"""
|
|
||||||
uvicorn.run(self.app, host=self.host, port=self.port)
|
|
||||||
|
|
||||||
async def run(self):
|
|
||||||
"""异步方式运行服务器"""
|
|
||||||
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
|
|
||||||
self.server = uvicorn.Server(config)
|
|
||||||
try:
|
|
||||||
await self.server.serve()
|
|
||||||
except KeyboardInterrupt as e:
|
|
||||||
await self.stop()
|
|
||||||
raise KeyboardInterrupt from e
|
|
||||||
|
|
||||||
async def start_server(self):
|
|
||||||
"""启动服务器的异步方法"""
|
|
||||||
if not self._running:
|
|
||||||
self._running = True
|
|
||||||
await self.run()
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止服务器"""
|
|
||||||
if hasattr(self, "server"):
|
|
||||||
self._running = False
|
|
||||||
# 正确关闭 uvicorn 服务器
|
|
||||||
self.server.should_exit = True
|
|
||||||
await self.server.shutdown()
|
|
||||||
# 等待服务器完全停止
|
|
||||||
if hasattr(self.server, "started") and self.server.started:
|
|
||||||
await self.server.main_loop()
|
|
||||||
# 清理处理程序
|
|
||||||
self.message_handlers.clear()
|
|
||||||
|
|
||||||
def start(self):
|
|
||||||
"""启动服务器的便捷方法"""
|
|
||||||
loop = asyncio.new_event_loop()
|
|
||||||
asyncio.set_event_loop(loop)
|
|
||||||
try:
|
|
||||||
loop.run_until_complete(self.start_server())
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
pass
|
|
||||||
finally:
|
|
||||||
loop.close()
|
|
||||||
|
|
||||||
|
|
||||||
global_api = MessageServer(host=os.environ["HOST"], port=int(os.environ["PORT"]))
|
|
||||||
|
|||||||
Reference in New Issue
Block a user