This commit is contained in:
tt-P607
2025-12-05 02:15:59 +08:00
25 changed files with 573 additions and 851 deletions

View File

@@ -1,14 +1,25 @@
"""
消息管理器模块
提供统一的消息管理、上下文管理和流循环调度功能
基于 Generator + Tick 的事件驱动模式
"""
from .distribution_manager import StreamLoopManager, stream_loop_manager
from .distribution_manager import (
ConversationTick,
StreamLoopManager,
conversation_loop,
run_chat_stream,
stream_loop_manager,
)
from .message_manager import MessageManager, message_manager
__all__ = [
"ConversationTick",
"MessageManager",
"StreamLoopManager",
"conversation_loop",
"message_manager",
"run_chat_stream",
"stream_loop_manager",
]

View File

@@ -234,13 +234,6 @@ class BatchDatabaseWriter:
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
elif global_config.database.database_type == "mysql":
from sqlalchemy.dialects.mysql import insert as mysql_insert
stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in update_data.items() if key != "stream_id"}
)
elif global_config.database.database_type == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
@@ -268,13 +261,6 @@ class BatchDatabaseWriter:
stmt = sqlite_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=update_data)
elif global_config.database.database_type == "mysql":
from sqlalchemy.dialects.mysql import insert as mysql_insert
stmt = mysql_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in update_data.items() if key != "stream_id"}
)
elif global_config.database.database_type == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert

File diff suppressed because it is too large Load Diff

View File

@@ -3,7 +3,6 @@ import hashlib
import time
from rich.traceback import install
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
@@ -665,11 +664,6 @@ class ChatManager:
if global_config.database.database_type == "sqlite":
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
elif global_config.database.database_type == "mysql":
stmt = mysql_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
elif global_config.database.database_type == "postgresql":
stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
# PostgreSQL 需要使用 constraint 参数或正确的 index_elements

View File

@@ -9,7 +9,6 @@
支持的数据库:
- SQLite (默认)
- MySQL
- PostgreSQL
"""

View File

@@ -2,7 +2,6 @@
提供跨数据库兼容性支持,处理不同数据库之间的差异:
- SQLite: 轻量级本地数据库
- MySQL: 高性能关系型数据库
- PostgreSQL: 功能丰富的开源数据库
主要职责:
@@ -23,7 +22,6 @@ class DatabaseDialect(Enum):
"""数据库方言枚举"""
SQLITE = "sqlite"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
@@ -68,20 +66,6 @@ DIALECT_CONFIGS: dict[DatabaseDialect, DialectConfig] = {
}
},
),
DatabaseDialect.MYSQL: DialectConfig(
dialect=DatabaseDialect.MYSQL,
ping_query="SELECT 1",
supports_returning=False, # MySQL 8.0.21+ 有限支持
supports_native_json=True, # MySQL 5.7+
supports_arrays=False,
requires_length_for_index=True, # MySQL 索引需要指定长度
default_string_length=255,
isolation_level="READ COMMITTED",
engine_kwargs={
"pool_pre_ping": True,
"pool_recycle": 3600,
},
),
DatabaseDialect.POSTGRESQL: DialectConfig(
dialect=DatabaseDialect.POSTGRESQL,
ping_query="SELECT 1",
@@ -113,13 +97,13 @@ class DialectAdapter:
"""初始化适配器
Args:
db_type: 数据库类型字符串 ("sqlite", "mysql", "postgresql")
db_type: 数据库类型字符串 ("sqlite", "postgresql")
"""
try:
cls._current_dialect = DatabaseDialect(db_type.lower())
cls._config = DIALECT_CONFIGS[cls._current_dialect]
except ValueError:
raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, mysql, postgresql")
raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, postgresql")
@classmethod
def get_dialect(cls) -> DatabaseDialect:
@@ -153,15 +137,10 @@ class DialectAdapter:
"""
config = cls.get_config()
# MySQL 索引列需要指定长度
if config.requires_length_for_index and indexed:
return String(max_length)
# SQLite 和 PostgreSQL 可以使用 Text
if config.dialect in (DatabaseDialect.SQLITE, DatabaseDialect.POSTGRESQL):
return Text() if not indexed else String(max_length)
# MySQL 使用 VARCHAR
return String(max_length)
@classmethod
@@ -189,11 +168,6 @@ class DialectAdapter:
"""是否为 SQLite"""
return cls.get_dialect() == DatabaseDialect.SQLITE
@classmethod
def is_mysql(cls) -> bool:
"""是否为 MySQL"""
return cls.get_dialect() == DatabaseDialect.MYSQL
@classmethod
def is_postgresql(cls) -> bool:
"""是否为 PostgreSQL"""
@@ -211,7 +185,7 @@ def get_indexed_string_field(max_length: int = 255) -> TypeEngine:
这是一个便捷函数,用于在模型定义中获取适合当前数据库的字符串类型
Args:
max_length: 最大长度(对于 MySQL 是必需的)
max_length: 最大长度
Returns:
SQLAlchemy 类型

View File

@@ -4,7 +4,6 @@
支持的数据库类型:
- SQLite: 轻量级本地数据库,使用 aiosqlite 驱动
- MySQL: 高性能关系型数据库,使用 aiomysql 驱动
- PostgreSQL: 功能丰富的开源数据库,使用 asyncpg 驱动
"""
@@ -66,9 +65,7 @@ async def get_engine() -> AsyncEngine:
logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...")
# 根据数据库类型构建URL和引擎参数
if db_type == "mysql":
url, engine_kwargs = _build_mysql_config(config)
elif db_type == "postgresql":
if db_type == "postgresql":
url, engine_kwargs = _build_postgresql_config(config)
else:
url, engine_kwargs = _build_sqlite_config(config)
@@ -123,55 +120,6 @@ def _build_sqlite_config(config) -> tuple[str, dict]:
return url, engine_kwargs
def _build_mysql_config(config) -> tuple[str, dict]:
"""构建 MySQL 配置
Args:
config: 数据库配置对象
Returns:
(url, engine_kwargs) 元组
"""
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
if config.mysql_unix_socket:
# Unix socket连接
encoded_socket = quote_plus(config.mysql_unix_socket)
url = (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@/{config.mysql_database}"
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
)
else:
# TCP连接
url = (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
engine_kwargs = {
"echo": False,
"future": True,
"pool_size": config.connection_pool_size,
"max_overflow": config.connection_pool_size * 2,
"pool_timeout": config.connection_timeout,
"pool_recycle": 3600,
"pool_pre_ping": True,
"connect_args": {
"autocommit": config.mysql_autocommit,
"charset": config.mysql_charset,
"connect_timeout": config.connection_timeout,
},
}
logger.info(
f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
)
return url, engine_kwargs
def _build_postgresql_config(config) -> tuple[str, dict]:
"""构建 PostgreSQL 配置

View File

@@ -119,9 +119,6 @@ async def check_and_migrate_database(existing_engine=None):
):
# SQLite 将布尔值存储为 0 或 1
default_value = "1" if default_arg else "0"
elif dialect.name == "mysql" and isinstance(default_arg, bool):
# MySQL 也使用 1/0 表示布尔值
default_value = "1" if default_arg else "0"
elif isinstance(default_arg, bool):
# PostgreSQL 使用 TRUE/FALSE
default_value = "TRUE" if default_arg else "FALSE"

View File

@@ -5,7 +5,6 @@
支持的数据库类型:
- SQLite: 使用 Text 类型
- MySQL: 使用 VARCHAR(max_length) 用于索引字段
- PostgreSQL: 使用 Text 类型PostgreSQL 的 Text 类型性能与 VARCHAR 相当)
所有模型使用统一的类型注解风格:
@@ -31,12 +30,11 @@ def get_string_field(max_length=255, **kwargs):
根据数据库类型返回合适的字符串字段类型
对于需要索引的字段:
- MySQL: 必须使用 VARCHAR(max_length),因为索引需要指定长度
- PostgreSQL: 可以使用 Text但为了兼容性使用 VARCHAR
- SQLite: 可以使用 Text无长度限制
Args:
max_length: 最大长度(对于 MySQL 是必需的)
max_length: 最大长度
**kwargs: 传递给 String/Text 的额外参数
Returns:
@@ -47,11 +45,8 @@ def get_string_field(max_length=255, **kwargs):
assert global_config is not None
db_type = global_config.database.database_type
# MySQL 索引需要指定长度的 VARCHAR
if db_type == "mysql":
return String(max_length, **kwargs)
# PostgreSQL 可以使用 Text但为了跨数据库迁移兼容性使用 VARCHAR
elif db_type == "postgresql":
if db_type == "postgresql":
return String(max_length, **kwargs)
# SQLite 使用 Text无长度限制
else:

View File

@@ -4,7 +4,6 @@
支持的数据库类型:
- SQLite: 设置 PRAGMA 参数优化并发
- MySQL: 无特殊会话设置
- PostgreSQL: 可选设置 schema 搜索路径
"""
@@ -79,7 +78,6 @@ async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
schema = global_config.database.postgresql_schema
if schema and schema != "public":
await session.execute(text(f"SET search_path TO {schema}"))
# MySQL 通常不需要会话级别的特殊设置
except Exception:
# 复用连接时设置可能已存在,忽略错误
pass
@@ -93,7 +91,6 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
支持的数据库:
- SQLite: 自动设置 busy_timeout 和外键约束
- MySQL: 直接使用,无特殊设置
- PostgreSQL: 支持自定义 schema
使用示例:
@@ -132,7 +129,7 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
- 正常退出时自动提交事务
- 发生异常时自动回滚事务
- 如果用户代码已手动调用 commit/rollback再次调用是安全的
- 适用于所有数据库类型SQLite, MySQL, PostgreSQL
- 适用于所有数据库类型SQLite, PostgreSQL
Yields:
AsyncSession: SQLAlchemy异步会话对象

View File

@@ -128,7 +128,7 @@ class ConnectionPoolManager:
- 正常退出时自动提交事务
- 发生异常时自动回滚事务
- 如果用户代码已手动调用 commit/rollback再次调用是安全的空操作
- 支持所有数据库类型SQLite、MySQL、PostgreSQL
- 支持所有数据库类型SQLite、PostgreSQL
"""
connection_info = None
@@ -158,7 +158,7 @@ class ConnectionPoolManager:
yield connection_info.session
# 🔧 正常退出时提交事务
# 这对所有数据库SQLite、MySQL、PostgreSQL都很重要
# 这对所有数据库SQLite、PostgreSQL都很重要
# 因为 SQLAlchemy 默认使用事务模式,不会自动提交
# 注意:如果用户代码已调用 commit(),这里的 commit() 是安全的空操作
if connection_info and connection_info.session:

View File

@@ -65,7 +65,7 @@ TEMPLATE_DIR = os.path.join(PROJECT_ROOT, "template")
# 考虑到实际上配置文件中的mai_version是不会自动更新的,所以采用硬编码
# 对该字段的更新请严格参照语义化版本规范https://semver.org/lang/zh-CN/
MMC_VERSION = "0.13.0"
MMC_VERSION = "0.13.1-alpha.1"
# 全局配置变量
_CONFIG_INITIALIZED = False

View File

@@ -16,26 +16,9 @@ from src.config.config_base import ValidatedConfigBase
class DatabaseConfig(ValidatedConfigBase):
"""数据库配置类"""
database_type: Literal["sqlite", "mysql", "postgresql"] = Field(default="sqlite", description="数据库类型")
database_type: Literal["sqlite", "postgresql"] = Field(default="sqlite", description="数据库类型")
sqlite_path: str = Field(default="data/MaiBot.db", description="SQLite数据库文件路径")
# MySQL 配置
mysql_host: str = Field(default="localhost", description="MySQL服务器地址")
mysql_port: int = Field(default=3306, ge=1, le=65535, description="MySQL服务器端口")
mysql_database: str = Field(default="maibot", description="MySQL数据库名")
mysql_user: str = Field(default="root", description="MySQL用户名")
mysql_password: str = Field(default="", description="MySQL密码")
mysql_charset: str = Field(default="utf8mb4", description="MySQL字符集")
mysql_unix_socket: str = Field(default="", description="MySQL Unix套接字路径")
mysql_ssl_mode: Literal["DISABLED", "PREFERRED", "REQUIRED", "VERIFY_CA", "VERIFY_IDENTITY"] = Field(
default="DISABLED", description="SSL模式"
)
mysql_ssl_ca: str = Field(default="", description="SSL CA证书路径")
mysql_ssl_cert: str = Field(default="", description="SSL客户端证书路径")
mysql_ssl_key: str = Field(default="", description="SSL密钥路径")
mysql_autocommit: bool = Field(default=True, description="自动提交事务")
mysql_sql_mode: str = Field(default="TRADITIONAL", description="SQL模式")
# PostgreSQL 配置
postgresql_host: str = Field(default="localhost", description="PostgreSQL服务器地址")
postgresql_port: int = Field(default=5432, ge=1, le=65535, description="PostgreSQL服务器端口")

View File

@@ -61,7 +61,6 @@ INSTALL_NAME_TO_IMPORT_NAME = {
"passlib": "passlib", # 密码哈希库
"bcrypt": "bcrypt", # Bcrypt密码哈希
# ============== 数据库 (Database) ==============
"mysql-connector-python": "mysql.connector", # MySQL官方驱动
"psycopg2-binary": "psycopg2", # PostgreSQL驱动 (二进制)
"pymongo": "pymongo", # MongoDB驱动
"redis": "redis", # Redis客户端

View File

@@ -96,7 +96,6 @@ class ReplyAction(BaseAction):
# 发送回复
reply_text = await self._send_response(response_set)
logger.info(f"{self.log_prefix} reply 动作执行成功")
return True, reply_text
except asyncio.CancelledError:
@@ -218,8 +217,7 @@ class RespondAction(BaseAction):
# 发送回复respond 默认不引用)
reply_text = await self._send_response(response_set)
logger.info(f"{self.log_prefix} respond 动作执行成功")
return True, reply_text
except asyncio.CancelledError:

View File

@@ -126,7 +126,6 @@ class ChatStreamImpressionTool(BaseTool):
updates.append(f"兴趣分: {final_impression['stream_interest_score']:.2f}")
result_text = f"已更新聊天流 {stream_id} 的印象:\n" + "\n".join(updates)
logger.info(f"聊天流印象更新成功: {stream_id}")
return {"type": "chat_stream_impression_update", "id": stream_id, "content": result_text}
@@ -214,7 +213,7 @@ class ChatStreamImpressionTool(BaseTool):
await cache.delete(generate_cache_key("stream_impression", stream_id))
await cache.delete(generate_cache_key("chat_stream", stream_id))
logger.info(f"聊天流印象已更新到数据库: {stream_id}")
logger.debug(f"聊天流印象已更新到数据库: {stream_id}")
else:
error_msg = f"聊天流 {stream_id} 不存在于数据库中,无法更新印象"
logger.error(error_msg)

View File

@@ -88,6 +88,93 @@ class NapcatAdapter(BaseAdapter):
# 注册 utils 内部使用的适配器实例,便于工具方法自动获取 WS
handler_utils.register_adapter(self)
def _should_process_event(self, raw: Dict[str, Any]) -> bool:
"""
检查事件是否应该被处理(黑白名单过滤)
此方法在 from_platform_message 顶层调用,对所有类型的事件(消息、通知、元事件)进行过滤。
Args:
raw: OneBot 原始事件数据
Returns:
bool: True表示应该处理False表示应该过滤
"""
if not self.plugin:
return True
plugin_config = self.plugin.config
if not plugin_config:
return True # 如果没有配置,默认处理所有事件
features_config = plugin_config.get("features", {})
post_type = raw.get("post_type")
# 获取用户信息(根据事件类型从不同字段获取)
user_id: str = ""
if post_type == "message":
sender_info = raw.get("sender", {})
user_id = str(sender_info.get("user_id", ""))
elif post_type == "notice":
user_id = str(raw.get("user_id", ""))
else:
# 元事件或其他类型不需要过滤
return True
# 检查全局封禁用户列表
ban_user_ids = [str(item) for item in features_config.get("ban_user_id", [])]
if user_id and user_id in ban_user_ids:
logger.debug(f"用户 {user_id} 在全局封禁列表中,事件被过滤")
return False
# 检查是否屏蔽其他QQ机器人仅对消息事件生效
if post_type == "message" and features_config.get("ban_qq_bot", False):
sender_info = raw.get("sender", {})
role = sender_info.get("role", "")
if role == "admin" or "bot" in str(sender_info).lower():
logger.debug(f"检测到机器人消息 {user_id},事件被过滤")
return False
# 获取消息类型(消息事件使用 message_type通知事件根据 group_id 判断)
message_type = raw.get("message_type")
group_id = raw.get("group_id")
# 如果是通知事件,根据是否有 group_id 判断是群通知还是私聊通知
if post_type == "notice":
message_type = "group" if group_id else "private"
# 群聊/群通知过滤
if message_type == "group" and group_id:
group_id_str = str(group_id)
group_list_type = features_config.get("group_list_type", "blacklist")
group_list = [str(item) for item in features_config.get("group_list", [])]
if group_list_type == "blacklist":
if group_id_str in group_list:
logger.debug(f"群聊 {group_id_str} 在黑名单中,事件被过滤")
return False
else: # whitelist
if group_id_str not in group_list:
logger.debug(f"群聊 {group_id_str} 不在白名单中,事件被过滤")
return False
# 私聊/私聊通知过滤
elif message_type == "private":
private_list_type = features_config.get("private_list_type", "blacklist")
private_list = [str(item) for item in features_config.get("private_list", [])]
if private_list_type == "blacklist":
if user_id in private_list:
logger.debug(f"私聊用户 {user_id} 在黑名单中,事件被过滤")
return False
else: # whitelist
if user_id not in private_list:
logger.debug(f"私聊用户 {user_id} 不在白名单中,事件被过滤")
return False
# 通过所有过滤条件
return True
async def on_adapter_loaded(self) -> None:
"""适配器加载时的初始化"""
logger.info("Napcat 适配器正在启动...")
@@ -161,6 +248,8 @@ class NapcatAdapter(BaseAdapter):
- notice 事件 → 通知(戳一戳、表情回复等)
- meta_event 事件 → 元事件(心跳、生命周期)
- API 响应 → 存入响应池
注意:黑白名单等过滤机制在此方法最开始执行,确保所有类型的事件都能被过滤。
"""
post_type = raw.get("post_type")
@@ -171,6 +260,11 @@ class NapcatAdapter(BaseAdapter):
future = self._response_pool[echo]
if not future.done():
future.set_result(raw)
return None
# 顶层过滤:黑白名单等过滤机制
if not self._should_process_event(raw):
return None
try:
# 消息事件

View File

@@ -39,79 +39,6 @@ class MessageHandler:
"""设置插件配置"""
self.plugin_config = config
def _should_process_message(self, raw: Dict[str, Any]) -> bool:
"""
检查消息是否应该被处理(黑白名单过滤)
Args:
raw: OneBot 原始消息数据
Returns:
bool: True表示应该处理False表示应该过滤
"""
if not self.plugin_config:
return True # 如果没有配置,默认处理所有消息
features_config = self.plugin_config.get("features", {})
# 获取消息基本信息
message_type = raw.get("message_type")
sender_info = raw.get("sender", {})
user_id = str(sender_info.get("user_id", ""))
# 检查全局封禁用户列表
ban_user_ids = [str(item) for item in features_config.get("ban_user_id", [])]
if user_id in ban_user_ids:
logger.debug(f"用户 {user_id} 在全局封禁列表中,消息被过滤")
return False
# 检查是否屏蔽其他QQ机器人
if features_config.get("ban_qq_bot", False):
# 判断是否为机器人消息通常通过sender中的role字段或其他标识
role = sender_info.get("role", "")
if role == "admin" or "bot" in str(sender_info).lower():
logger.debug(f"检测到机器人消息 {user_id},消息被过滤")
return False
# 群聊消息处理
if message_type == "group":
group_id = str(raw.get("group_id", ""))
# 获取群聊配置
group_list_type = features_config.get("group_list_type", "blacklist")
group_list = [str(item) for item in features_config.get("group_list", [])]
if group_list_type == "blacklist":
# 黑名单模式:如果在黑名单中就过滤
if group_id in group_list:
logger.debug(f"群聊 {group_id} 在黑名单中,消息被过滤")
return False
else: # whitelist
# 白名单模式:如果不在白名单中就过滤
if group_id not in group_list:
logger.debug(f"群聊 {group_id} 不在白名单中,消息被过滤")
return False
# 私聊消息处理
elif message_type == "private":
# 获取私聊配置
private_list_type = features_config.get("private_list_type", "blacklist")
private_list = [str(item) for item in features_config.get("private_list", [])]
if private_list_type == "blacklist":
# 黑名单模式:如果在黑名单中就过滤
if user_id in private_list:
logger.debug(f"私聊用户 {user_id} 在黑名单中,消息被过滤")
return False
else: # whitelist
# 白名单模式:如果不在白名单中就过滤
if user_id not in private_list:
logger.debug(f"私聊用户 {user_id} 不在白名单中,消息被过滤")
return False
# 通过所有过滤条件
return True
async def handle_raw_message(self, raw: Dict[str, Any]):
"""
处理原始消息并转换为 MessageEnvelope
@@ -120,18 +47,17 @@ class MessageHandler:
raw: OneBot 原始消息数据
Returns:
MessageEnvelope (dict) or None (if message is filtered)
MessageEnvelope (dict) or None
Note:
黑白名单过滤已移动到 NapcatAdapter.from_platform_message 顶层执行,
确保所有类型的事件(消息、通知等)都能被统一过滤。
"""
message_type = raw.get("message_type")
message_id = str(raw.get("message_id", ""))
message_time = time.time()
# 黑白名单过滤
if not self._should_process_message(raw):
logger.debug(f"消息被黑白名单过滤丢弃: message_id={message_id}")
return None
msg_builder = MessageBuilder()
# 构造用户信息