refactor: 初步重构为maimcore

This commit is contained in:
tcmofashi
2025-03-27 13:30:46 +08:00
parent 09c6500d79
commit 4c332d0b2f
26 changed files with 426 additions and 1213 deletions

View File

@@ -0,0 +1,26 @@
"""Maim Message - A message handling library"""
__version__ = "0.1.0"
from .api import BaseMessageAPI, global_api
from .message_base import (
Seg,
GroupInfo,
UserInfo,
FormatInfo,
TemplateInfo,
BaseMessageInfo,
MessageBase,
)
__all__ = [
"BaseMessageAPI",
"Seg",
"global_api",
"GroupInfo",
"UserInfo",
"FormatInfo",
"TemplateInfo",
"BaseMessageInfo",
"MessageBase",
]

View File

@@ -0,0 +1,86 @@
from fastapi import FastAPI, HTTPException
from typing import Optional, Dict, Any, Callable, List
import aiohttp
import asyncio
import uvicorn
import os
class BaseMessageAPI:
def __init__(self, host: str = "0.0.0.0", port: int = 18000):
self.app = FastAPI()
self.host = host
self.port = port
self.message_handlers: List[Callable] = []
self._setup_routes()
self._running = False
def _setup_routes(self):
"""设置基础路由"""
@self.app.post("/api/message")
async def handle_message(message: Dict[str, Any]):
# try:
for handler in self.message_handlers:
await handler(message)
return {"status": "success"}
# except Exception as e:
# raise HTTPException(status_code=500, detail=str(e)) from e
def register_message_handler(self, handler: Callable):
"""注册消息处理函数"""
self.message_handlers.append(handler)
async def send_message(self, url: str, data: Dict[str, Any]) -> Dict[str, Any]:
"""发送消息到指定端点"""
async with aiohttp.ClientSession() as session:
try:
async with session.post(url, json=data, headers={"Content-Type": "application/json"}) as response:
return await response.json()
except Exception as e:
# logger.error(f"发送消息失败: {str(e)}")
pass
def run_sync(self):
"""同步方式运行服务器"""
uvicorn.run(self.app, host=self.host, port=self.port)
async def run(self):
"""异步方式运行服务器"""
config = uvicorn.Config(self.app, host=self.host, port=self.port, loop="asyncio")
self.server = uvicorn.Server(config)
await self.server.serve()
async def start_server(self):
"""启动服务器的异步方法"""
if not self._running:
self._running = True
await self.run()
async def stop(self):
"""停止服务器"""
if hasattr(self, "server"):
self._running = False
# 正确关闭 uvicorn 服务器
self.server.should_exit = True
await self.server.shutdown()
# 等待服务器完全停止
if hasattr(self.server, "started") and self.server.started:
await self.server.main_loop()
# 清理处理程序
self.message_handlers.clear()
def start(self):
"""启动服务器的便捷方法"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
loop.run_until_complete(self.start_server())
except KeyboardInterrupt:
pass
finally:
loop.close()
global_api = BaseMessageAPI(host=os.environ["HOST"], port=os.environ["PORT"])

View File

@@ -0,0 +1,246 @@
from dataclasses import dataclass, asdict
from typing import List, Optional, Union, Dict
@dataclass
class Seg:
"""消息片段类,用于表示消息的不同部分
Attributes:
type: 片段类型,可以是 'text''image''seglist'
data: 片段的具体内容
- 对于 text 类型data 是字符串
- 对于 image 类型data 是 base64 字符串
- 对于 seglist 类型data 是 Seg 列表
translated_data: 经过翻译处理的数据(可选)
"""
type: str
data: Union[str, List["Seg"]]
# def __init__(self, type: str, data: Union[str, List['Seg']],):
# """初始化实例,确保字典和属性同步"""
# # 先初始化字典
# self.type = type
# self.data = data
@classmethod
def from_dict(cls, data: Dict) -> "Seg":
"""从字典创建Seg实例"""
type = data.get("type")
data = data.get("data")
if type == "seglist":
data = [Seg.from_dict(seg) for seg in data]
return cls(type=type, data=data)
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {"type": self.type}
if self.type == "seglist":
result["data"] = [seg.to_dict() for seg in self.data]
else:
result["data"] = self.data
return result
@dataclass
class GroupInfo:
"""群组信息类"""
platform: Optional[str] = None
group_id: Optional[int] = None
group_name: Optional[str] = None # 群名称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "GroupInfo":
"""从字典创建GroupInfo实例
Args:
data: 包含必要字段的字典
Returns:
GroupInfo: 新的实例
"""
if data.get("group_id") is None:
return None
return cls(
platform=data.get("platform"), group_id=data.get("group_id"), group_name=data.get("group_name", None)
)
@dataclass
class UserInfo:
"""用户信息类"""
platform: Optional[str] = None
user_id: Optional[int] = None
user_nickname: Optional[str] = None # 用户昵称
user_cardname: Optional[str] = None # 用户群昵称
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "UserInfo":
"""从字典创建UserInfo实例
Args:
data: 包含必要字段的字典
Returns:
UserInfo: 新的实例
"""
return cls(
platform=data.get("platform"),
user_id=data.get("user_id"),
user_nickname=data.get("user_nickname", None),
user_cardname=data.get("user_cardname", None),
)
@dataclass
class FormatInfo:
"""格式信息类"""
"""
目前maimcore可接受的格式为text,image,emoji
可发送的格式为text,emoji,reply
"""
content_format: Optional[str] = None
accept_format: Optional[str] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "FormatInfo":
"""从字典创建FormatInfo实例
Args:
data: 包含必要字段的字典
Returns:
FormatInfo: 新的实例
"""
return cls(
content_format=data.get("content_format"),
accept_format=data.get("accept_format"),
)
@dataclass
class TemplateInfo:
"""模板信息类"""
template_items: Optional[List[Dict]] = None
template_name: Optional[str] = None
template_default: bool = True
def to_dict(self) -> Dict:
"""转换为字典格式"""
return {k: v for k, v in asdict(self).items() if v is not None}
@classmethod
def from_dict(cls, data: Dict) -> "TemplateInfo":
"""从字典创建TemplateInfo实例
Args:
data: 包含必要字段的字典
Returns:
TemplateInfo: 新的实例
"""
return cls(
template_items=data.get("template_items"),
template_name=data.get("template_name"),
template_default=data.get("template_default", True),
)
@dataclass
class BaseMessageInfo:
"""消息信息类"""
platform: Optional[str] = None
message_id: Union[str, int, None] = None
time: Optional[int] = None
group_info: Optional[GroupInfo] = None
user_info: Optional[UserInfo] = None
format_info: Optional[FormatInfo] = None
template_info: Optional[TemplateInfo] = None
def to_dict(self) -> Dict:
"""转换为字典格式"""
result = {}
for field, value in asdict(self).items():
if value is not None:
if isinstance(value, (GroupInfo, UserInfo, FormatInfo, TemplateInfo)):
result[field] = value.to_dict()
else:
result[field] = value
return result
@classmethod
def from_dict(cls, data: Dict) -> "BaseMessageInfo":
"""从字典创建BaseMessageInfo实例
Args:
data: 包含必要字段的字典
Returns:
BaseMessageInfo: 新的实例
"""
group_info = GroupInfo.from_dict(data.get("group_info", {}))
user_info = UserInfo.from_dict(data.get("user_info", {}))
format_info = FormatInfo.from_dict(data.get("format_info", {}))
template_info = TemplateInfo.from_dict(data.get("template_info", {}))
return cls(
platform=data.get("platform"),
message_id=data.get("message_id"),
time=data.get("time"),
group_info=group_info,
user_info=user_info,
format_info=format_info,
template_info=template_info,
)
@dataclass
class MessageBase:
"""消息类"""
message_info: BaseMessageInfo
message_segment: Seg
raw_message: Optional[str] = None # 原始消息包含未解析的cq码
def to_dict(self) -> Dict:
"""转换为字典格式
Returns:
Dict: 包含所有非None字段的字典其中
- message_info: 转换为字典格式
- message_segment: 转换为字典格式
- raw_message: 如果存在则包含
"""
result = {"message_info": self.message_info.to_dict(), "message_segment": self.message_segment.to_dict()}
if self.raw_message is not None:
result["raw_message"] = self.raw_message
return result
@classmethod
def from_dict(cls, data: Dict) -> "MessageBase":
"""从字典创建MessageBase实例
Args:
data: 包含必要字段的字典
Returns:
MessageBase: 新的实例
"""
message_info = BaseMessageInfo.from_dict(data.get("message_info", {}))
message_segment = Seg(**data.get("message_segment", {}))
raw_message = data.get("raw_message", None)
return cls(message_info=message_info, message_segment=message_segment, raw_message=raw_message)

View File

@@ -0,0 +1,98 @@
import unittest
import asyncio
import aiohttp
from api import BaseMessageAPI
from message_base import (
BaseMessageInfo,
UserInfo,
GroupInfo,
FormatInfo,
TemplateInfo,
MessageBase,
Seg,
)
send_url = "http://localhost"
receive_port = 18002 # 接收消息的端口
send_port = 18000 # 发送消息的端口
test_endpoint = "/api/message"
# 创建并启动API实例
api = BaseMessageAPI(host="0.0.0.0", port=receive_port)
class TestLiveAPI(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
"""测试前的设置"""
self.received_messages = []
async def message_handler(message):
self.received_messages.append(message)
self.api = api
self.api.register_message_handler(message_handler)
self.server_task = asyncio.create_task(self.api.run())
try:
await asyncio.wait_for(asyncio.sleep(1), timeout=5)
except asyncio.TimeoutError:
self.skipTest("服务器启动超时")
async def asyncTearDown(self):
"""测试后的清理"""
if hasattr(self, "server_task"):
await self.api.stop() # 先调用正常的停止流程
if not self.server_task.done():
self.server_task.cancel()
try:
await asyncio.wait_for(self.server_task, timeout=100)
except (asyncio.CancelledError, asyncio.TimeoutError):
pass
async def test_send_and_receive_message(self):
"""测试向运行中的API发送消息并接收响应"""
# 准备测试消息
user_info = UserInfo(user_id=12345678, user_nickname="测试用户", platform="qq")
group_info = GroupInfo(group_id=12345678, group_name="测试群", platform="qq")
format_info = FormatInfo(
content_format=["text"], accept_format=["text", "emoji", "reply"]
)
template_info = None
message_info = BaseMessageInfo(
platform="qq",
message_id=12345678,
time=12345678,
group_info=group_info,
user_info=user_info,
format_info=format_info,
template_info=template_info,
)
message = MessageBase(
message_info=message_info,
raw_message="测试消息",
message_segment=Seg(type="text", data="测试消息"),
)
test_message = message.to_dict()
# 发送测试消息到发送端口
async with aiohttp.ClientSession() as session:
async with session.post(
f"{send_url}:{send_port}{test_endpoint}",
json=test_message,
) as response:
response_data = await response.json()
self.assertEqual(response.status, 200)
self.assertEqual(response_data["status"], "success")
try:
async with asyncio.timeout(5): # 设置5秒超时
while len(self.received_messages) == 0:
await asyncio.sleep(0.1)
received_message = self.received_messages[0]
print(received_message)
self.received_messages.clear()
except asyncio.TimeoutError:
self.fail("等待接收消息超时")
if __name__ == "__main__":
unittest.main()