From 37a0725d99b4f538ba76dfd21082c6aa091f4f0f Mon Sep 17 00:00:00 2001 From: Windpicker-owo <3431391539@qq.com> Date: Sat, 1 Nov 2025 15:39:43 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E6=B8=85=E7=90=86=E6=97=A7=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E5=AE=9E=E7=8E=B0=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 删除old/目录下的旧实现文件 - 删除sqlalchemy_models.py.bak备份文件 - 完成数据库重构代码清理工作 --- src/common/database/old/database.py | 109 --- src/common/database/old/db_batch_scheduler.py | 462 --------- src/common/database/old/db_migration.py | 140 --- .../database/old/sqlalchemy_database_api.py | 426 --------- src/common/database/old/sqlalchemy_init.py | 124 --- src/common/database/old/sqlalchemy_models.py | 892 ------------------ src/common/database/sqlalchemy_models.py.bak | 872 ----------------- 7 files changed, 3025 deletions(-) delete mode 100644 src/common/database/old/database.py delete mode 100644 src/common/database/old/db_batch_scheduler.py delete mode 100644 src/common/database/old/db_migration.py delete mode 100644 src/common/database/old/sqlalchemy_database_api.py delete mode 100644 src/common/database/old/sqlalchemy_init.py delete mode 100644 src/common/database/old/sqlalchemy_models.py delete mode 100644 src/common/database/sqlalchemy_models.py.bak diff --git a/src/common/database/old/database.py b/src/common/database/old/database.py deleted file mode 100644 index 681304f02..000000000 --- a/src/common/database/old/database.py +++ /dev/null @@ -1,109 +0,0 @@ -import os - -from rich.traceback import install - -from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool - -# 数据库批量调度器和连接池 -from src.common.database.db_batch_scheduler import get_db_batch_scheduler - -# SQLAlchemy相关导入 -from src.common.database.sqlalchemy_init import initialize_database_compat -from src.common.database.sqlalchemy_models import get_engine -from src.common.logger import get_logger - -install(extra_lines=3) - -_sql_engine = None - -logger = get_logger("database") - - -# 兼容性:为了不破坏现有代码,保留db变量但指向SQLAlchemy -class DatabaseProxy: - """数据库代理类""" - - def __init__(self): - self._engine = None - self._session = None - - @staticmethod - async def initialize(*args, **kwargs): - """初始化数据库连接""" - result = await initialize_database_compat() - - # 启动数据库优化系统 - try: - # 启动数据库批量调度器 - batch_scheduler = get_db_batch_scheduler() - await batch_scheduler.start() - logger.info("🚀 数据库批量调度器启动成功") - - # 启动连接池管理器 - await start_connection_pool() - logger.info("🚀 连接池管理器启动成功") - except Exception as e: - logger.error(f"启动数据库优化系统失败: {e}") - - return result - - -# 创建全局数据库代理实例 -db = DatabaseProxy() - - -async def initialize_sql_database(database_config): - """ - 根据配置初始化SQL数据库连接(SQLAlchemy版本) - - Args: - database_config: DatabaseConfig对象 - """ - global _sql_engine - - try: - logger.info("使用SQLAlchemy初始化SQL数据库...") - - # 记录数据库配置信息 - if database_config.database_type == "mysql": - connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}" - logger.info("MySQL数据库连接配置:") - logger.info(f" 连接信息: {connection_info}") - logger.info(f" 字符集: {database_config.mysql_charset}") - else: - ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) - if not os.path.isabs(database_config.sqlite_path): - db_path = os.path.join(ROOT_PATH, database_config.sqlite_path) - else: - db_path = database_config.sqlite_path - logger.info("SQLite数据库连接配置:") - logger.info(f" 数据库文件: {db_path}") - - # 使用SQLAlchemy初始化 - success = await initialize_database_compat() - if success: - _sql_engine = await get_engine() - logger.info("SQLAlchemy数据库初始化成功") - else: - logger.error("SQLAlchemy数据库初始化失败") - - return _sql_engine - - except Exception as e: - logger.error(f"初始化SQL数据库失败: {e}") - return None - - -async def stop_database(): - """停止数据库相关服务""" - try: - # 停止连接池管理器 - await stop_connection_pool() - logger.info("🛑 连接池管理器已停止") - - # 停止数据库批量调度器 - batch_scheduler = get_db_batch_scheduler() - await batch_scheduler.stop() - logger.info("🛑 数据库批量调度器已停止") - except Exception as e: - logger.error(f"停止数据库优化系统时出错: {e}") diff --git a/src/common/database/old/db_batch_scheduler.py b/src/common/database/old/db_batch_scheduler.py deleted file mode 100644 index a09f7fb84..000000000 --- a/src/common/database/old/db_batch_scheduler.py +++ /dev/null @@ -1,462 +0,0 @@ -""" -数据库批量调度器 -实现多个数据库请求的智能合并和批量处理,减少数据库连接竞争 -""" - -import asyncio -import time -from collections import defaultdict, deque -from collections.abc import Callable -from contextlib import asynccontextmanager -from dataclasses import dataclass -from typing import Any, TypeVar - -from sqlalchemy import delete, insert, select, update - -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.logger import get_logger - -logger = get_logger("db_batch_scheduler") - -T = TypeVar("T") - - -@dataclass -class BatchOperation: - """批量操作基础类""" - - operation_type: str # 'select', 'insert', 'update', 'delete' - model_class: Any - conditions: dict[str, Any] - data: dict[str, Any] | None = None - callback: Callable | None = None - future: asyncio.Future | None = None - timestamp: float = 0.0 - - def __post_init__(self): - if self.timestamp == 0.0: - self.timestamp = time.time() - - -@dataclass -class BatchResult: - """批量操作结果""" - - success: bool - data: Any = None - error: str | None = None - - -class DatabaseBatchScheduler: - """数据库批量调度器""" - - def __init__( - self, - batch_size: int = 50, - max_wait_time: float = 0.1, # 100ms - max_queue_size: int = 1000, - ): - self.batch_size = batch_size - self.max_wait_time = max_wait_time - self.max_queue_size = max_queue_size - - # 操作队列,按操作类型和模型分类 - self.operation_queues: dict[str, deque] = defaultdict(deque) - - # 调度控制 - self._scheduler_task: asyncio.Task | None = None - self._is_running = False - self._lock = asyncio.Lock() - - # 统计信息 - self.stats = {"total_operations": 0, "batched_operations": 0, "cache_hits": 0, "execution_time": 0.0} - - # 简单的结果缓存(用于频繁的查询) - self._result_cache: dict[str, tuple[Any, float]] = {} - self._cache_ttl = 5.0 # 5秒缓存 - - async def start(self): - """启动调度器""" - if self._is_running: - return - - self._is_running = True - self._scheduler_task = asyncio.create_task(self._scheduler_loop()) - logger.info("数据库批量调度器已启动") - - async def stop(self): - """停止调度器""" - if not self._is_running: - return - - self._is_running = False - if self._scheduler_task: - self._scheduler_task.cancel() - try: - await self._scheduler_task - except asyncio.CancelledError: - pass - - # 处理剩余的操作 - await self._flush_all_queues() - logger.info("数据库批量调度器已停止") - - def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str: - """生成缓存键""" - # 简单的缓存键生成,实际可以根据需要优化 - key_parts = [operation_type, model_class.__name__, str(sorted(conditions.items()))] - return "|".join(key_parts) - - def _get_from_cache(self, cache_key: str) -> Any | None: - """从缓存获取结果""" - if cache_key in self._result_cache: - result, timestamp = self._result_cache[cache_key] - if time.time() - timestamp < self._cache_ttl: - self.stats["cache_hits"] += 1 - return result - else: - # 清理过期缓存 - del self._result_cache[cache_key] - return None - - def _set_cache(self, cache_key: str, result: Any): - """设置缓存""" - self._result_cache[cache_key] = (result, time.time()) - - async def add_operation(self, operation: BatchOperation) -> asyncio.Future: - """添加操作到队列""" - # 检查是否可以立即返回缓存结果 - if operation.operation_type == "select": - cache_key = self._generate_cache_key(operation.operation_type, operation.model_class, operation.conditions) - cached_result = self._get_from_cache(cache_key) - if cached_result is not None: - if operation.callback: - operation.callback(cached_result) - future = asyncio.get_event_loop().create_future() - future.set_result(cached_result) - return future - - # 创建future用于返回结果 - future = asyncio.get_event_loop().create_future() - operation.future = future - - # 添加到队列 - queue_key = f"{operation.operation_type}_{operation.model_class.__name__}" - - async with self._lock: - if len(self.operation_queues[queue_key]) >= self.max_queue_size: - # 队列满了,直接执行 - await self._execute_operations([operation]) - else: - self.operation_queues[queue_key].append(operation) - self.stats["total_operations"] += 1 - - return future - - async def _scheduler_loop(self): - """调度器主循环""" - while self._is_running: - try: - await asyncio.sleep(self.max_wait_time) - await self._flush_all_queues() - except asyncio.CancelledError: - break - except Exception as e: - logger.error(f"调度器循环异常: {e}", exc_info=True) - - async def _flush_all_queues(self): - """刷新所有队列""" - async with self._lock: - if not any(self.operation_queues.values()): - return - - # 复制队列内容,避免长时间占用锁 - queues_copy = {key: deque(operations) for key, operations in self.operation_queues.items()} - # 清空原队列 - for queue in self.operation_queues.values(): - queue.clear() - - # 批量执行各队列的操作 - for operations in queues_copy.values(): - if operations: - await self._execute_operations(list(operations)) - - async def _execute_operations(self, operations: list[BatchOperation]): - """执行批量操作""" - if not operations: - return - - start_time = time.time() - - try: - # 按操作类型分组 - op_groups = defaultdict(list) - for op in operations: - op_groups[op.operation_type].append(op) - - # 为每种操作类型创建批量执行任务 - tasks = [] - for op_type, ops in op_groups.items(): - if op_type == "select": - tasks.append(self._execute_select_batch(ops)) - elif op_type == "insert": - tasks.append(self._execute_insert_batch(ops)) - elif op_type == "update": - tasks.append(self._execute_update_batch(ops)) - elif op_type == "delete": - tasks.append(self._execute_delete_batch(ops)) - - # 并发执行所有操作 - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 处理结果 - for i, result in enumerate(results): - operation = operations[i] - if isinstance(result, Exception): - if operation.future and not operation.future.done(): - operation.future.set_exception(result) - else: - if operation.callback: - try: - operation.callback(result) - except Exception as e: - logger.warning(f"操作回调执行失败: {e}") - - if operation.future and not operation.future.done(): - operation.future.set_result(result) - - # 缓存查询结果 - if operation.operation_type == "select": - cache_key = self._generate_cache_key( - operation.operation_type, operation.model_class, operation.conditions - ) - self._set_cache(cache_key, result) - - self.stats["batched_operations"] += len(operations) - - except Exception as e: - logger.error(f"批量操作执行失败: {e}", exc_info="") - # 设置所有future的异常状态 - for operation in operations: - if operation.future and not operation.future.done(): - operation.future.set_exception(e) - finally: - self.stats["execution_time"] += time.time() - start_time - - async def _execute_select_batch(self, operations: list[BatchOperation]): - """批量执行查询操作""" - # 合并相似的查询条件 - merged_conditions = self._merge_select_conditions(operations) - - async with get_db_session() as session: - results = [] - for conditions, ops in merged_conditions.items(): - try: - # 构建查询 - query = select(ops[0].model_class) - for field_name, value in conditions.items(): - model_attr = getattr(ops[0].model_class, field_name) - if isinstance(value, list | tuple | set): - query = query.where(model_attr.in_(value)) - else: - query = query.where(model_attr == value) - - # 执行查询 - result = await session.execute(query) - data = result.scalars().all() - - # 分发结果到各个操作 - for op in ops: - if len(conditions) == 1 and len(ops) == 1: - # 单个查询,直接返回所有结果 - op_result = data - else: - # 需要根据条件过滤结果 - op_result = [ - item - for item in data - if all(getattr(item, k) == v for k, v in op.conditions.items() if hasattr(item, k)) - ] - results.append(op_result) - - except Exception as e: - logger.error(f"批量查询失败: {e}", exc_info=True) - results.append([]) - - return results if len(results) > 1 else results[0] if results else [] - - async def _execute_insert_batch(self, operations: list[BatchOperation]): - """批量执行插入操作""" - async with get_db_session() as session: - try: - # 收集所有要插入的数据 - all_data = [op.data for op in operations if op.data] - if not all_data: - return [] - - # 批量插入 - stmt = insert(operations[0].model_class).values(all_data) - result = await session.execute(stmt) - await session.commit() - - return [result.rowcount] * len(operations) - - except Exception as e: - await session.rollback() - logger.error(f"批量插入失败: {e}", exc_info=True) - return [0] * len(operations) - - async def _execute_update_batch(self, operations: list[BatchOperation]): - """批量执行更新操作""" - async with get_db_session() as session: - try: - results = [] - for op in operations: - if not op.data or not op.conditions: - results.append(0) - continue - - stmt = update(op.model_class) - for field_name, value in op.conditions.items(): - model_attr = getattr(op.model_class, field_name) - if isinstance(value, list | tuple | set): - stmt = stmt.where(model_attr.in_(value)) - else: - stmt = stmt.where(model_attr == value) - - stmt = stmt.values(**op.data) - result = await session.execute(stmt) - results.append(result.rowcount) - - await session.commit() - return results - - except Exception as e: - await session.rollback() - logger.error(f"批量更新失败: {e}", exc_info=True) - return [0] * len(operations) - - async def _execute_delete_batch(self, operations: list[BatchOperation]): - """批量执行删除操作""" - async with get_db_session() as session: - try: - results = [] - for op in operations: - if not op.conditions: - results.append(0) - continue - - stmt = delete(op.model_class) - for field_name, value in op.conditions.items(): - model_attr = getattr(op.model_class, field_name) - if isinstance(value, list | tuple | set): - stmt = stmt.where(model_attr.in_(value)) - else: - stmt = stmt.where(model_attr == value) - - result = await session.execute(stmt) - results.append(result.rowcount) - - await session.commit() - return results - - except Exception as e: - await session.rollback() - logger.error(f"批量删除失败: {e}", exc_info=True) - return [0] * len(operations) - - def _merge_select_conditions(self, operations: list[BatchOperation]) -> dict[tuple, list[BatchOperation]]: - """合并相似的查询条件""" - merged = {} - - for op in operations: - # 生成条件键 - condition_key = tuple(sorted(op.conditions.keys())) - - if condition_key not in merged: - merged[condition_key] = {} - - # 尝试合并相同字段的值 - for field_name, value in op.conditions.items(): - if field_name not in merged[condition_key]: - merged[condition_key][field_name] = [] - - if isinstance(value, list | tuple | set): - merged[condition_key][field_name].extend(value) - else: - merged[condition_key][field_name].append(value) - - # 记录操作 - if condition_key not in merged: - merged[condition_key] = {"_operations": []} - if "_operations" not in merged[condition_key]: - merged[condition_key]["_operations"] = [] - merged[condition_key]["_operations"].append(op) - - # 去重并构建最终条件 - final_merged = {} - for condition_key, conditions in merged.items(): - operations = conditions.pop("_operations") - - # 去重 - for field_name, values in conditions.items(): - conditions[field_name] = list(set(values)) - - final_merged[condition_key] = operations - - return final_merged - - def get_stats(self) -> dict[str, Any]: - """获取统计信息""" - return { - **self.stats, - "cache_size": len(self._result_cache), - "queue_sizes": {k: len(v) for k, v in self.operation_queues.items()}, - "is_running": self._is_running, - } - - -# 全局数据库批量调度器实例 -db_batch_scheduler = DatabaseBatchScheduler() - - -@asynccontextmanager -async def get_batch_session(): - """获取批量会话上下文管理器""" - if not db_batch_scheduler._is_running: - await db_batch_scheduler.start() - - try: - yield db_batch_scheduler - finally: - pass - - -# 便捷函数 -async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any: - """批量查询""" - operation = BatchOperation(operation_type="select", model_class=model_class, conditions=conditions) - return await db_batch_scheduler.add_operation(operation) - - -async def batch_insert(model_class: Any, data: dict[str, Any]) -> int: - """批量插入""" - operation = BatchOperation(operation_type="insert", model_class=model_class, conditions={}, data=data) - return await db_batch_scheduler.add_operation(operation) - - -async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int: - """批量更新""" - operation = BatchOperation(operation_type="update", model_class=model_class, conditions=conditions, data=data) - return await db_batch_scheduler.add_operation(operation) - - -async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int: - """批量删除""" - operation = BatchOperation(operation_type="delete", model_class=model_class, conditions=conditions) - return await db_batch_scheduler.add_operation(operation) - - -def get_db_batch_scheduler() -> DatabaseBatchScheduler: - """获取数据库批量调度器实例""" - return db_batch_scheduler diff --git a/src/common/database/old/db_migration.py b/src/common/database/old/db_migration.py deleted file mode 100644 index d699964ac..000000000 --- a/src/common/database/old/db_migration.py +++ /dev/null @@ -1,140 +0,0 @@ -# mmc/src/common/database/db_migration.py - -from sqlalchemy import inspect -from sqlalchemy.sql import text - -from src.common.database.sqlalchemy_models import Base, get_engine -from src.common.logger import get_logger - -logger = get_logger("db_migration") - - -async def check_and_migrate_database(existing_engine=None): - """ - 异步检查数据库结构并自动迁移。 - - 自动创建不存在的表。 - - 自动为现有表添加缺失的列。 - - 自动为现有表创建缺失的索引。 - - Args: - existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。 - """ - logger.info("正在检查数据库结构并执行自动迁移...") - engine = existing_engine if existing_engine is not None else await get_engine() - - async with engine.connect() as connection: - # 在同步上下文中运行inspector操作 - def get_inspector(sync_conn): - return inspect(sync_conn) - - inspector = await connection.run_sync(get_inspector) - - # 在同步lambda中传递inspector - db_table_names = await connection.run_sync(lambda conn: set(inspector.get_table_names())) - - # 1. 首先处理表的创建 - tables_to_create = [] - for table_name, table in Base.metadata.tables.items(): - if table_name not in db_table_names: - tables_to_create.append(table) - - if tables_to_create: - logger.info(f"发现 {len(tables_to_create)} 个不存在的表,正在创建...") - try: - # 一次性创建所有缺失的表 - await connection.run_sync( - lambda sync_conn: Base.metadata.create_all(sync_conn, tables=tables_to_create) - ) - for table in tables_to_create: - logger.info(f"表 '{table.name}' 创建成功。") - db_table_names.add(table.name) # 将新创建的表添加到集合中 - except Exception as e: - logger.error(f"创建表时失败: {e}", exc_info=True) - - # 2. 然后处理现有表的列和索引的添加 - for table_name, table in Base.metadata.tables.items(): - if table_name not in db_table_names: - logger.warning(f"跳过检查表 '{table_name}',因为它在创建步骤中可能已失败。") - continue - - logger.debug(f"正在检查表 '{table_name}' 的列和索引...") - - try: - # 检查并添加缺失的列 - db_columns = await connection.run_sync( - lambda conn: {col["name"] for col in inspector.get_columns(table_name)} - ) - model_columns = {col.name for col in table.c} - missing_columns = model_columns - db_columns - - if missing_columns: - logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}") - - def add_columns_sync(conn): - dialect = conn.dialect - compiler = dialect.ddl_compiler(dialect, None) - - for column_name in missing_columns: - column = table.c[column_name] - column_type = compiler.get_column_specification(column) - sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}" - - if column.default: - # 手动处理不同方言的默认值 - default_arg = column.default.arg - if dialect.name == "sqlite" and isinstance(default_arg, bool): - # SQLite 将布尔值存储为 0 或 1 - default_value = "1" if default_arg else "0" - elif hasattr(compiler, "render_literal_value"): - try: - # 尝试使用 render_literal_value - default_value = compiler.render_literal_value(default_arg, column.type) - except AttributeError: - # 如果失败,则回退到简单的字符串转换 - default_value = ( - f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) - ) - else: - # 对于没有 render_literal_value 的旧版或特定方言 - default_value = ( - f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg) - ) - - sql += f" DEFAULT {default_value}" - - if not column.nullable: - sql += " NOT NULL" - - conn.execute(text(sql)) - logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。") - - await connection.run_sync(add_columns_sync) - else: - logger.info(f"表 '{table_name}' 的列结构一致。") - - # 检查并创建缺失的索引 - db_indexes = await connection.run_sync( - lambda conn: {idx["name"] for idx in inspector.get_indexes(table_name)} - ) - model_indexes = {idx.name for idx in table.indexes} - missing_indexes = model_indexes - db_indexes - - if missing_indexes: - logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}") - - def add_indexes_sync(conn): - for index_name in missing_indexes: - index_obj = next((idx for idx in table.indexes if idx.name == index_name), None) - if index_obj is not None: - index_obj.create(conn) - logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。") - - await connection.run_sync(add_indexes_sync) - else: - logger.debug(f"表 '{table_name}' 的索引一致。") - - except Exception as e: - logger.error(f"在处理表 '{table_name}' 时发生意外错误: {e}", exc_info=True) - continue - - logger.info("数据库结构检查与自动迁移完成。") diff --git a/src/common/database/old/sqlalchemy_database_api.py b/src/common/database/old/sqlalchemy_database_api.py deleted file mode 100644 index 38c972236..000000000 --- a/src/common/database/old/sqlalchemy_database_api.py +++ /dev/null @@ -1,426 +0,0 @@ -"""SQLAlchemy数据库API模块 - -提供基于SQLAlchemy的数据库操作,替换Peewee以解决MySQL连接问题 -支持自动重连、连接池管理和更好的错误处理 -""" - -import time -import traceback -from typing import Any - -from sqlalchemy import and_, asc, desc, func, select -from sqlalchemy.exc import SQLAlchemyError - -from src.common.database.sqlalchemy_models import ( - ActionRecords, - CacheEntries, - ChatStreams, - Emoji, - Expression, - GraphEdges, - GraphNodes, - ImageDescriptions, - Images, - LLMUsage, - MaiZoneScheduleStatus, - Memory, - Messages, - OnlineTime, - PersonInfo, - Schedule, - ThinkingLog, - UserRelationships, - get_db_session, -) -from src.common.logger import get_logger - -logger = get_logger("sqlalchemy_database_api") - -# 模型映射表,用于通过名称获取模型类 -MODEL_MAPPING = { - "Messages": Messages, - "ActionRecords": ActionRecords, - "PersonInfo": PersonInfo, - "ChatStreams": ChatStreams, - "LLMUsage": LLMUsage, - "Emoji": Emoji, - "Images": Images, - "ImageDescriptions": ImageDescriptions, - "OnlineTime": OnlineTime, - "Memory": Memory, - "Expression": Expression, - "ThinkingLog": ThinkingLog, - "GraphNodes": GraphNodes, - "GraphEdges": GraphEdges, - "Schedule": Schedule, - "MaiZoneScheduleStatus": MaiZoneScheduleStatus, - "CacheEntries": CacheEntries, - "UserRelationships": UserRelationships, -} - - -async def build_filters(model_class, filters: dict[str, Any]): - """构建查询过滤条件""" - conditions = [] - - for field_name, value in filters.items(): - if not hasattr(model_class, field_name): - logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'") - continue - - field = getattr(model_class, field_name) - - if isinstance(value, dict): - # 处理 MongoDB 风格的操作符 - for op, op_value in value.items(): - if op == "$gt": - conditions.append(field > op_value) - elif op == "$lt": - conditions.append(field < op_value) - elif op == "$gte": - conditions.append(field >= op_value) - elif op == "$lte": - conditions.append(field <= op_value) - elif op == "$ne": - conditions.append(field != op_value) - elif op == "$in": - conditions.append(field.in_(op_value)) - elif op == "$nin": - conditions.append(~field.in_(op_value)) - else: - logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')") - else: - # 直接相等比较 - conditions.append(field == value) - - return conditions - - -async def db_query( - model_class, - data: dict[str, Any] | None = None, - query_type: str | None = "get", - filters: dict[str, Any] | None = None, - limit: int | None = None, - order_by: list[str] | None = None, - single_result: bool | None = False, -) -> list[dict[str, Any]] | dict[str, Any] | None: - """执行异步数据库查询操作 - - Args: - model_class: SQLAlchemy模型类 - data: 用于创建或更新的数据字典 - query_type: 查询类型 ("get", "create", "update", "delete", "count") - filters: 过滤条件字典 - limit: 限制结果数量 - order_by: 排序字段,前缀'-'表示降序 - single_result: 是否只返回单个结果 - - Returns: - 根据查询类型返回相应结果 - """ - try: - if query_type not in ["get", "create", "update", "delete", "count"]: - raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'") - - async with get_db_session() as session: - if not session: - logger.error("[SQLAlchemy] 无法获取数据库会话") - return None if single_result else [] - - if query_type == "get": - query = select(model_class) - - # 应用过滤条件 - if filters: - conditions = await build_filters(model_class, filters) - if conditions: - query = query.where(and_(*conditions)) - - # 应用排序 - if order_by: - for field_name in order_by: - if field_name.startswith("-"): - field_name = field_name[1:] - if hasattr(model_class, field_name): - query = query.order_by(desc(getattr(model_class, field_name))) - else: - if hasattr(model_class, field_name): - query = query.order_by(asc(getattr(model_class, field_name))) - - # 应用限制 - if limit and limit > 0: - query = query.limit(limit) - - # 执行查询 - result = await session.execute(query) - results = result.scalars().all() - - # 转换为字典格式 - result_dicts = [] - for result_obj in results: - result_dict = {} - for column in result_obj.__table__.columns: - result_dict[column.name] = getattr(result_obj, column.name) - result_dicts.append(result_dict) - - if single_result: - return result_dicts[0] if result_dicts else None - return result_dicts - - elif query_type == "create": - if not data: - raise ValueError("创建记录需要提供data参数") - - # 创建新记录 - new_record = model_class(**data) - session.add(new_record) - await session.flush() # 获取自动生成的ID - - # 转换为字典格式返回 - result_dict = {} - for column in new_record.__table__.columns: - result_dict[column.name] = getattr(new_record, column.name) - return result_dict - - elif query_type == "update": - if not data: - raise ValueError("更新记录需要提供data参数") - - query = select(model_class) - - # 应用过滤条件 - if filters: - conditions = await build_filters(model_class, filters) - if conditions: - query = query.where(and_(*conditions)) - - # 首先获取要更新的记录 - result = await session.execute(query) - records_to_update = result.scalars().all() - - # 更新每个记录 - affected_rows = 0 - for record in records_to_update: - for field, value in data.items(): - if hasattr(record, field): - setattr(record, field, value) - affected_rows += 1 - - return affected_rows - - elif query_type == "delete": - query = select(model_class) - - # 应用过滤条件 - if filters: - conditions = await build_filters(model_class, filters) - if conditions: - query = query.where(and_(*conditions)) - - # 首先获取要删除的记录 - result = await session.execute(query) - records_to_delete = result.scalars().all() - - # 删除记录 - affected_rows = 0 - for record in records_to_delete: - await session.delete(record) - affected_rows += 1 - - return affected_rows - - elif query_type == "count": - query = select(func.count(model_class.id)) - - # 应用过滤条件 - if filters: - conditions = await build_filters(model_class, filters) - if conditions: - query = query.where(and_(*conditions)) - - result = await session.execute(query) - return result.scalar() - - except SQLAlchemyError as e: - logger.error(f"[SQLAlchemy] 数据库操作出错: {e}") - traceback.print_exc() - - # 根据查询类型返回合适的默认值 - if query_type == "get": - return None if single_result else [] - elif query_type in ["create", "update", "delete", "count"]: - return None - return None - - except Exception as e: - logger.error(f"[SQLAlchemy] 意外错误: {e}") - traceback.print_exc() - - if query_type == "get": - return None if single_result else [] - return None - - -async def db_save( - model_class, data: dict[str, Any], key_field: str | None = None, key_value: Any | None = None -) -> dict[str, Any] | None: - """异步保存数据到数据库(创建或更新) - - Args: - model_class: SQLAlchemy模型类 - data: 要保存的数据字典 - key_field: 用于查找现有记录的字段名 - key_value: 用于查找现有记录的字段值 - - Returns: - 保存后的记录数据或None - """ - try: - async with get_db_session() as session: - if not session: - logger.error("[SQLAlchemy] 无法获取数据库会话") - return None - # 如果提供了key_field和key_value,尝试更新现有记录 - if key_field and key_value is not None: - if hasattr(model_class, key_field): - query = select(model_class).where(getattr(model_class, key_field) == key_value) - result = await session.execute(query) - existing_record = result.scalars().first() - - if existing_record: - # 更新现有记录 - for field, value in data.items(): - if hasattr(existing_record, field): - setattr(existing_record, field, value) - - await session.flush() - - # 转换为字典格式返回 - result_dict = {} - for column in existing_record.__table__.columns: - result_dict[column.name] = getattr(existing_record, column.name) - return result_dict - - # 创建新记录 - new_record = model_class(**data) - session.add(new_record) - await session.flush() - - # 转换为字典格式返回 - result_dict = {} - for column in new_record.__table__.columns: - result_dict[column.name] = getattr(new_record, column.name) - return result_dict - - except SQLAlchemyError as e: - logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}") - traceback.print_exc() - return None - except Exception as e: - logger.error(f"[SQLAlchemy] 保存时意外错误: {e}") - traceback.print_exc() - return None - - -async def db_get( - model_class, - filters: dict[str, Any] | None = None, - limit: int | None = None, - order_by: str | None = None, - single_result: bool | None = False, -) -> list[dict[str, Any]] | dict[str, Any] | None: - """异步从数据库获取记录 - - Args: - model_class: SQLAlchemy模型类 - filters: 过滤条件 - limit: 结果数量限制 - order_by: 排序字段,前缀'-'表示降序 - single_result: 是否只返回单个结果 - - Returns: - 记录数据或None - """ - order_by_list = [order_by] if order_by else None - return await db_query( - model_class=model_class, - query_type="get", - filters=filters, - limit=limit, - order_by=order_by_list, - single_result=single_result, - ) - - -async def store_action_info( - chat_stream=None, - action_build_into_prompt: bool = False, - action_prompt_display: str = "", - action_done: bool = True, - thinking_id: str = "", - action_data: dict | None = None, - action_name: str = "", -) -> dict[str, Any] | None: - """异步存储动作信息到数据库 - - Args: - chat_stream: 聊天流对象 - action_build_into_prompt: 是否将此动作构建到提示中 - action_prompt_display: 动作的提示显示文本 - action_done: 动作是否完成 - thinking_id: 关联的思考ID - action_data: 动作数据字典 - action_name: 动作名称 - - Returns: - 保存的记录数据或None - """ - try: - import orjson - - # 构建动作记录数据 - record_data = { - "action_id": thinking_id or str(int(time.time() * 1000000)), - "time": time.time(), - "action_name": action_name, - "action_data": orjson.dumps(action_data or {}).decode("utf-8"), - "action_done": action_done, - "action_build_into_prompt": action_build_into_prompt, - "action_prompt_display": action_prompt_display, - } - - # 从chat_stream获取聊天信息 - if chat_stream: - record_data.update( - { - "chat_id": getattr(chat_stream, "stream_id", ""), - "chat_info_stream_id": getattr(chat_stream, "stream_id", ""), - "chat_info_platform": getattr(chat_stream, "platform", ""), - } - ) - else: - record_data.update( - { - "chat_id": "", - "chat_info_stream_id": "", - "chat_info_platform": "", - } - ) - - # 保存记录 - saved_record = await db_save( - ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"] - ) - - if saved_record: - logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})") - else: - logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}") - - return saved_record - - except Exception as e: - logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}") - traceback.print_exc() - return None diff --git a/src/common/database/old/sqlalchemy_init.py b/src/common/database/old/sqlalchemy_init.py deleted file mode 100644 index daf61f3a5..000000000 --- a/src/common/database/old/sqlalchemy_init.py +++ /dev/null @@ -1,124 +0,0 @@ -"""SQLAlchemy数据库初始化模块 - -替换Peewee的数据库初始化逻辑 -提供统一的异步数据库初始化接口 -""" - -from sqlalchemy.exc import SQLAlchemyError - -from src.common.database.sqlalchemy_models import Base, get_engine, initialize_database -from src.common.logger import get_logger - -logger = get_logger("sqlalchemy_init") - - -async def initialize_sqlalchemy_database() -> bool: - """ - 初始化SQLAlchemy异步数据库 - 创建所有表结构 - - Returns: - bool: 初始化是否成功 - """ - try: - logger.info("开始初始化SQLAlchemy异步数据库...") - - # 初始化数据库引擎和会话 - engine, session_local = await initialize_database() - - if engine is None: - logger.error("数据库引擎初始化失败") - return False - - logger.info("SQLAlchemy异步数据库初始化成功") - return True - - except SQLAlchemyError as e: - logger.error(f"SQLAlchemy数据库初始化失败: {e}") - return False - except Exception as e: - logger.error(f"数据库初始化过程中发生未知错误: {e}") - return False - - -async def create_all_tables() -> bool: - """ - 异步创建所有数据库表 - - Returns: - bool: 创建是否成功 - """ - try: - logger.info("开始创建数据库表...") - - engine = await get_engine() - if engine is None: - logger.error("无法获取数据库引擎") - return False - - # 异步创建所有表 - async with engine.begin() as conn: - await conn.run_sync(Base.metadata.create_all) - - logger.info("数据库表创建成功") - return True - - except SQLAlchemyError as e: - logger.error(f"创建数据库表失败: {e}") - return False - except Exception as e: - logger.error(f"创建数据库表过程中发生未知错误: {e}") - return False - - -async def get_database_info() -> dict | None: - """ - 异步获取数据库信息 - - Returns: - dict: 数据库信息字典,包含引擎信息等 - """ - try: - engine = await get_engine() - if engine is None: - return None - - info = { - "engine_name": engine.name, - "driver": engine.driver, - "url": str(engine.url).replace(engine.url.password or "", "***"), # 隐藏密码 - "pool_size": getattr(engine.pool, "size", None), - "max_overflow": getattr(engine.pool, "max_overflow", None), - } - - return info - - except Exception as e: - logger.error(f"获取数据库信息失败: {e}") - return None - - -_database_initialized = False - - -async def initialize_database_compat() -> bool: - """ - 兼容性异步数据库初始化函数 - 用于替换原有的Peewee初始化代码 - - Returns: - bool: 初始化是否成功 - """ - global _database_initialized - - if _database_initialized: - return True - - success = await initialize_sqlalchemy_database() - if success: - success = await create_all_tables() - - if success: - _database_initialized = True - - return success diff --git a/src/common/database/old/sqlalchemy_models.py b/src/common/database/old/sqlalchemy_models.py deleted file mode 100644 index 287f0fc29..000000000 --- a/src/common/database/old/sqlalchemy_models.py +++ /dev/null @@ -1,892 +0,0 @@ -"""SQLAlchemy数据库模型定义 - -替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力 - -说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到 -SQLAlchemy 2.0 推荐的带类型注解的声明式风格: - - field_name: Mapped[PyType] = mapped_column(Type, ...) - -这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。 -当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。 -""" - -import datetime -import os -import time -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager -from typing import Any - -from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text, text -from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Mapped, mapped_column - -from src.common.database.connection_pool_manager import get_connection_pool_manager -from src.common.logger import get_logger - -logger = get_logger("sqlalchemy_models") - -# 创建基类 -Base = declarative_base() - -# 全局异步引擎与会话工厂占位(延迟初始化) -_engine: AsyncEngine | None = None -_SessionLocal: async_sessionmaker[AsyncSession] | None = None - - -async def enable_sqlite_wal_mode(engine): - """为 SQLite 启用 WAL 模式以提高并发性能""" - try: - async with engine.begin() as conn: - # 启用 WAL 模式 - await conn.execute(text("PRAGMA journal_mode = WAL")) - # 设置适中的同步级别,平衡性能和安全性 - await conn.execute(text("PRAGMA synchronous = NORMAL")) - # 启用外键约束 - await conn.execute(text("PRAGMA foreign_keys = ON")) - # 设置 busy_timeout,避免锁定错误 - await conn.execute(text("PRAGMA busy_timeout = 60000")) # 60秒 - - logger.info("[SQLite] WAL 模式已启用,并发性能已优化") - except Exception as e: - logger.warning(f"[SQLite] 启用 WAL 模式失败: {e},将使用默认配置") - - -async def maintain_sqlite_database(): - """定期维护 SQLite 数据库性能""" - try: - engine, SessionLocal = await initialize_database() - if not engine: - return - - async with engine.begin() as conn: - # 检查并确保 WAL 模式仍然启用 - result = await conn.execute(text("PRAGMA journal_mode")) - journal_mode = result.scalar() - - if journal_mode != "wal": - await conn.execute(text("PRAGMA journal_mode = WAL")) - logger.info("[SQLite] WAL 模式已重新启用") - - # 优化数据库性能 - await conn.execute(text("PRAGMA synchronous = NORMAL")) - await conn.execute(text("PRAGMA busy_timeout = 60000")) - await conn.execute(text("PRAGMA foreign_keys = ON")) - - # 定期清理(可选,根据需要启用) - # await conn.execute(text("PRAGMA optimize")) - - logger.info("[SQLite] 数据库维护完成") - except Exception as e: - logger.warning(f"[SQLite] 数据库维护失败: {e}") - - -def get_sqlite_performance_config(): - """获取 SQLite 性能优化配置""" - return { - "journal_mode": "WAL", # 提高并发性能 - "synchronous": "NORMAL", # 平衡性能和安全性 - "busy_timeout": 60000, # 60秒超时 - "foreign_keys": "ON", # 启用外键约束 - "cache_size": -10000, # 10MB 缓存 - "temp_store": "MEMORY", # 临时存储使用内存 - "mmap_size": 268435456, # 256MB 内存映射 - } - - -# MySQL兼容的字段类型辅助函数 -def get_string_field(max_length=255, **kwargs): - """ - 根据数据库类型返回合适的字符串字段 - MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text - """ - from src.config.config import global_config - - if global_config.database.database_type == "mysql": - return String(max_length, **kwargs) - else: - return Text(**kwargs) - - -class ChatStreams(Base): - """聊天流模型""" - - __tablename__ = "chat_streams" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - stream_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, unique=True, index=True) - create_time: Mapped[float] = mapped_column(Float, nullable=False) - group_platform: Mapped[str | None] = mapped_column(Text, nullable=True) - group_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True) - group_name: Mapped[str | None] = mapped_column(Text, nullable=True) - last_active_time: Mapped[float] = mapped_column(Float, nullable=False) - platform: Mapped[str] = mapped_column(Text, nullable=False) - user_platform: Mapped[str] = mapped_column(Text, nullable=False) - user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) - user_nickname: Mapped[str] = mapped_column(Text, nullable=False) - user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True) - energy_value: Mapped[float | None] = mapped_column(Float, nullable=True, default=5.0) - sleep_pressure: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0) - focus_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5) - # 动态兴趣度系统字段 - base_interest_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5) - message_interest_total: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0) - message_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) - action_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) - reply_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) - last_interaction_time: Mapped[float | None] = mapped_column(Float, nullable=True, default=None) - consecutive_no_reply: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) - # 消息打断系统字段 - interruption_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) - # 聊天流印象字段 - stream_impression_text: Mapped[str | None] = mapped_column(Text, nullable=True) # 对聊天流的主观印象描述 - stream_chat_style: Mapped[str | None] = mapped_column(Text, nullable=True) # 聊天流的总体风格 - stream_topic_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 话题关键词,逗号分隔 - stream_interest_score: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5) # 对聊天流的兴趣程度(0-1) - - __table_args__ = ( - Index("idx_chatstreams_stream_id", "stream_id"), - Index("idx_chatstreams_user_id", "user_id"), - Index("idx_chatstreams_group_id", "group_id"), - ) - - -class LLMUsage(Base): - """LLM使用记录模型""" - - __tablename__ = "llm_usage" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - model_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) - model_assign_name: Mapped[str] = mapped_column(get_string_field(100), index=True) - model_api_provider: Mapped[str] = mapped_column(get_string_field(100), index=True) - user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) - request_type: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) - endpoint: Mapped[str] = mapped_column(Text, nullable=False) - prompt_tokens: Mapped[int] = mapped_column(Integer, nullable=False) - completion_tokens: Mapped[int] = mapped_column(Integer, nullable=False) - time_cost: Mapped[float | None] = mapped_column(Float, nullable=True) - total_tokens: Mapped[int] = mapped_column(Integer, nullable=False) - cost: Mapped[float] = mapped_column(Float, nullable=False) - status: Mapped[str] = mapped_column(Text, nullable=False) - timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True, default=datetime.datetime.now) - - __table_args__ = ( - Index("idx_llmusage_model_name", "model_name"), - Index("idx_llmusage_model_assign_name", "model_assign_name"), - Index("idx_llmusage_model_api_provider", "model_api_provider"), - Index("idx_llmusage_time_cost", "time_cost"), - Index("idx_llmusage_user_id", "user_id"), - Index("idx_llmusage_request_type", "request_type"), - Index("idx_llmusage_timestamp", "timestamp"), - ) - - -class Emoji(Base): - """表情包模型""" - - __tablename__ = "emoji" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - full_path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True) - format: Mapped[str] = mapped_column(Text, nullable=False) - emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) - description: Mapped[str] = mapped_column(Text, nullable=False) - query_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - is_registered: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - is_banned: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - emotion: Mapped[str | None] = mapped_column(Text, nullable=True) - record_time: Mapped[float] = mapped_column(Float, nullable=False) - register_time: Mapped[float | None] = mapped_column(Float, nullable=True) - usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - last_used_time: Mapped[float | None] = mapped_column(Float, nullable=True) - - __table_args__ = ( - Index("idx_emoji_full_path", "full_path"), - Index("idx_emoji_hash", "emoji_hash"), - ) - - -class Messages(Base): - """消息模型""" - - __tablename__ = "messages" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - message_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) - time: Mapped[float] = mapped_column(Float, nullable=False) - chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) - reply_to: Mapped[str | None] = mapped_column(Text, nullable=True) - interest_value: Mapped[float | None] = mapped_column(Float, nullable=True) - key_words: Mapped[str | None] = mapped_column(Text, nullable=True) - key_words_lite: Mapped[str | None] = mapped_column(Text, nullable=True) - is_mentioned: Mapped[bool | None] = mapped_column(Boolean, nullable=True) - - # 从 chat_info 扁平化而来的字段 - chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False) - chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False) - chat_info_user_platform: Mapped[str] = mapped_column(Text, nullable=False) - chat_info_user_id: Mapped[str] = mapped_column(Text, nullable=False) - chat_info_user_nickname: Mapped[str] = mapped_column(Text, nullable=False) - chat_info_user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True) - chat_info_group_platform: Mapped[str | None] = mapped_column(Text, nullable=True) - chat_info_group_id: Mapped[str | None] = mapped_column(Text, nullable=True) - chat_info_group_name: Mapped[str | None] = mapped_column(Text, nullable=True) - chat_info_create_time: Mapped[float] = mapped_column(Float, nullable=False) - chat_info_last_active_time: Mapped[float] = mapped_column(Float, nullable=False) - - # 从顶层 user_info 扁平化而来的字段 - user_platform: Mapped[str | None] = mapped_column(Text, nullable=True) - user_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True) - user_nickname: Mapped[str | None] = mapped_column(Text, nullable=True) - user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True) - - processed_plain_text: Mapped[str | None] = mapped_column(Text, nullable=True) - display_message: Mapped[str | None] = mapped_column(Text, nullable=True) - memorized_times: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - priority_mode: Mapped[str | None] = mapped_column(Text, nullable=True) - priority_info: Mapped[str | None] = mapped_column(Text, nullable=True) - additional_config: Mapped[str | None] = mapped_column(Text, nullable=True) - is_emoji: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - is_picid: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - is_command: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - is_notify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - is_public_notice: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - notice_type: Mapped[str | None] = mapped_column(String(50), nullable=True) - - # 兴趣度系统字段 - actions: Mapped[str | None] = mapped_column(Text, nullable=True) - should_reply: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False) - should_act: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False) - - __table_args__ = ( - Index("idx_messages_message_id", "message_id"), - Index("idx_messages_chat_id", "chat_id"), - Index("idx_messages_time", "time"), - Index("idx_messages_user_id", "user_id"), - Index("idx_messages_should_reply", "should_reply"), - Index("idx_messages_should_act", "should_act"), - ) - - -class ActionRecords(Base): - """动作记录模型""" - - __tablename__ = "action_records" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - action_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) - time: Mapped[float] = mapped_column(Float, nullable=False) - action_name: Mapped[str] = mapped_column(Text, nullable=False) - action_data: Mapped[str] = mapped_column(Text, nullable=False) - action_done: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - action_build_into_prompt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - action_prompt_display: Mapped[str] = mapped_column(Text, nullable=False) - chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) - chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False) - chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False) - - __table_args__ = ( - Index("idx_actionrecords_action_id", "action_id"), - Index("idx_actionrecords_chat_id", "chat_id"), - Index("idx_actionrecords_time", "time"), - ) - - -class Images(Base): - """图像信息模型""" - - __tablename__ = "images" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - image_id: Mapped[str] = mapped_column(Text, nullable=False, default="") - emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) - description: Mapped[str | None] = mapped_column(Text, nullable=True) - path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True) - count: Mapped[int] = mapped_column(Integer, nullable=False, default=1) - timestamp: Mapped[float] = mapped_column(Float, nullable=False) - type: Mapped[str] = mapped_column(Text, nullable=False) - vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - - __table_args__ = ( - Index("idx_images_emoji_hash", "emoji_hash"), - Index("idx_images_path", "path"), - ) - - -class ImageDescriptions(Base): - """图像描述信息模型""" - - __tablename__ = "image_descriptions" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - type: Mapped[str] = mapped_column(Text, nullable=False) - image_description_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) - description: Mapped[str] = mapped_column(Text, nullable=False) - timestamp: Mapped[float] = mapped_column(Float, nullable=False) - - __table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),) - - -class Videos(Base): - """视频信息模型""" - - __tablename__ = "videos" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - video_id: Mapped[str] = mapped_column(Text, nullable=False, default="") - video_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True, unique=True) - description: Mapped[str | None] = mapped_column(Text, nullable=True) - count: Mapped[int] = mapped_column(Integer, nullable=False, default=1) - timestamp: Mapped[float] = mapped_column(Float, nullable=False) - vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - - # 视频特有属性 - duration: Mapped[float | None] = mapped_column(Float, nullable=True) - frame_count: Mapped[int | None] = mapped_column(Integer, nullable=True) - fps: Mapped[float | None] = mapped_column(Float, nullable=True) - resolution: Mapped[str | None] = mapped_column(Text, nullable=True) - file_size: Mapped[int | None] = mapped_column(Integer, nullable=True) - - __table_args__ = ( - Index("idx_videos_video_hash", "video_hash"), - Index("idx_videos_timestamp", "timestamp"), - ) - - -class OnlineTime(Base): - """在线时长记录模型""" - - __tablename__ = "online_time" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - timestamp: Mapped[str] = mapped_column(Text, nullable=False, default=str(datetime.datetime.now)) - duration: Mapped[int] = mapped_column(Integer, nullable=False) - start_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - end_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True) - - __table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),) - - -class PersonInfo(Base): - """人物信息模型""" - - __tablename__ = "person_info" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - person_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True) - person_name: Mapped[str | None] = mapped_column(Text, nullable=True) - name_reason: Mapped[str | None] = mapped_column(Text, nullable=True) - platform: Mapped[str] = mapped_column(Text, nullable=False) - user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) - nickname: Mapped[str | None] = mapped_column(Text, nullable=True) - impression: Mapped[str | None] = mapped_column(Text, nullable=True) - short_impression: Mapped[str | None] = mapped_column(Text, nullable=True) - points: Mapped[str | None] = mapped_column(Text, nullable=True) - forgotten_points: Mapped[str | None] = mapped_column(Text, nullable=True) - info_list: Mapped[str | None] = mapped_column(Text, nullable=True) - know_times: Mapped[float | None] = mapped_column(Float, nullable=True) - know_since: Mapped[float | None] = mapped_column(Float, nullable=True) - last_know: Mapped[float | None] = mapped_column(Float, nullable=True) - attitude: Mapped[int | None] = mapped_column(Integer, nullable=True, default=50) - - __table_args__ = ( - Index("idx_personinfo_person_id", "person_id"), - Index("idx_personinfo_user_id", "user_id"), - ) - - -class BotPersonalityInterests(Base): - """机器人人格兴趣标签模型""" - - __tablename__ = "bot_personality_interests" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - personality_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) - personality_description: Mapped[str] = mapped_column(Text, nullable=False) - interest_tags: Mapped[str] = mapped_column(Text, nullable=False) - embedding_model: Mapped[str] = mapped_column(get_string_field(100), nullable=False, default="text-embedding-ada-002") - version: Mapped[int] = mapped_column(Integer, nullable=False, default=1) - last_updated: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, index=True) - - __table_args__ = ( - Index("idx_botpersonality_personality_id", "personality_id"), - Index("idx_botpersonality_version", "version"), - Index("idx_botpersonality_last_updated", "last_updated"), - ) - - -class Memory(Base): - """记忆模型""" - - __tablename__ = "memory" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - memory_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) - chat_id: Mapped[str | None] = mapped_column(Text, nullable=True) - memory_text: Mapped[str | None] = mapped_column(Text, nullable=True) - keywords: Mapped[str | None] = mapped_column(Text, nullable=True) - create_time: Mapped[float | None] = mapped_column(Float, nullable=True) - last_view_time: Mapped[float | None] = mapped_column(Float, nullable=True) - - __table_args__ = (Index("idx_memory_memory_id", "memory_id"),) - - -class Expression(Base): - """表达风格模型""" - - __tablename__ = "expression" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - situation: Mapped[str] = mapped_column(Text, nullable=False) - style: Mapped[str] = mapped_column(Text, nullable=False) - count: Mapped[float] = mapped_column(Float, nullable=False) - last_active_time: Mapped[float] = mapped_column(Float, nullable=False) - chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) - type: Mapped[str] = mapped_column(Text, nullable=False) - create_date: Mapped[float | None] = mapped_column(Float, nullable=True) - - __table_args__ = (Index("idx_expression_chat_id", "chat_id"),) - - -class ThinkingLog(Base): - """思考日志模型""" - - __tablename__ = "thinking_logs" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) - trigger_text: Mapped[str | None] = mapped_column(Text, nullable=True) - response_text: Mapped[str | None] = mapped_column(Text, nullable=True) - trigger_info_json: Mapped[str | None] = mapped_column(Text, nullable=True) - response_info_json: Mapped[str | None] = mapped_column(Text, nullable=True) - timing_results_json: Mapped[str | None] = mapped_column(Text, nullable=True) - chat_history_json: Mapped[str | None] = mapped_column(Text, nullable=True) - chat_history_in_thinking_json: Mapped[str | None] = mapped_column(Text, nullable=True) - chat_history_after_response_json: Mapped[str | None] = mapped_column(Text, nullable=True) - heartflow_data_json: Mapped[str | None] = mapped_column(Text, nullable=True) - reasoning_data_json: Mapped[str | None] = mapped_column(Text, nullable=True) - created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - - __table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),) - - -class GraphNodes(Base): - """记忆图节点模型""" - - __tablename__ = "graph_nodes" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - concept: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True) - memory_items: Mapped[str] = mapped_column(Text, nullable=False) - hash: Mapped[str] = mapped_column(Text, nullable=False) - weight: Mapped[float] = mapped_column(Float, nullable=False, default=1.0) - created_time: Mapped[float] = mapped_column(Float, nullable=False) - last_modified: Mapped[float] = mapped_column(Float, nullable=False) - - __table_args__ = (Index("idx_graphnodes_concept", "concept"),) - - -class GraphEdges(Base): - """记忆图边模型""" - - __tablename__ = "graph_edges" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - source: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True) - target: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True) - strength: Mapped[int] = mapped_column(Integer, nullable=False) - hash: Mapped[str] = mapped_column(Text, nullable=False) - created_time: Mapped[float] = mapped_column(Float, nullable=False) - last_modified: Mapped[float] = mapped_column(Float, nullable=False) - - __table_args__ = ( - Index("idx_graphedges_source", "source"), - Index("idx_graphedges_target", "target"), - ) - - -class Schedule(Base): - """日程模型""" - - __tablename__ = "schedule" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - date: Mapped[str] = mapped_column(get_string_field(10), nullable=False, unique=True, index=True) - schedule_data: Mapped[str] = mapped_column(Text, nullable=False) - created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - __table_args__ = (Index("idx_schedule_date", "date"),) - - -class MaiZoneScheduleStatus(Base): - """麦麦空间日程处理状态模型""" - - __tablename__ = "maizone_schedule_status" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - datetime_hour: Mapped[str] = mapped_column(get_string_field(13), nullable=False, unique=True, index=True) - activity: Mapped[str] = mapped_column(Text, nullable=False) - is_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - processed_at: Mapped[datetime.datetime | None] = mapped_column(DateTime, nullable=True) - story_content: Mapped[str | None] = mapped_column(Text, nullable=True) - send_success: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) - created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - __table_args__ = ( - Index("idx_maizone_datetime_hour", "datetime_hour"), - Index("idx_maizone_is_processed", "is_processed"), - ) - - -class BanUser(Base): - """被禁用用户模型 - - 使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型, - 避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。 - """ - - __tablename__ = "ban_users" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - platform: Mapped[str] = mapped_column(Text, nullable=False) - user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) - violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) - reason: Mapped[str] = mapped_column(Text, nullable=False) - created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - - __table_args__ = ( - Index("idx_violation_num", "violation_num"), - Index("idx_banuser_user_id", "user_id"), - Index("idx_banuser_platform", "platform"), - Index("idx_banuser_platform_user_id", "platform", "user_id"), - ) - - -class AntiInjectionStats(Base): - """反注入系统统计模型""" - - __tablename__ = "anti_injection_stats" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - total_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - """总处理消息数""" - - detected_injections: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - """检测到的注入攻击数""" - - blocked_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - """被阻止的消息数""" - - shielded_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - """被加盾的消息数""" - - processing_time_total: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) - """总处理时间""" - - total_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) - """累计总处理时间""" - - last_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) - """最近一次处理时间""" - - error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - """错误计数""" - - start_time: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - """统计开始时间""" - - created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - """记录创建时间""" - - updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) - """记录更新时间""" - - __table_args__ = ( - Index("idx_anti_injection_stats_created_at", "created_at"), - Index("idx_anti_injection_stats_updated_at", "updated_at"), - ) - - -class CacheEntries(Base): - """工具缓存条目模型""" - - __tablename__ = "cache_entries" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - cache_key: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True) - """缓存键,包含工具名、参数和代码哈希""" - - cache_value: Mapped[str] = mapped_column(Text, nullable=False) - """缓存的数据,JSON格式""" - - expires_at: Mapped[float] = mapped_column(Float, nullable=False, index=True) - """过期时间戳""" - - tool_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) - """工具名称""" - - created_at: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time()) - """创建时间戳""" - - last_accessed: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time()) - """最后访问时间戳""" - - access_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - """访问次数""" - - __table_args__ = ( - Index("idx_cache_entries_key", "cache_key"), - Index("idx_cache_entries_expires_at", "expires_at"), - Index("idx_cache_entries_tool_name", "tool_name"), - Index("idx_cache_entries_created_at", "created_at"), - ) - - -class MonthlyPlan(Base): - """月度计划模型""" - - __tablename__ = "monthly_plans" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - plan_text: Mapped[str] = mapped_column(Text, nullable=False) - target_month: Mapped[str] = mapped_column(String(7), nullable=False, index=True) - status: Mapped[str] = mapped_column(get_string_field(20), nullable=False, default="active", index=True) - usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) - last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True) - created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True) - - __table_args__ = ( - Index("idx_monthlyplan_target_month_status", "target_month", "status"), - Index("idx_monthlyplan_last_used_date", "last_used_date"), - Index("idx_monthlyplan_usage_count", "usage_count"), - ) - - -def get_database_url(): - """获取数据库连接URL""" - from src.config.config import global_config - - config = global_config.database - - if config.database_type == "mysql": - # 对用户名和密码进行URL编码,处理特殊字符 - from urllib.parse import quote_plus - - encoded_user = quote_plus(config.mysql_user) - encoded_password = quote_plus(config.mysql_password) - - # 检查是否配置了Unix socket连接 - if config.mysql_unix_socket: - # 使用Unix socket连接 - encoded_socket = quote_plus(config.mysql_unix_socket) - return ( - f"mysql+aiomysql://{encoded_user}:{encoded_password}" - f"@/{config.mysql_database}" - f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" - ) - else: - # 使用标准TCP连接 - return ( - f"mysql+aiomysql://{encoded_user}:{encoded_password}" - f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" - f"?charset={config.mysql_charset}" - ) - else: # SQLite - # 如果是相对路径,则相对于项目根目录 - if not os.path.isabs(config.sqlite_path): - ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) - db_path = os.path.join(ROOT_PATH, config.sqlite_path) - else: - db_path = config.sqlite_path - - # 确保数据库目录存在 - os.makedirs(os.path.dirname(db_path), exist_ok=True) - - return f"sqlite+aiosqlite:///{db_path}" - - -_initializing: bool = False # 防止递归初始化 - -async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[AsyncSession]]: - """初始化异步数据库引擎和会话 - - Returns: - tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: 创建好的异步引擎与会话工厂。 - - 说明: - 显式的返回类型标注有助于 Pyright/Pylance 正确推断调用处的对象, - 避免后续对返回值再次 `await` 时出现 *"tuple[...] 并非 awaitable"* 的误用。 - """ - global _engine, _SessionLocal, _initializing - - # 已经初始化直接返回 - if _engine is not None and _SessionLocal is not None: - return _engine, _SessionLocal - - # 正在初始化的并发调用等待主初始化完成,避免递归 - if _initializing: - import asyncio - for _ in range(1000): # 最多等待约10秒 - await asyncio.sleep(0.01) - if _engine is not None and _SessionLocal is not None: - return _engine, _SessionLocal - raise RuntimeError("等待数据库初始化完成超时 (reentrancy guard)") - - _initializing = True - try: - database_url = get_database_url() - from src.config.config import global_config - - config = global_config.database - - # 配置引擎参数 - engine_kwargs: dict[str, Any] = { - "echo": False, # 生产环境关闭SQL日志 - "future": True, - } - - if config.database_type == "mysql": - engine_kwargs.update( - { - "pool_size": config.connection_pool_size, - "max_overflow": config.connection_pool_size * 2, - "pool_timeout": config.connection_timeout, - "pool_recycle": 3600, - "pool_pre_ping": True, - "connect_args": { - "autocommit": config.mysql_autocommit, - "charset": config.mysql_charset, - "connect_timeout": config.connection_timeout, - }, - } - ) - else: - engine_kwargs.update( - { - "connect_args": { - "check_same_thread": False, - "timeout": 60, - }, - } - ) - - _engine = create_async_engine(database_url, **engine_kwargs) - _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) - - # 迁移 - from src.common.database.db_migration import check_and_migrate_database - await check_and_migrate_database(existing_engine=_engine) - - if config.database_type == "sqlite": - await enable_sqlite_wal_mode(_engine) - - logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}") - return _engine, _SessionLocal - finally: - _initializing = False - - -@asynccontextmanager -async def get_db_session() -> AsyncGenerator[AsyncSession]: - """ - 异步数据库会话上下文管理器。 - 在初始化失败时会yield None,调用方需要检查会话是否为None。 - - 现在使用透明的连接池管理器来复用现有连接,提高并发性能。 - """ - SessionLocal = None - try: - _, SessionLocal = await initialize_database() - if not SessionLocal: - raise RuntimeError("数据库会话工厂 (_SessionLocal) 未初始化。") - except Exception as e: - logger.error(f"数据库初始化失败,无法创建会话: {e}") - raise - - # 使用连接池管理器获取会话 - pool_manager = get_connection_pool_manager() - - async with pool_manager.get_session(SessionLocal) as session: - # 对于 SQLite,在会话开始时设置 PRAGMA(仅对新连接) - from src.config.config import global_config - - if global_config.database.database_type == "sqlite": - try: - await session.execute(text("PRAGMA busy_timeout = 60000")) - await session.execute(text("PRAGMA foreign_keys = ON")) - except Exception as e: - logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}") - - yield session - - -async def get_engine(): - """获取异步数据库引擎""" - engine, _ = await initialize_database() - return engine - - -class PermissionNodes(Base): - """权限节点模型""" - - __tablename__ = "permission_nodes" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - node_name: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True) - description: Mapped[str] = mapped_column(Text, nullable=False) - plugin_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) - default_granted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) - created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False) - - __table_args__ = ( - Index("idx_permission_plugin", "plugin_name"), - Index("idx_permission_node", "node_name"), - ) - - -class UserPermissions(Base): - """用户权限模型""" - - __tablename__ = "user_permissions" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - platform: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) - user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) - permission_node: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True) - granted: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) - granted_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False) - granted_by: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True) - - __table_args__ = ( - Index("idx_user_platform_id", "platform", "user_id"), - Index("idx_user_permission", "platform", "user_id", "permission_node"), - Index("idx_permission_granted", "permission_node", "granted"), - ) - - -class UserRelationships(Base): - """用户关系模型 - 存储用户与bot的关系数据""" - - __tablename__ = "user_relationships" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True) - user_name: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True) - user_aliases: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户别名,逗号分隔 - relationship_text: Mapped[str | None] = mapped_column(Text, nullable=True) - preference_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户偏好关键词,逗号分隔 - relationship_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.3) # 关系分数(0-1) - last_updated: Mapped[float] = mapped_column(Float, nullable=False, default=time.time) - created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False) - - __table_args__ = ( - Index("idx_user_relationship_id", "user_id"), - Index("idx_relationship_score", "relationship_score"), - Index("idx_relationship_updated", "last_updated"), - ) diff --git a/src/common/database/sqlalchemy_models.py.bak b/src/common/database/sqlalchemy_models.py.bak deleted file mode 100644 index 061ac6fad..000000000 --- a/src/common/database/sqlalchemy_models.py.bak +++ /dev/null @@ -1,872 +0,0 @@ -"""SQLAlchemy数据库模型定义 - -替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力 - -说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到 -SQLAlchemy 2.0 推荐的带类型注解的声明式风格: - - field_name: Mapped[PyType] = mapped_column(Type, ...) - -这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。 -当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。 -""" - -import datetime -import os -import time -from collections.abc import AsyncGenerator -from contextlib import asynccontextmanager -from typing import Any - -from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text -from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Mapped, mapped_column - -from src.common.database.connection_pool_manager import get_connection_pool_manager -from src.common.logger import get_logger - -logger = get_logger("sqlalchemy_models") - -# 创建基类 -Base = declarative_base() - - -async def enable_sqlite_wal_mode(engine): - """为 SQLite 启用 WAL 模式以提高并发性能""" - try: - async with engine.begin() as conn: - # 启用 WAL 模式 - await conn.execute(text("PRAGMA journal_mode = WAL")) - # 设置适中的同步级别,平衡性能和安全性 - await conn.execute(text("PRAGMA synchronous = NORMAL")) - # 启用外键约束 - await conn.execute(text("PRAGMA foreign_keys = ON")) - # 设置 busy_timeout,避免锁定错误 - await conn.execute(text("PRAGMA busy_timeout = 60000")) # 60秒 - - logger.info("[SQLite] WAL 模式已启用,并发性能已优化") - except Exception as e: - logger.warning(f"[SQLite] 启用 WAL 模式失败: {e},将使用默认配置") - - -async def maintain_sqlite_database(): - """定期维护 SQLite 数据库性能""" - try: - engine, SessionLocal = await initialize_database() - if not engine: - return - - async with engine.begin() as conn: - # 检查并确保 WAL 模式仍然启用 - result = await conn.execute(text("PRAGMA journal_mode")) - journal_mode = result.scalar() - - if journal_mode != "wal": - await conn.execute(text("PRAGMA journal_mode = WAL")) - logger.info("[SQLite] WAL 模式已重新启用") - - # 优化数据库性能 - await conn.execute(text("PRAGMA synchronous = NORMAL")) - await conn.execute(text("PRAGMA busy_timeout = 60000")) - await conn.execute(text("PRAGMA foreign_keys = ON")) - - # 定期清理(可选,根据需要启用) - # await conn.execute(text("PRAGMA optimize")) - - logger.info("[SQLite] 数据库维护完成") - except Exception as e: - logger.warning(f"[SQLite] 数据库维护失败: {e}") - - -def get_sqlite_performance_config(): - """获取 SQLite 性能优化配置""" - return { - "journal_mode": "WAL", # 提高并发性能 - "synchronous": "NORMAL", # 平衡性能和安全性 - "busy_timeout": 60000, # 60秒超时 - "foreign_keys": "ON", # 启用外键约束 - "cache_size": -10000, # 10MB 缓存 - "temp_store": "MEMORY", # 临时存储使用内存 - "mmap_size": 268435456, # 256MB 内存映射 - } - - -# MySQL兼容的字段类型辅助函数 -def get_string_field(max_length=255, **kwargs): - """ - 根据数据库类型返回合适的字符串字段 - MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text - """ - from src.config.config import global_config - - if global_config.database.database_type == "mysql": - return String(max_length, **kwargs) - else: - return Text(**kwargs) - - -class ChatStreams(Base): - """聊天流模型""" - - __tablename__ = "chat_streams" - - id = Column(Integer, primary_key=True, autoincrement=True) - stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True) - create_time = Column(Float, nullable=False) - group_platform = Column(Text, nullable=True) - group_id = Column(get_string_field(100), nullable=True, index=True) - group_name = Column(Text, nullable=True) - last_active_time = Column(Float, nullable=False) - platform = Column(Text, nullable=False) - user_platform = Column(Text, nullable=False) - user_id = Column(get_string_field(100), nullable=False, index=True) - user_nickname = Column(Text, nullable=False) - user_cardname = Column(Text, nullable=True) - energy_value = Column(Float, nullable=True, default=5.0) - sleep_pressure = Column(Float, nullable=True, default=0.0) - focus_energy = Column(Float, nullable=True, default=0.5) - # 动态兴趣度系统字段 - base_interest_energy = Column(Float, nullable=True, default=0.5) - message_interest_total = Column(Float, nullable=True, default=0.0) - message_count = Column(Integer, nullable=True, default=0) - action_count = Column(Integer, nullable=True, default=0) - reply_count = Column(Integer, nullable=True, default=0) - last_interaction_time = Column(Float, nullable=True, default=None) - consecutive_no_reply = Column(Integer, nullable=True, default=0) - # 消息打断系统字段 - interruption_count = Column(Integer, nullable=True, default=0) - - __table_args__ = ( - Index("idx_chatstreams_stream_id", "stream_id"), - Index("idx_chatstreams_user_id", "user_id"), - Index("idx_chatstreams_group_id", "group_id"), - ) - - -class LLMUsage(Base): - """LLM使用记录模型""" - - __tablename__ = "llm_usage" - - id = Column(Integer, primary_key=True, autoincrement=True) - model_name = Column(get_string_field(100), nullable=False, index=True) - model_assign_name = Column(get_string_field(100), index=True) # 添加索引 - model_api_provider = Column(get_string_field(100), index=True) # 添加索引 - user_id = Column(get_string_field(50), nullable=False, index=True) - request_type = Column(get_string_field(50), nullable=False, index=True) - endpoint = Column(Text, nullable=False) - prompt_tokens = Column(Integer, nullable=False) - completion_tokens = Column(Integer, nullable=False) - time_cost = Column(Float, nullable=True) - total_tokens = Column(Integer, nullable=False) - cost = Column(Float, nullable=False) - status = Column(Text, nullable=False) - timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now) - - __table_args__ = ( - Index("idx_llmusage_model_name", "model_name"), - Index("idx_llmusage_model_assign_name", "model_assign_name"), - Index("idx_llmusage_model_api_provider", "model_api_provider"), - Index("idx_llmusage_time_cost", "time_cost"), - Index("idx_llmusage_user_id", "user_id"), - Index("idx_llmusage_request_type", "request_type"), - Index("idx_llmusage_timestamp", "timestamp"), - ) - - -class Emoji(Base): - """表情包模型""" - - __tablename__ = "emoji" - - id = Column(Integer, primary_key=True, autoincrement=True) - full_path = Column(get_string_field(500), nullable=False, unique=True, index=True) - format = Column(Text, nullable=False) - emoji_hash = Column(get_string_field(64), nullable=False, index=True) - description = Column(Text, nullable=False) - query_count = Column(Integer, nullable=False, default=0) - is_registered = Column(Boolean, nullable=False, default=False) - is_banned = Column(Boolean, nullable=False, default=False) - emotion = Column(Text, nullable=True) - record_time = Column(Float, nullable=False) - register_time = Column(Float, nullable=True) - usage_count = Column(Integer, nullable=False, default=0) - last_used_time = Column(Float, nullable=True) - - __table_args__ = ( - Index("idx_emoji_full_path", "full_path"), - Index("idx_emoji_hash", "emoji_hash"), - ) - - -class Messages(Base): - """消息模型""" - - __tablename__ = "messages" - - id = Column(Integer, primary_key=True, autoincrement=True) - message_id = Column(get_string_field(100), nullable=False, index=True) - time = Column(Float, nullable=False) - chat_id = Column(get_string_field(64), nullable=False, index=True) - reply_to = Column(Text, nullable=True) - interest_value = Column(Float, nullable=True) - key_words = Column(Text, nullable=True) - key_words_lite = Column(Text, nullable=True) - is_mentioned = Column(Boolean, nullable=True) - - # 从 chat_info 扁平化而来的字段 - chat_info_stream_id = Column(Text, nullable=False) - chat_info_platform = Column(Text, nullable=False) - chat_info_user_platform = Column(Text, nullable=False) - chat_info_user_id = Column(Text, nullable=False) - chat_info_user_nickname = Column(Text, nullable=False) - chat_info_user_cardname = Column(Text, nullable=True) - chat_info_group_platform = Column(Text, nullable=True) - chat_info_group_id = Column(Text, nullable=True) - chat_info_group_name = Column(Text, nullable=True) - chat_info_create_time = Column(Float, nullable=False) - chat_info_last_active_time = Column(Float, nullable=False) - - # 从顶层 user_info 扁平化而来的字段 - user_platform = Column(Text, nullable=True) - user_id = Column(get_string_field(100), nullable=True, index=True) - user_nickname = Column(Text, nullable=True) - user_cardname = Column(Text, nullable=True) - - processed_plain_text = Column(Text, nullable=True) - display_message = Column(Text, nullable=True) - memorized_times = Column(Integer, nullable=False, default=0) - priority_mode = Column(Text, nullable=True) - priority_info = Column(Text, nullable=True) - additional_config = Column(Text, nullable=True) - is_emoji = Column(Boolean, nullable=False, default=False) - is_picid = Column(Boolean, nullable=False, default=False) - is_command = Column(Boolean, nullable=False, default=False) - is_notify = Column(Boolean, nullable=False, default=False) - - # 兴趣度系统字段 - actions = Column(Text, nullable=True) # JSON格式存储动作列表 - should_reply = Column(Boolean, nullable=True, default=False) - should_act = Column(Boolean, nullable=True, default=False) - - __table_args__ = ( - Index("idx_messages_message_id", "message_id"), - Index("idx_messages_chat_id", "chat_id"), - Index("idx_messages_time", "time"), - Index("idx_messages_user_id", "user_id"), - Index("idx_messages_should_reply", "should_reply"), - Index("idx_messages_should_act", "should_act"), - ) - - -class ActionRecords(Base): - """动作记录模型""" - - __tablename__ = "action_records" - - id = Column(Integer, primary_key=True, autoincrement=True) - action_id = Column(get_string_field(100), nullable=False, index=True) - time = Column(Float, nullable=False) - action_name = Column(Text, nullable=False) - action_data = Column(Text, nullable=False) - action_done = Column(Boolean, nullable=False, default=False) - action_build_into_prompt = Column(Boolean, nullable=False, default=False) - action_prompt_display = Column(Text, nullable=False) - chat_id = Column(get_string_field(64), nullable=False, index=True) - chat_info_stream_id = Column(Text, nullable=False) - chat_info_platform = Column(Text, nullable=False) - - __table_args__ = ( - Index("idx_actionrecords_action_id", "action_id"), - Index("idx_actionrecords_chat_id", "chat_id"), - Index("idx_actionrecords_time", "time"), - ) - - -class Images(Base): - """图像信息模型""" - - __tablename__ = "images" - - id = Column(Integer, primary_key=True, autoincrement=True) - image_id = Column(Text, nullable=False, default="") - emoji_hash = Column(get_string_field(64), nullable=False, index=True) - description = Column(Text, nullable=True) - path = Column(get_string_field(500), nullable=False, unique=True) - count = Column(Integer, nullable=False, default=1) - timestamp = Column(Float, nullable=False) - type = Column(Text, nullable=False) - vlm_processed = Column(Boolean, nullable=False, default=False) - - __table_args__ = ( - Index("idx_images_emoji_hash", "emoji_hash"), - Index("idx_images_path", "path"), - ) - - -class ImageDescriptions(Base): - """图像描述信息模型""" - - __tablename__ = "image_descriptions" - - id = Column(Integer, primary_key=True, autoincrement=True) - type = Column(Text, nullable=False) - image_description_hash = Column(get_string_field(64), nullable=False, index=True) - description = Column(Text, nullable=False) - timestamp = Column(Float, nullable=False) - - __table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),) - - -class Videos(Base): - """视频信息模型""" - - __tablename__ = "videos" - - id = Column(Integer, primary_key=True, autoincrement=True) - video_id = Column(Text, nullable=False, default="") - video_hash = Column(get_string_field(64), nullable=False, index=True, unique=True) - description = Column(Text, nullable=True) - count = Column(Integer, nullable=False, default=1) - timestamp = Column(Float, nullable=False) - vlm_processed = Column(Boolean, nullable=False, default=False) - - # 视频特有属性 - duration = Column(Float, nullable=True) # 视频时长(秒) - frame_count = Column(Integer, nullable=True) # 总帧数 - fps = Column(Float, nullable=True) # 帧率 - resolution = Column(Text, nullable=True) # 分辨率 - file_size = Column(Integer, nullable=True) # 文件大小(字节) - - __table_args__ = ( - Index("idx_videos_video_hash", "video_hash"), - Index("idx_videos_timestamp", "timestamp"), - ) - - -class OnlineTime(Base): - """在线时长记录模型""" - - __tablename__ = "online_time" - - id = Column(Integer, primary_key=True, autoincrement=True) - timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now)) - duration = Column(Integer, nullable=False) - start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now) - end_timestamp = Column(DateTime, nullable=False, index=True) - - __table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),) - - -class PersonInfo(Base): - """人物信息模型""" - - __tablename__ = "person_info" - - id = Column(Integer, primary_key=True, autoincrement=True) - person_id = Column(get_string_field(100), nullable=False, unique=True, index=True) - person_name = Column(Text, nullable=True) - name_reason = Column(Text, nullable=True) - platform = Column(Text, nullable=False) - user_id = Column(get_string_field(50), nullable=False, index=True) - nickname = Column(Text, nullable=True) - impression = Column(Text, nullable=True) - short_impression = Column(Text, nullable=True) - points = Column(Text, nullable=True) - forgotten_points = Column(Text, nullable=True) - info_list = Column(Text, nullable=True) - know_times = Column(Float, nullable=True) - know_since = Column(Float, nullable=True) - last_know = Column(Float, nullable=True) - attitude = Column(Integer, nullable=True, default=50) - - __table_args__ = ( - Index("idx_personinfo_person_id", "person_id"), - Index("idx_personinfo_user_id", "user_id"), - ) - - -class BotPersonalityInterests(Base): - """机器人人格兴趣标签模型""" - - __tablename__ = "bot_personality_interests" - - id = Column(Integer, primary_key=True, autoincrement=True) - personality_id = Column(get_string_field(100), nullable=False, index=True) - personality_description = Column(Text, nullable=False) - interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表 - embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002") - version = Column(Integer, nullable=False, default=1) - last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True) - - __table_args__ = ( - Index("idx_botpersonality_personality_id", "personality_id"), - Index("idx_botpersonality_version", "version"), - Index("idx_botpersonality_last_updated", "last_updated"), - ) - - -class Memory(Base): - """记忆模型""" - - __tablename__ = "memory" - - id = Column(Integer, primary_key=True, autoincrement=True) - memory_id = Column(get_string_field(64), nullable=False, index=True) - chat_id = Column(Text, nullable=True) - memory_text = Column(Text, nullable=True) - keywords = Column(Text, nullable=True) - create_time = Column(Float, nullable=True) - last_view_time = Column(Float, nullable=True) - - __table_args__ = (Index("idx_memory_memory_id", "memory_id"),) - - -class Expression(Base): - """表达风格模型""" - - __tablename__ = "expression" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - situation: Mapped[str] = mapped_column(Text, nullable=False) - style: Mapped[str] = mapped_column(Text, nullable=False) - count: Mapped[float] = mapped_column(Float, nullable=False) - last_active_time: Mapped[float] = mapped_column(Float, nullable=False) - chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) - type: Mapped[str] = mapped_column(Text, nullable=False) - create_date: Mapped[float | None] = mapped_column(Float, nullable=True) - - __table_args__ = (Index("idx_expression_chat_id", "chat_id"),) - - -class ThinkingLog(Base): - """思考日志模型""" - - __tablename__ = "thinking_logs" - - id = Column(Integer, primary_key=True, autoincrement=True) - chat_id = Column(get_string_field(64), nullable=False, index=True) - trigger_text = Column(Text, nullable=True) - response_text = Column(Text, nullable=True) - trigger_info_json = Column(Text, nullable=True) - response_info_json = Column(Text, nullable=True) - timing_results_json = Column(Text, nullable=True) - chat_history_json = Column(Text, nullable=True) - chat_history_in_thinking_json = Column(Text, nullable=True) - chat_history_after_response_json = Column(Text, nullable=True) - heartflow_data_json = Column(Text, nullable=True) - reasoning_data_json = Column(Text, nullable=True) - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) - - __table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),) - - -class GraphNodes(Base): - """记忆图节点模型""" - - __tablename__ = "graph_nodes" - - id = Column(Integer, primary_key=True, autoincrement=True) - concept = Column(get_string_field(255), nullable=False, unique=True, index=True) - memory_items = Column(Text, nullable=False) - hash = Column(Text, nullable=False) - weight = Column(Float, nullable=False, default=1.0) - created_time = Column(Float, nullable=False) - last_modified = Column(Float, nullable=False) - - __table_args__ = (Index("idx_graphnodes_concept", "concept"),) - - -class GraphEdges(Base): - """记忆图边模型""" - - __tablename__ = "graph_edges" - - id = Column(Integer, primary_key=True, autoincrement=True) - source = Column(get_string_field(255), nullable=False, index=True) - target = Column(get_string_field(255), nullable=False, index=True) - strength = Column(Integer, nullable=False) - hash = Column(Text, nullable=False) - created_time = Column(Float, nullable=False) - last_modified = Column(Float, nullable=False) - - __table_args__ = ( - Index("idx_graphedges_source", "source"), - Index("idx_graphedges_target", "target"), - ) - - -class Schedule(Base): - """日程模型""" - - __tablename__ = "schedule" - - id = Column(Integer, primary_key=True, autoincrement=True) - date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式 - schedule_data = Column(Text, nullable=False) # JSON格式的日程数据 - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) - updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - __table_args__ = (Index("idx_schedule_date", "date"),) - - -class MaiZoneScheduleStatus(Base): - """麦麦空间日程处理状态模型""" - - __tablename__ = "maizone_schedule_status" - - id = Column(Integer, primary_key=True, autoincrement=True) - datetime_hour = Column( - get_string_field(13), nullable=False, unique=True, index=True - ) # YYYY-MM-DD HH格式,精确到小时 - activity = Column(Text, nullable=False) # 该小时的活动内容 - is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理 - processed_at = Column(DateTime, nullable=True) # 处理时间 - story_content = Column(Text, nullable=True) # 生成的说说内容 - send_success = Column(Boolean, nullable=False, default=False) # 是否发送成功 - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) - updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) - - __table_args__ = ( - Index("idx_maizone_datetime_hour", "datetime_hour"), - Index("idx_maizone_is_processed", "is_processed"), - ) - - -class BanUser(Base): - """被禁用用户模型 - - 使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型, - 避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。 - """ - - __tablename__ = "ban_users" - - id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - platform: Mapped[str] = mapped_column(Text, nullable=False) - user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) - violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) - reason: Mapped[str] = mapped_column(Text, nullable=False) - created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) - - __table_args__ = ( - Index("idx_violation_num", "violation_num"), - Index("idx_banuser_user_id", "user_id"), - Index("idx_banuser_platform", "platform"), - Index("idx_banuser_platform_user_id", "platform", "user_id"), - ) - - -class AntiInjectionStats(Base): - """反注入系统统计模型""" - - __tablename__ = "anti_injection_stats" - - id = Column(Integer, primary_key=True, autoincrement=True) - total_messages = Column(Integer, nullable=False, default=0) - """总处理消息数""" - - detected_injections = Column(Integer, nullable=False, default=0) - """检测到的注入攻击数""" - - blocked_messages = Column(Integer, nullable=False, default=0) - """被阻止的消息数""" - - shielded_messages = Column(Integer, nullable=False, default=0) - """被加盾的消息数""" - - processing_time_total = Column(Float, nullable=False, default=0.0) - """总处理时间""" - - total_process_time = Column(Float, nullable=False, default=0.0) - """累计总处理时间""" - - last_process_time = Column(Float, nullable=False, default=0.0) - """最近一次处理时间""" - - error_count = Column(Integer, nullable=False, default=0) - """错误计数""" - - start_time = Column(DateTime, nullable=False, default=datetime.datetime.now) - """统计开始时间""" - - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) - """记录创建时间""" - - updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) - """记录更新时间""" - - __table_args__ = ( - Index("idx_anti_injection_stats_created_at", "created_at"), - Index("idx_anti_injection_stats_updated_at", "updated_at"), - ) - - -class CacheEntries(Base): - """工具缓存条目模型""" - - __tablename__ = "cache_entries" - - id = Column(Integer, primary_key=True, autoincrement=True) - cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True) - """缓存键,包含工具名、参数和代码哈希""" - - cache_value = Column(Text, nullable=False) - """缓存的数据,JSON格式""" - - expires_at = Column(Float, nullable=False, index=True) - """过期时间戳""" - - tool_name = Column(get_string_field(100), nullable=False, index=True) - """工具名称""" - - created_at = Column(Float, nullable=False, default=lambda: time.time()) - """创建时间戳""" - - last_accessed = Column(Float, nullable=False, default=lambda: time.time()) - """最后访问时间戳""" - - access_count = Column(Integer, nullable=False, default=0) - """访问次数""" - - __table_args__ = ( - Index("idx_cache_entries_key", "cache_key"), - Index("idx_cache_entries_expires_at", "expires_at"), - Index("idx_cache_entries_tool_name", "tool_name"), - Index("idx_cache_entries_created_at", "created_at"), - ) - - -class MonthlyPlan(Base): - """月度计划模型""" - - __tablename__ = "monthly_plans" - - id = Column(Integer, primary_key=True, autoincrement=True) - plan_text = Column(Text, nullable=False) - target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM" - status = Column( - get_string_field(20), nullable=False, default="active", index=True - ) # 'active', 'completed', 'archived' - usage_count = Column(Integer, nullable=False, default=0) - last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) - - # 保留 is_deleted 字段以兼容现有数据,但标记为已弃用 - is_deleted = Column(Boolean, nullable=False, default=False) - - __table_args__ = ( - Index("idx_monthlyplan_target_month_status", "target_month", "status"), - Index("idx_monthlyplan_last_used_date", "last_used_date"), - Index("idx_monthlyplan_usage_count", "usage_count"), - # 保留旧索引以兼容 - Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"), - ) - - -# 数据库引擎和会话管理 -_engine = None -_SessionLocal = None - - -def get_database_url(): - """获取数据库连接URL""" - from src.config.config import global_config - - config = global_config.database - - if config.database_type == "mysql": - # 对用户名和密码进行URL编码,处理特殊字符 - from urllib.parse import quote_plus - - encoded_user = quote_plus(config.mysql_user) - encoded_password = quote_plus(config.mysql_password) - - # 检查是否配置了Unix socket连接 - if config.mysql_unix_socket: - # 使用Unix socket连接 - encoded_socket = quote_plus(config.mysql_unix_socket) - return ( - f"mysql+aiomysql://{encoded_user}:{encoded_password}" - f"@/{config.mysql_database}" - f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" - ) - else: - # 使用标准TCP连接 - return ( - f"mysql+aiomysql://{encoded_user}:{encoded_password}" - f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" - f"?charset={config.mysql_charset}" - ) - else: # SQLite - # 如果是相对路径,则相对于项目根目录 - if not os.path.isabs(config.sqlite_path): - ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) - db_path = os.path.join(ROOT_PATH, config.sqlite_path) - else: - db_path = config.sqlite_path - - # 确保数据库目录存在 - os.makedirs(os.path.dirname(db_path), exist_ok=True) - - return f"sqlite+aiosqlite:///{db_path}" - - -async def initialize_database(): - """初始化异步数据库引擎和会话""" - global _engine, _SessionLocal - - if _engine is not None: - return _engine, _SessionLocal - - database_url = get_database_url() - from src.config.config import global_config - - config = global_config.database - - # 配置引擎参数 - engine_kwargs: dict[str, Any] = { - "echo": False, # 生产环境关闭SQL日志 - "future": True, - } - - if config.database_type == "mysql": - # MySQL连接池配置 - 异步引擎使用默认连接池 - engine_kwargs.update( - { - "pool_size": config.connection_pool_size, - "max_overflow": config.connection_pool_size * 2, - "pool_timeout": config.connection_timeout, - "pool_recycle": 3600, # 1小时回收连接 - "pool_pre_ping": True, # 连接前ping检查 - "connect_args": { - "autocommit": config.mysql_autocommit, - "charset": config.mysql_charset, - "connect_timeout": config.connection_timeout, - }, - } - ) - else: - # SQLite配置 - aiosqlite不支持连接池参数 - engine_kwargs.update( - { - "connect_args": { - "check_same_thread": False, - "timeout": 60, # 增加超时时间 - }, - } - ) - - _engine = create_async_engine(database_url, **engine_kwargs) - _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) - - # 调用新的迁移函数,它会处理表的创建和列的添加 - from src.common.database.db_migration import check_and_migrate_database - - await check_and_migrate_database() - - # 如果是 SQLite,启用 WAL 模式以提高并发性能 - if config.database_type == "sqlite": - await enable_sqlite_wal_mode(_engine) - - logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}") - return _engine, _SessionLocal - - -@asynccontextmanager -async def get_db_session() -> AsyncGenerator[AsyncSession]: - """ - 异步数据库会话上下文管理器。 - 在初始化失败时会yield None,调用方需要检查会话是否为None。 - - 现在使用透明的连接池管理器来复用现有连接,提高并发性能。 - """ - SessionLocal = None - try: - _, SessionLocal = await initialize_database() - if not SessionLocal: - raise RuntimeError("数据库会话工厂 (_SessionLocal) 未初始化。") - except Exception as e: - logger.error(f"数据库初始化失败,无法创建会话: {e}") - raise - - # 使用连接池管理器获取会话 - pool_manager = get_connection_pool_manager() - - async with pool_manager.get_session(SessionLocal) as session: - # 对于 SQLite,在会话开始时设置 PRAGMA(仅对新连接) - from src.config.config import global_config - - if global_config.database.database_type == "sqlite": - try: - await session.execute(text("PRAGMA busy_timeout = 60000")) - await session.execute(text("PRAGMA foreign_keys = ON")) - except Exception as e: - logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}") - - yield session - - -async def get_engine(): - """获取异步数据库引擎""" - engine, _ = await initialize_database() - return engine - - -class PermissionNodes(Base): - """权限节点模型""" - - __tablename__ = "permission_nodes" - - id = Column(Integer, primary_key=True, autoincrement=True) - node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称 - description = Column(Text, nullable=False) # 权限描述 - plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件 - default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权 - created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间 - - __table_args__ = ( - Index("idx_permission_plugin", "plugin_name"), - Index("idx_permission_node", "node_name"), - ) - - -class UserPermissions(Base): - """用户权限模型""" - - __tablename__ = "user_permissions" - - id = Column(Integer, primary_key=True, autoincrement=True) - platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型 - user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID - permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称 - granted = Column(Boolean, default=True, nullable=False) # 是否授权 - granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间 - granted_by = Column(get_string_field(100), nullable=True) # 授权者信息 - - __table_args__ = ( - Index("idx_user_platform_id", "platform", "user_id"), - Index("idx_user_permission", "platform", "user_id", "permission_node"), - Index("idx_permission_granted", "permission_node", "granted"), - ) - - -class UserRelationships(Base): - """用户关系模型 - 存储用户与bot的关系数据""" - - __tablename__ = "user_relationships" - - id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID - user_name = Column(get_string_field(100), nullable=True) # 用户名 - relationship_text = Column(Text, nullable=True) # 关系印象描述 - relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1) - last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间 - created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间 - - __table_args__ = ( - Index("idx_user_relationship_id", "user_id"), - Index("idx_relationship_score", "relationship_score"), - Index("idx_relationship_updated", "last_updated"), - )