refactor: 清理项目结构并修复类型注解问题

修复 SQLAlchemy 模型的类型注解,使用 Mapped 类型避免类型检查器错误
- 修正异步数据库操作中缺少 await 的问题
- 优化反注入统计系统的数值字段处理逻辑
- 添加缺失的导入语句修复模块依赖问题
This commit is contained in:
雅诺狐
2025-10-07 11:35:12 +08:00
parent 167e4d2520
commit 875ee4813c
19 changed files with 1466 additions and 3997 deletions

View File

@@ -265,7 +265,8 @@ class AntiPromptInjector:
async with get_db_session() as session:
# 删除对应的消息记录
stmt = delete(Messages).where(Messages.message_id == message_id)
result = session.execute(stmt)
# 注意: 异步会话需要 await 执行,否则 result 是 coroutine无法获取 rowcount
result = await session.execute(stmt)
await session.commit()
if result.rowcount > 0:
@@ -296,7 +297,7 @@ class AntiPromptInjector:
.where(Messages.message_id == message_id)
.values(processed_plain_text=new_content, display_message=new_content)
)
result = session.execute(stmt)
result = await session.execute(stmt)
await session.commit()
if result.rowcount > 0:

View File

@@ -5,9 +5,9 @@
"""
import datetime
from typing import Any
from typing import Any, Optional, TypedDict, Literal, Union, Callable, TypeVar, cast
from sqlalchemy import select
from sqlalchemy import select, delete
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
from src.common.logger import get_logger
@@ -16,8 +16,30 @@ from src.config.config import global_config
logger = get_logger("anti_injector.statistics")
TNum = TypeVar("TNum", int, float)
def _add_optional(a: Optional[TNum], b: TNum) -> TNum:
"""安全相加:左值可能为 None。
Args:
a: 可能为 None 的当前值
b: 要累加的增量(非 None
Returns:
新的累加结果(与 b 同类型)
"""
if a is None:
return b
return cast(TNum, a + b) # a 不为 None此处显式 cast 便于类型检查
class AntiInjectionStatistics:
"""反注入系统统计管理类"""
"""反注入系统统计管理类
主要改进:
- 对 "可能为 None" 的数值字段做集中安全处理,减少在业务逻辑里反复判空。
- 补充类型注解便于静态检查器Pylance/Pyright识别。
"""
def __init__(self):
"""初始化统计管理器"""
@@ -25,8 +47,12 @@ class AntiInjectionStatistics:
"""当前会话开始时间"""
@staticmethod
async def get_or_create_stats():
"""获取或创建统计记录"""
async def get_or_create_stats() -> Optional[AntiInjectionStats]: # type: ignore[name-defined]
"""获取或创建统计记录
Returns:
AntiInjectionStats | None: 成功返回模型实例,否则 None
"""
try:
async with get_db_session() as session:
# 获取最新的统计记录,如果没有则创建
@@ -46,8 +72,15 @@ class AntiInjectionStatistics:
return None
@staticmethod
async def update_stats(**kwargs):
"""更新统计数据"""
async def update_stats(**kwargs: Any) -> None:
"""更新统计数据(批量可选字段)
支持字段:
- processing_time_delta: float 累加到 processing_time_total
- last_processing_time: float 设置 last_process_time
- total_messages / detected_injections / blocked_messages / shielded_messages / error_count: 累加
- 其他任意字段:直接赋值(若模型存在该属性)
"""
try:
async with get_db_session() as session:
stats = (
@@ -62,14 +95,13 @@ class AntiInjectionStatistics:
# 更新统计字段
for key, value in kwargs.items():
if key == "processing_time_delta":
# 处理 时间累加 - 确保不为None
if stats.processing_time_total is None:
stats.processing_time_total = 0.0
stats.processing_time_total += value
# 处理时间累加 - 确保不为 None
delta = float(value)
stats.processing_time_total = _add_optional(stats.processing_time_total, delta) # type: ignore[attr-defined]
continue
elif key == "last_processing_time":
# 直接设置最后处理时间
stats.last_process_time = value
stats.last_process_time = float(value)
continue
elif hasattr(stats, key):
if key in [
@@ -79,12 +111,10 @@ class AntiInjectionStatistics:
"shielded_messages",
"error_count",
]:
# 累加类型的字段 - 确保不为None
current_value = getattr(stats, key)
if current_value is None:
setattr(stats, key, value)
else:
setattr(stats, key, current_value + value)
# 累加类型的字段 - 统一用辅助函数
current_value = cast(Optional[int], getattr(stats, key))
increment = int(value)
setattr(stats, key, _add_optional(current_value, increment))
else:
# 直接设置的字段
setattr(stats, key, value)
@@ -114,10 +144,11 @@ class AntiInjectionStatistics:
stats = await self.get_or_create_stats()
# 计算派生统计信息 - 处理None值
total_messages = stats.total_messages or 0
detected_injections = stats.detected_injections or 0
processing_time_total = stats.processing_time_total or 0.0
# 计算派生统计信息 - 处理 None 值
total_messages = stats.total_messages or 0 # type: ignore[attr-defined]
detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined]
processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined]
detection_rate = (detected_injections / total_messages * 100) if total_messages > 0 else 0
avg_processing_time = (processing_time_total / total_messages) if total_messages > 0 else 0
@@ -127,17 +158,22 @@ class AntiInjectionStatistics:
current_time = datetime.datetime.now()
uptime = current_time - self.session_start_time
last_proc = stats.last_process_time # type: ignore[attr-defined]
blocked_messages = stats.blocked_messages or 0 # type: ignore[attr-defined]
shielded_messages = stats.shielded_messages or 0 # type: ignore[attr-defined]
error_count = stats.error_count or 0 # type: ignore[attr-defined]
return {
"status": "enabled",
"uptime": str(uptime),
"total_messages": total_messages,
"detected_injections": detected_injections,
"blocked_messages": stats.blocked_messages or 0,
"shielded_messages": stats.shielded_messages or 0,
"blocked_messages": blocked_messages,
"shielded_messages": shielded_messages,
"detection_rate": f"{detection_rate:.2f}%",
"average_processing_time": f"{avg_processing_time:.3f}s",
"last_processing_time": f"{stats.last_process_time:.3f}s" if stats.last_process_time else "0.000s",
"error_count": stats.error_count or 0,
"last_processing_time": f"{last_proc:.3f}s" if last_proc else "0.000s",
"error_count": error_count,
}
except Exception as e:
logger.error(f"获取统计信息失败: {e}")
@@ -149,7 +185,7 @@ class AntiInjectionStatistics:
try:
async with get_db_session() as session:
# 删除现有统计记录
await session.execute(select(AntiInjectionStats).delete())
await session.execute(delete(AntiInjectionStats))
await session.commit()
logger.info("统计信息已重置")
except Exception as e:

View File

@@ -51,7 +51,7 @@ class UserBanManager:
remaining_time = ban_duration - (datetime.datetime.now() - ban_record.created_at)
return False, None, f"用户被封禁中,剩余时间: {remaining_time}"
else:
# 封禁已过期,重置违规次数
# 封禁已过期,重置违规次数与时间(模型已使用 Mapped 类型,可直接赋值)
ban_record.violation_num = 0
ban_record.created_at = datetime.datetime.now()
await session.commit()
@@ -92,7 +92,6 @@ class UserBanManager:
await session.commit()
# 检查是否需要自动封禁
if ban_record.violation_num >= self.config.auto_ban_violation_threshold:
logger.warning(f"用户 {platform}:{user_id} 违规次数达到 {ban_record.violation_num},触发自动封禁")
# 只有在首次达到阈值时才更新封禁开始时间

View File

@@ -377,11 +377,12 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], r
class EmojiManager:
_instance = None
_initialized: bool = False # 显式声明,避免属性未定义错误
def __new__(cls) -> "EmojiManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
# 类属性已声明,无需再次赋值
return cls._instance
def __init__(self) -> None:
@@ -399,7 +400,8 @@ class EmojiManager:
self.emoji_num_max = global_config.emoji.max_reg_num
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表使用类型注解明确列表元素类型
logger.info("启动表情包管理器")
self._initialized = True
logger.info("启动表情包管理器")
def shutdown(self) -> None:
@@ -752,8 +754,8 @@ class EmojiManager:
try:
emoji_record = await self.get_emoji_from_db(emoji_hash)
if emoji_record and emoji_record[0].emotion:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
return emoji_record.emotion
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") # type: ignore # type: ignore
return emoji_record.emotion # type: ignore
except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}")

View File

@@ -7,6 +7,7 @@ import asyncio
import time
from typing import Any
from chat.message_manager.adaptive_stream_manager import StreamPriority
from src.chat.chatter_manager import ChatterManager
from src.chat.energy_system import energy_manager
from src.common.data_models.message_manager_data_model import StreamContext