更改部分类型注解

This commit is contained in:
John Richard
2025-10-02 21:10:36 +08:00
parent 7923eafef3
commit 047105e5e8
6 changed files with 17 additions and 45 deletions

View File

@@ -760,7 +760,7 @@ async def initialize_database():
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]:
async def get_db_session() -> AsyncGenerator[AsyncSession]:
"""
异步数据库会话上下文管理器。
在初始化失败时会yield None调用方需要检查会话是否为None。
@@ -770,13 +770,10 @@ async def get_db_session() -> AsyncGenerator[AsyncSession | None, None]:
try:
_, SessionLocal = await initialize_database()
if not SessionLocal:
logger.error("数据库会话工厂 (_SessionLocal) 未初始化。")
yield None
return
raise RuntimeError("数据库会话工厂 (_SessionLocal) 未初始化。")
except Exception as e:
logger.error(f"数据库初始化失败,无法创建会话: {e}")
yield None
return
raise
try:
session = SessionLocal()

View File

@@ -1,3 +1,4 @@
# Todo: 重构Action,这里现在只剩下了报错。
import asyncio
import time
from abc import ABC, abstractmethod
@@ -452,7 +453,7 @@ class BaseAction(ABC):
# 4. 执行Action
logger.debug(f"{log_prefix} 开始执行...")
execute_result = await action_instance.execute()
execute_result = await action_instance.execute() # Todo: 修复类型错误
# 确保返回类型符合 (bool, str) 格式
is_success = execute_result[0] if isinstance(execute_result, tuple) and len(execute_result) > 0 else False
message = execute_result[1] if isinstance(execute_result, tuple) and len(execute_result) > 1 else ""

View File

@@ -21,10 +21,6 @@ class BasePlugin(PluginBase):
- Command组件处理命令请求
- 未来可扩展Scheduler、Listener等
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@abstractmethod
def get_plugin_components(
self,
@@ -42,7 +38,7 @@ class BasePlugin(PluginBase):
Returns:
List[tuple[ComponentInfo, Type]]: [(组件信息, 组件类), ...]
"""
raise NotImplementedError("Subclasses must implement this method")
...
def register_plugin(self) -> bool:
"""注册插件及其所有组件"""

View File

@@ -27,40 +27,17 @@ class PluginBase(ABC):
"""
# 插件基本信息(子类必须定义)
@property
@abstractmethod
def plugin_name(self) -> str:
return "" # 插件内部标识符(如 "hello_world_plugin"
@property
@abstractmethod
def enable_plugin(self) -> bool:
return True # 是否启用插件
@property
@abstractmethod
def dependencies(self) -> list[str]:
return [] # 依赖的其他插件
@property
@abstractmethod
def python_dependencies(self) -> list[str | PythonDependency]:
return [] # Python包依赖支持字符串列表或PythonDependency对象列表
@property
@abstractmethod
def config_file_name(self) -> str:
return "" # 配置文件名
plugin_name: str
config_file_name: str
enable_plugin: bool = True
dependencies: list[str] = []
python_dependencies: list[str | PythonDependency] = []
# manifest文件相关
manifest_file_name: str = "_manifest.json" # manifest文件名
manifest_data: dict[str, Any] = {} # manifest数据
# 配置定义
@property
@abstractmethod
def config_schema(self) -> dict[str, dict[str, ConfigField] | str]:
return {}
config_schema: dict[str, dict[str, ConfigField] | str] = {}
config_section_descriptions: dict[str, str] = {}

View File

@@ -1,6 +1,7 @@
import asyncio
import datetime
import re
from typing import ClassVar
from dateutil.parser import parse as parse_datetime
@@ -542,7 +543,7 @@ class SetEmojiLikePlugin(BasePlugin):
config_section_descriptions = {"plugin": "插件基本信息", "components": "插件组件"}
# 配置Schema定义
config_schema: dict = {
config_schema: ClassVar[dict ]= {
"plugin": {
"name": ConfigField(type=str, default="set_emoji_like", description="插件名称"),
"version": ConfigField(type=str, default="1.0.0", description="插件版本"),

View File

@@ -70,7 +70,7 @@ async def get_active_plans_for_month(month: str) -> list[MonthlyPlan]:
.where(MonthlyPlan.target_month == month, MonthlyPlan.status == "active")
.order_by(MonthlyPlan.created_at.desc())
)
return result.scalars().all()
return list(result.scalars().all())
except Exception as e:
logger.error(f"查询 {month} 的有效月度计划时发生错误: {e}")
return []
@@ -225,7 +225,7 @@ async def get_smart_plans_for_daily_schedule(month: str, max_count: int = 3, avo
plans = random.sample(plans, max_count)
logger.info(f"智能抽取了 {len(plans)}{month} 的月度计划用于每日日程生成。")
return plans
return list(plans)
except Exception as e:
logger.error(f"智能抽取 {month} 的月度计划时发生错误: {e}")
@@ -269,7 +269,7 @@ async def get_archived_plans_for_month(month: str) -> list[MonthlyPlan]:
result = await session.execute(
select(MonthlyPlan).where(MonthlyPlan.target_month == month, MonthlyPlan.status == "archived")
)
return result.scalars().all()
return list(result.scalars().all())
except Exception as e:
logger.error(f"查询 {month} 的归档月度计划时发生错误: {e}")
return []