chore: 清理旧数据库实现文件
- 删除old/目录下的旧实现文件 - 删除sqlalchemy_models.py.bak备份文件 - 完成数据库重构代码清理工作
This commit is contained in:
@@ -1,109 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
from rich.traceback import install
|
|
||||||
|
|
||||||
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
|
|
||||||
|
|
||||||
# 数据库批量调度器和连接池
|
|
||||||
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
|
|
||||||
|
|
||||||
# SQLAlchemy相关导入
|
|
||||||
from src.common.database.sqlalchemy_init import initialize_database_compat
|
|
||||||
from src.common.database.sqlalchemy_models import get_engine
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
install(extra_lines=3)
|
|
||||||
|
|
||||||
_sql_engine = None
|
|
||||||
|
|
||||||
logger = get_logger("database")
|
|
||||||
|
|
||||||
|
|
||||||
# 兼容性:为了不破坏现有代码,保留db变量但指向SQLAlchemy
|
|
||||||
class DatabaseProxy:
|
|
||||||
"""数据库代理类"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self._engine = None
|
|
||||||
self._session = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def initialize(*args, **kwargs):
|
|
||||||
"""初始化数据库连接"""
|
|
||||||
result = await initialize_database_compat()
|
|
||||||
|
|
||||||
# 启动数据库优化系统
|
|
||||||
try:
|
|
||||||
# 启动数据库批量调度器
|
|
||||||
batch_scheduler = get_db_batch_scheduler()
|
|
||||||
await batch_scheduler.start()
|
|
||||||
logger.info("🚀 数据库批量调度器启动成功")
|
|
||||||
|
|
||||||
# 启动连接池管理器
|
|
||||||
await start_connection_pool()
|
|
||||||
logger.info("🚀 连接池管理器启动成功")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"启动数据库优化系统失败: {e}")
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# 创建全局数据库代理实例
|
|
||||||
db = DatabaseProxy()
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_sql_database(database_config):
|
|
||||||
"""
|
|
||||||
根据配置初始化SQL数据库连接(SQLAlchemy版本)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
database_config: DatabaseConfig对象
|
|
||||||
"""
|
|
||||||
global _sql_engine
|
|
||||||
|
|
||||||
try:
|
|
||||||
logger.info("使用SQLAlchemy初始化SQL数据库...")
|
|
||||||
|
|
||||||
# 记录数据库配置信息
|
|
||||||
if database_config.database_type == "mysql":
|
|
||||||
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
|
|
||||||
logger.info("MySQL数据库连接配置:")
|
|
||||||
logger.info(f" 连接信息: {connection_info}")
|
|
||||||
logger.info(f" 字符集: {database_config.mysql_charset}")
|
|
||||||
else:
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
||||||
if not os.path.isabs(database_config.sqlite_path):
|
|
||||||
db_path = os.path.join(ROOT_PATH, database_config.sqlite_path)
|
|
||||||
else:
|
|
||||||
db_path = database_config.sqlite_path
|
|
||||||
logger.info("SQLite数据库连接配置:")
|
|
||||||
logger.info(f" 数据库文件: {db_path}")
|
|
||||||
|
|
||||||
# 使用SQLAlchemy初始化
|
|
||||||
success = await initialize_database_compat()
|
|
||||||
if success:
|
|
||||||
_sql_engine = await get_engine()
|
|
||||||
logger.info("SQLAlchemy数据库初始化成功")
|
|
||||||
else:
|
|
||||||
logger.error("SQLAlchemy数据库初始化失败")
|
|
||||||
|
|
||||||
return _sql_engine
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"初始化SQL数据库失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def stop_database():
|
|
||||||
"""停止数据库相关服务"""
|
|
||||||
try:
|
|
||||||
# 停止连接池管理器
|
|
||||||
await stop_connection_pool()
|
|
||||||
logger.info("🛑 连接池管理器已停止")
|
|
||||||
|
|
||||||
# 停止数据库批量调度器
|
|
||||||
batch_scheduler = get_db_batch_scheduler()
|
|
||||||
await batch_scheduler.stop()
|
|
||||||
logger.info("🛑 数据库批量调度器已停止")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"停止数据库优化系统时出错: {e}")
|
|
||||||
@@ -1,462 +0,0 @@
|
|||||||
"""
|
|
||||||
数据库批量调度器
|
|
||||||
实现多个数据库请求的智能合并和批量处理,减少数据库连接竞争
|
|
||||||
"""
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import time
|
|
||||||
from collections import defaultdict, deque
|
|
||||||
from collections.abc import Callable
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import Any, TypeVar
|
|
||||||
|
|
||||||
from sqlalchemy import delete, insert, select, update
|
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("db_batch_scheduler")
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BatchOperation:
|
|
||||||
"""批量操作基础类"""
|
|
||||||
|
|
||||||
operation_type: str # 'select', 'insert', 'update', 'delete'
|
|
||||||
model_class: Any
|
|
||||||
conditions: dict[str, Any]
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
callback: Callable | None = None
|
|
||||||
future: asyncio.Future | None = None
|
|
||||||
timestamp: float = 0.0
|
|
||||||
|
|
||||||
def __post_init__(self):
|
|
||||||
if self.timestamp == 0.0:
|
|
||||||
self.timestamp = time.time()
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BatchResult:
|
|
||||||
"""批量操作结果"""
|
|
||||||
|
|
||||||
success: bool
|
|
||||||
data: Any = None
|
|
||||||
error: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class DatabaseBatchScheduler:
|
|
||||||
"""数据库批量调度器"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
batch_size: int = 50,
|
|
||||||
max_wait_time: float = 0.1, # 100ms
|
|
||||||
max_queue_size: int = 1000,
|
|
||||||
):
|
|
||||||
self.batch_size = batch_size
|
|
||||||
self.max_wait_time = max_wait_time
|
|
||||||
self.max_queue_size = max_queue_size
|
|
||||||
|
|
||||||
# 操作队列,按操作类型和模型分类
|
|
||||||
self.operation_queues: dict[str, deque] = defaultdict(deque)
|
|
||||||
|
|
||||||
# 调度控制
|
|
||||||
self._scheduler_task: asyncio.Task | None = None
|
|
||||||
self._is_running = False
|
|
||||||
self._lock = asyncio.Lock()
|
|
||||||
|
|
||||||
# 统计信息
|
|
||||||
self.stats = {"total_operations": 0, "batched_operations": 0, "cache_hits": 0, "execution_time": 0.0}
|
|
||||||
|
|
||||||
# 简单的结果缓存(用于频繁的查询)
|
|
||||||
self._result_cache: dict[str, tuple[Any, float]] = {}
|
|
||||||
self._cache_ttl = 5.0 # 5秒缓存
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""启动调度器"""
|
|
||||||
if self._is_running:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._is_running = True
|
|
||||||
self._scheduler_task = asyncio.create_task(self._scheduler_loop())
|
|
||||||
logger.info("数据库批量调度器已启动")
|
|
||||||
|
|
||||||
async def stop(self):
|
|
||||||
"""停止调度器"""
|
|
||||||
if not self._is_running:
|
|
||||||
return
|
|
||||||
|
|
||||||
self._is_running = False
|
|
||||||
if self._scheduler_task:
|
|
||||||
self._scheduler_task.cancel()
|
|
||||||
try:
|
|
||||||
await self._scheduler_task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 处理剩余的操作
|
|
||||||
await self._flush_all_queues()
|
|
||||||
logger.info("数据库批量调度器已停止")
|
|
||||||
|
|
||||||
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str:
|
|
||||||
"""生成缓存键"""
|
|
||||||
# 简单的缓存键生成,实际可以根据需要优化
|
|
||||||
key_parts = [operation_type, model_class.__name__, str(sorted(conditions.items()))]
|
|
||||||
return "|".join(key_parts)
|
|
||||||
|
|
||||||
def _get_from_cache(self, cache_key: str) -> Any | None:
|
|
||||||
"""从缓存获取结果"""
|
|
||||||
if cache_key in self._result_cache:
|
|
||||||
result, timestamp = self._result_cache[cache_key]
|
|
||||||
if time.time() - timestamp < self._cache_ttl:
|
|
||||||
self.stats["cache_hits"] += 1
|
|
||||||
return result
|
|
||||||
else:
|
|
||||||
# 清理过期缓存
|
|
||||||
del self._result_cache[cache_key]
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _set_cache(self, cache_key: str, result: Any):
|
|
||||||
"""设置缓存"""
|
|
||||||
self._result_cache[cache_key] = (result, time.time())
|
|
||||||
|
|
||||||
async def add_operation(self, operation: BatchOperation) -> asyncio.Future:
|
|
||||||
"""添加操作到队列"""
|
|
||||||
# 检查是否可以立即返回缓存结果
|
|
||||||
if operation.operation_type == "select":
|
|
||||||
cache_key = self._generate_cache_key(operation.operation_type, operation.model_class, operation.conditions)
|
|
||||||
cached_result = self._get_from_cache(cache_key)
|
|
||||||
if cached_result is not None:
|
|
||||||
if operation.callback:
|
|
||||||
operation.callback(cached_result)
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
future.set_result(cached_result)
|
|
||||||
return future
|
|
||||||
|
|
||||||
# 创建future用于返回结果
|
|
||||||
future = asyncio.get_event_loop().create_future()
|
|
||||||
operation.future = future
|
|
||||||
|
|
||||||
# 添加到队列
|
|
||||||
queue_key = f"{operation.operation_type}_{operation.model_class.__name__}"
|
|
||||||
|
|
||||||
async with self._lock:
|
|
||||||
if len(self.operation_queues[queue_key]) >= self.max_queue_size:
|
|
||||||
# 队列满了,直接执行
|
|
||||||
await self._execute_operations([operation])
|
|
||||||
else:
|
|
||||||
self.operation_queues[queue_key].append(operation)
|
|
||||||
self.stats["total_operations"] += 1
|
|
||||||
|
|
||||||
return future
|
|
||||||
|
|
||||||
async def _scheduler_loop(self):
|
|
||||||
"""调度器主循环"""
|
|
||||||
while self._is_running:
|
|
||||||
try:
|
|
||||||
await asyncio.sleep(self.max_wait_time)
|
|
||||||
await self._flush_all_queues()
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"调度器循环异常: {e}", exc_info=True)
|
|
||||||
|
|
||||||
async def _flush_all_queues(self):
|
|
||||||
"""刷新所有队列"""
|
|
||||||
async with self._lock:
|
|
||||||
if not any(self.operation_queues.values()):
|
|
||||||
return
|
|
||||||
|
|
||||||
# 复制队列内容,避免长时间占用锁
|
|
||||||
queues_copy = {key: deque(operations) for key, operations in self.operation_queues.items()}
|
|
||||||
# 清空原队列
|
|
||||||
for queue in self.operation_queues.values():
|
|
||||||
queue.clear()
|
|
||||||
|
|
||||||
# 批量执行各队列的操作
|
|
||||||
for operations in queues_copy.values():
|
|
||||||
if operations:
|
|
||||||
await self._execute_operations(list(operations))
|
|
||||||
|
|
||||||
async def _execute_operations(self, operations: list[BatchOperation]):
|
|
||||||
"""执行批量操作"""
|
|
||||||
if not operations:
|
|
||||||
return
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 按操作类型分组
|
|
||||||
op_groups = defaultdict(list)
|
|
||||||
for op in operations:
|
|
||||||
op_groups[op.operation_type].append(op)
|
|
||||||
|
|
||||||
# 为每种操作类型创建批量执行任务
|
|
||||||
tasks = []
|
|
||||||
for op_type, ops in op_groups.items():
|
|
||||||
if op_type == "select":
|
|
||||||
tasks.append(self._execute_select_batch(ops))
|
|
||||||
elif op_type == "insert":
|
|
||||||
tasks.append(self._execute_insert_batch(ops))
|
|
||||||
elif op_type == "update":
|
|
||||||
tasks.append(self._execute_update_batch(ops))
|
|
||||||
elif op_type == "delete":
|
|
||||||
tasks.append(self._execute_delete_batch(ops))
|
|
||||||
|
|
||||||
# 并发执行所有操作
|
|
||||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
||||||
|
|
||||||
# 处理结果
|
|
||||||
for i, result in enumerate(results):
|
|
||||||
operation = operations[i]
|
|
||||||
if isinstance(result, Exception):
|
|
||||||
if operation.future and not operation.future.done():
|
|
||||||
operation.future.set_exception(result)
|
|
||||||
else:
|
|
||||||
if operation.callback:
|
|
||||||
try:
|
|
||||||
operation.callback(result)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"操作回调执行失败: {e}")
|
|
||||||
|
|
||||||
if operation.future and not operation.future.done():
|
|
||||||
operation.future.set_result(result)
|
|
||||||
|
|
||||||
# 缓存查询结果
|
|
||||||
if operation.operation_type == "select":
|
|
||||||
cache_key = self._generate_cache_key(
|
|
||||||
operation.operation_type, operation.model_class, operation.conditions
|
|
||||||
)
|
|
||||||
self._set_cache(cache_key, result)
|
|
||||||
|
|
||||||
self.stats["batched_operations"] += len(operations)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"批量操作执行失败: {e}", exc_info="")
|
|
||||||
# 设置所有future的异常状态
|
|
||||||
for operation in operations:
|
|
||||||
if operation.future and not operation.future.done():
|
|
||||||
operation.future.set_exception(e)
|
|
||||||
finally:
|
|
||||||
self.stats["execution_time"] += time.time() - start_time
|
|
||||||
|
|
||||||
async def _execute_select_batch(self, operations: list[BatchOperation]):
|
|
||||||
"""批量执行查询操作"""
|
|
||||||
# 合并相似的查询条件
|
|
||||||
merged_conditions = self._merge_select_conditions(operations)
|
|
||||||
|
|
||||||
async with get_db_session() as session:
|
|
||||||
results = []
|
|
||||||
for conditions, ops in merged_conditions.items():
|
|
||||||
try:
|
|
||||||
# 构建查询
|
|
||||||
query = select(ops[0].model_class)
|
|
||||||
for field_name, value in conditions.items():
|
|
||||||
model_attr = getattr(ops[0].model_class, field_name)
|
|
||||||
if isinstance(value, list | tuple | set):
|
|
||||||
query = query.where(model_attr.in_(value))
|
|
||||||
else:
|
|
||||||
query = query.where(model_attr == value)
|
|
||||||
|
|
||||||
# 执行查询
|
|
||||||
result = await session.execute(query)
|
|
||||||
data = result.scalars().all()
|
|
||||||
|
|
||||||
# 分发结果到各个操作
|
|
||||||
for op in ops:
|
|
||||||
if len(conditions) == 1 and len(ops) == 1:
|
|
||||||
# 单个查询,直接返回所有结果
|
|
||||||
op_result = data
|
|
||||||
else:
|
|
||||||
# 需要根据条件过滤结果
|
|
||||||
op_result = [
|
|
||||||
item
|
|
||||||
for item in data
|
|
||||||
if all(getattr(item, k) == v for k, v in op.conditions.items() if hasattr(item, k))
|
|
||||||
]
|
|
||||||
results.append(op_result)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"批量查询失败: {e}", exc_info=True)
|
|
||||||
results.append([])
|
|
||||||
|
|
||||||
return results if len(results) > 1 else results[0] if results else []
|
|
||||||
|
|
||||||
async def _execute_insert_batch(self, operations: list[BatchOperation]):
|
|
||||||
"""批量执行插入操作"""
|
|
||||||
async with get_db_session() as session:
|
|
||||||
try:
|
|
||||||
# 收集所有要插入的数据
|
|
||||||
all_data = [op.data for op in operations if op.data]
|
|
||||||
if not all_data:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# 批量插入
|
|
||||||
stmt = insert(operations[0].model_class).values(all_data)
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
return [result.rowcount] * len(operations)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await session.rollback()
|
|
||||||
logger.error(f"批量插入失败: {e}", exc_info=True)
|
|
||||||
return [0] * len(operations)
|
|
||||||
|
|
||||||
async def _execute_update_batch(self, operations: list[BatchOperation]):
|
|
||||||
"""批量执行更新操作"""
|
|
||||||
async with get_db_session() as session:
|
|
||||||
try:
|
|
||||||
results = []
|
|
||||||
for op in operations:
|
|
||||||
if not op.data or not op.conditions:
|
|
||||||
results.append(0)
|
|
||||||
continue
|
|
||||||
|
|
||||||
stmt = update(op.model_class)
|
|
||||||
for field_name, value in op.conditions.items():
|
|
||||||
model_attr = getattr(op.model_class, field_name)
|
|
||||||
if isinstance(value, list | tuple | set):
|
|
||||||
stmt = stmt.where(model_attr.in_(value))
|
|
||||||
else:
|
|
||||||
stmt = stmt.where(model_attr == value)
|
|
||||||
|
|
||||||
stmt = stmt.values(**op.data)
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
results.append(result.rowcount)
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await session.rollback()
|
|
||||||
logger.error(f"批量更新失败: {e}", exc_info=True)
|
|
||||||
return [0] * len(operations)
|
|
||||||
|
|
||||||
async def _execute_delete_batch(self, operations: list[BatchOperation]):
|
|
||||||
"""批量执行删除操作"""
|
|
||||||
async with get_db_session() as session:
|
|
||||||
try:
|
|
||||||
results = []
|
|
||||||
for op in operations:
|
|
||||||
if not op.conditions:
|
|
||||||
results.append(0)
|
|
||||||
continue
|
|
||||||
|
|
||||||
stmt = delete(op.model_class)
|
|
||||||
for field_name, value in op.conditions.items():
|
|
||||||
model_attr = getattr(op.model_class, field_name)
|
|
||||||
if isinstance(value, list | tuple | set):
|
|
||||||
stmt = stmt.where(model_attr.in_(value))
|
|
||||||
else:
|
|
||||||
stmt = stmt.where(model_attr == value)
|
|
||||||
|
|
||||||
result = await session.execute(stmt)
|
|
||||||
results.append(result.rowcount)
|
|
||||||
|
|
||||||
await session.commit()
|
|
||||||
return results
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await session.rollback()
|
|
||||||
logger.error(f"批量删除失败: {e}", exc_info=True)
|
|
||||||
return [0] * len(operations)
|
|
||||||
|
|
||||||
def _merge_select_conditions(self, operations: list[BatchOperation]) -> dict[tuple, list[BatchOperation]]:
|
|
||||||
"""合并相似的查询条件"""
|
|
||||||
merged = {}
|
|
||||||
|
|
||||||
for op in operations:
|
|
||||||
# 生成条件键
|
|
||||||
condition_key = tuple(sorted(op.conditions.keys()))
|
|
||||||
|
|
||||||
if condition_key not in merged:
|
|
||||||
merged[condition_key] = {}
|
|
||||||
|
|
||||||
# 尝试合并相同字段的值
|
|
||||||
for field_name, value in op.conditions.items():
|
|
||||||
if field_name not in merged[condition_key]:
|
|
||||||
merged[condition_key][field_name] = []
|
|
||||||
|
|
||||||
if isinstance(value, list | tuple | set):
|
|
||||||
merged[condition_key][field_name].extend(value)
|
|
||||||
else:
|
|
||||||
merged[condition_key][field_name].append(value)
|
|
||||||
|
|
||||||
# 记录操作
|
|
||||||
if condition_key not in merged:
|
|
||||||
merged[condition_key] = {"_operations": []}
|
|
||||||
if "_operations" not in merged[condition_key]:
|
|
||||||
merged[condition_key]["_operations"] = []
|
|
||||||
merged[condition_key]["_operations"].append(op)
|
|
||||||
|
|
||||||
# 去重并构建最终条件
|
|
||||||
final_merged = {}
|
|
||||||
for condition_key, conditions in merged.items():
|
|
||||||
operations = conditions.pop("_operations")
|
|
||||||
|
|
||||||
# 去重
|
|
||||||
for field_name, values in conditions.items():
|
|
||||||
conditions[field_name] = list(set(values))
|
|
||||||
|
|
||||||
final_merged[condition_key] = operations
|
|
||||||
|
|
||||||
return final_merged
|
|
||||||
|
|
||||||
def get_stats(self) -> dict[str, Any]:
|
|
||||||
"""获取统计信息"""
|
|
||||||
return {
|
|
||||||
**self.stats,
|
|
||||||
"cache_size": len(self._result_cache),
|
|
||||||
"queue_sizes": {k: len(v) for k, v in self.operation_queues.items()},
|
|
||||||
"is_running": self._is_running,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# 全局数据库批量调度器实例
|
|
||||||
db_batch_scheduler = DatabaseBatchScheduler()
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def get_batch_session():
|
|
||||||
"""获取批量会话上下文管理器"""
|
|
||||||
if not db_batch_scheduler._is_running:
|
|
||||||
await db_batch_scheduler.start()
|
|
||||||
|
|
||||||
try:
|
|
||||||
yield db_batch_scheduler
|
|
||||||
finally:
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
# 便捷函数
|
|
||||||
async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any:
|
|
||||||
"""批量查询"""
|
|
||||||
operation = BatchOperation(operation_type="select", model_class=model_class, conditions=conditions)
|
|
||||||
return await db_batch_scheduler.add_operation(operation)
|
|
||||||
|
|
||||||
|
|
||||||
async def batch_insert(model_class: Any, data: dict[str, Any]) -> int:
|
|
||||||
"""批量插入"""
|
|
||||||
operation = BatchOperation(operation_type="insert", model_class=model_class, conditions={}, data=data)
|
|
||||||
return await db_batch_scheduler.add_operation(operation)
|
|
||||||
|
|
||||||
|
|
||||||
async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int:
|
|
||||||
"""批量更新"""
|
|
||||||
operation = BatchOperation(operation_type="update", model_class=model_class, conditions=conditions, data=data)
|
|
||||||
return await db_batch_scheduler.add_operation(operation)
|
|
||||||
|
|
||||||
|
|
||||||
async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int:
|
|
||||||
"""批量删除"""
|
|
||||||
operation = BatchOperation(operation_type="delete", model_class=model_class, conditions=conditions)
|
|
||||||
return await db_batch_scheduler.add_operation(operation)
|
|
||||||
|
|
||||||
|
|
||||||
def get_db_batch_scheduler() -> DatabaseBatchScheduler:
|
|
||||||
"""获取数据库批量调度器实例"""
|
|
||||||
return db_batch_scheduler
|
|
||||||
@@ -1,140 +0,0 @@
|
|||||||
# mmc/src/common/database/db_migration.py
|
|
||||||
|
|
||||||
from sqlalchemy import inspect
|
|
||||||
from sqlalchemy.sql import text
|
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_models import Base, get_engine
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("db_migration")
|
|
||||||
|
|
||||||
|
|
||||||
async def check_and_migrate_database(existing_engine=None):
|
|
||||||
"""
|
|
||||||
异步检查数据库结构并自动迁移。
|
|
||||||
- 自动创建不存在的表。
|
|
||||||
- 自动为现有表添加缺失的列。
|
|
||||||
- 自动为现有表创建缺失的索引。
|
|
||||||
|
|
||||||
Args:
|
|
||||||
existing_engine: 可选的已存在的数据库引擎。如果提供,将使用该引擎;否则获取全局引擎。
|
|
||||||
"""
|
|
||||||
logger.info("正在检查数据库结构并执行自动迁移...")
|
|
||||||
engine = existing_engine if existing_engine is not None else await get_engine()
|
|
||||||
|
|
||||||
async with engine.connect() as connection:
|
|
||||||
# 在同步上下文中运行inspector操作
|
|
||||||
def get_inspector(sync_conn):
|
|
||||||
return inspect(sync_conn)
|
|
||||||
|
|
||||||
inspector = await connection.run_sync(get_inspector)
|
|
||||||
|
|
||||||
# 在同步lambda中传递inspector
|
|
||||||
db_table_names = await connection.run_sync(lambda conn: set(inspector.get_table_names()))
|
|
||||||
|
|
||||||
# 1. 首先处理表的创建
|
|
||||||
tables_to_create = []
|
|
||||||
for table_name, table in Base.metadata.tables.items():
|
|
||||||
if table_name not in db_table_names:
|
|
||||||
tables_to_create.append(table)
|
|
||||||
|
|
||||||
if tables_to_create:
|
|
||||||
logger.info(f"发现 {len(tables_to_create)} 个不存在的表,正在创建...")
|
|
||||||
try:
|
|
||||||
# 一次性创建所有缺失的表
|
|
||||||
await connection.run_sync(
|
|
||||||
lambda sync_conn: Base.metadata.create_all(sync_conn, tables=tables_to_create)
|
|
||||||
)
|
|
||||||
for table in tables_to_create:
|
|
||||||
logger.info(f"表 '{table.name}' 创建成功。")
|
|
||||||
db_table_names.add(table.name) # 将新创建的表添加到集合中
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建表时失败: {e}", exc_info=True)
|
|
||||||
|
|
||||||
# 2. 然后处理现有表的列和索引的添加
|
|
||||||
for table_name, table in Base.metadata.tables.items():
|
|
||||||
if table_name not in db_table_names:
|
|
||||||
logger.warning(f"跳过检查表 '{table_name}',因为它在创建步骤中可能已失败。")
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.debug(f"正在检查表 '{table_name}' 的列和索引...")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 检查并添加缺失的列
|
|
||||||
db_columns = await connection.run_sync(
|
|
||||||
lambda conn: {col["name"] for col in inspector.get_columns(table_name)}
|
|
||||||
)
|
|
||||||
model_columns = {col.name for col in table.c}
|
|
||||||
missing_columns = model_columns - db_columns
|
|
||||||
|
|
||||||
if missing_columns:
|
|
||||||
logger.info(f"在表 '{table_name}' 中发现缺失的列: {', '.join(missing_columns)}")
|
|
||||||
|
|
||||||
def add_columns_sync(conn):
|
|
||||||
dialect = conn.dialect
|
|
||||||
compiler = dialect.ddl_compiler(dialect, None)
|
|
||||||
|
|
||||||
for column_name in missing_columns:
|
|
||||||
column = table.c[column_name]
|
|
||||||
column_type = compiler.get_column_specification(column)
|
|
||||||
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}"
|
|
||||||
|
|
||||||
if column.default:
|
|
||||||
# 手动处理不同方言的默认值
|
|
||||||
default_arg = column.default.arg
|
|
||||||
if dialect.name == "sqlite" and isinstance(default_arg, bool):
|
|
||||||
# SQLite 将布尔值存储为 0 或 1
|
|
||||||
default_value = "1" if default_arg else "0"
|
|
||||||
elif hasattr(compiler, "render_literal_value"):
|
|
||||||
try:
|
|
||||||
# 尝试使用 render_literal_value
|
|
||||||
default_value = compiler.render_literal_value(default_arg, column.type)
|
|
||||||
except AttributeError:
|
|
||||||
# 如果失败,则回退到简单的字符串转换
|
|
||||||
default_value = (
|
|
||||||
f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 对于没有 render_literal_value 的旧版或特定方言
|
|
||||||
default_value = (
|
|
||||||
f"'{default_arg}'" if isinstance(default_arg, str) else str(default_arg)
|
|
||||||
)
|
|
||||||
|
|
||||||
sql += f" DEFAULT {default_value}"
|
|
||||||
|
|
||||||
if not column.nullable:
|
|
||||||
sql += " NOT NULL"
|
|
||||||
|
|
||||||
conn.execute(text(sql))
|
|
||||||
logger.info(f"成功向表 '{table_name}' 添加列 '{column_name}'。")
|
|
||||||
|
|
||||||
await connection.run_sync(add_columns_sync)
|
|
||||||
else:
|
|
||||||
logger.info(f"表 '{table_name}' 的列结构一致。")
|
|
||||||
|
|
||||||
# 检查并创建缺失的索引
|
|
||||||
db_indexes = await connection.run_sync(
|
|
||||||
lambda conn: {idx["name"] for idx in inspector.get_indexes(table_name)}
|
|
||||||
)
|
|
||||||
model_indexes = {idx.name for idx in table.indexes}
|
|
||||||
missing_indexes = model_indexes - db_indexes
|
|
||||||
|
|
||||||
if missing_indexes:
|
|
||||||
logger.info(f"在表 '{table_name}' 中发现缺失的索引: {', '.join(missing_indexes)}")
|
|
||||||
|
|
||||||
def add_indexes_sync(conn):
|
|
||||||
for index_name in missing_indexes:
|
|
||||||
index_obj = next((idx for idx in table.indexes if idx.name == index_name), None)
|
|
||||||
if index_obj is not None:
|
|
||||||
index_obj.create(conn)
|
|
||||||
logger.info(f"成功为表 '{table_name}' 创建索引 '{index_name}'。")
|
|
||||||
|
|
||||||
await connection.run_sync(add_indexes_sync)
|
|
||||||
else:
|
|
||||||
logger.debug(f"表 '{table_name}' 的索引一致。")
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"在处理表 '{table_name}' 时发生意外错误: {e}", exc_info=True)
|
|
||||||
continue
|
|
||||||
|
|
||||||
logger.info("数据库结构检查与自动迁移完成。")
|
|
||||||
@@ -1,426 +0,0 @@
|
|||||||
"""SQLAlchemy数据库API模块
|
|
||||||
|
|
||||||
提供基于SQLAlchemy的数据库操作,替换Peewee以解决MySQL连接问题
|
|
||||||
支持自动重连、连接池管理和更好的错误处理
|
|
||||||
"""
|
|
||||||
|
|
||||||
import time
|
|
||||||
import traceback
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import and_, asc, desc, func, select
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_models import (
|
|
||||||
ActionRecords,
|
|
||||||
CacheEntries,
|
|
||||||
ChatStreams,
|
|
||||||
Emoji,
|
|
||||||
Expression,
|
|
||||||
GraphEdges,
|
|
||||||
GraphNodes,
|
|
||||||
ImageDescriptions,
|
|
||||||
Images,
|
|
||||||
LLMUsage,
|
|
||||||
MaiZoneScheduleStatus,
|
|
||||||
Memory,
|
|
||||||
Messages,
|
|
||||||
OnlineTime,
|
|
||||||
PersonInfo,
|
|
||||||
Schedule,
|
|
||||||
ThinkingLog,
|
|
||||||
UserRelationships,
|
|
||||||
get_db_session,
|
|
||||||
)
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("sqlalchemy_database_api")
|
|
||||||
|
|
||||||
# 模型映射表,用于通过名称获取模型类
|
|
||||||
MODEL_MAPPING = {
|
|
||||||
"Messages": Messages,
|
|
||||||
"ActionRecords": ActionRecords,
|
|
||||||
"PersonInfo": PersonInfo,
|
|
||||||
"ChatStreams": ChatStreams,
|
|
||||||
"LLMUsage": LLMUsage,
|
|
||||||
"Emoji": Emoji,
|
|
||||||
"Images": Images,
|
|
||||||
"ImageDescriptions": ImageDescriptions,
|
|
||||||
"OnlineTime": OnlineTime,
|
|
||||||
"Memory": Memory,
|
|
||||||
"Expression": Expression,
|
|
||||||
"ThinkingLog": ThinkingLog,
|
|
||||||
"GraphNodes": GraphNodes,
|
|
||||||
"GraphEdges": GraphEdges,
|
|
||||||
"Schedule": Schedule,
|
|
||||||
"MaiZoneScheduleStatus": MaiZoneScheduleStatus,
|
|
||||||
"CacheEntries": CacheEntries,
|
|
||||||
"UserRelationships": UserRelationships,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def build_filters(model_class, filters: dict[str, Any]):
|
|
||||||
"""构建查询过滤条件"""
|
|
||||||
conditions = []
|
|
||||||
|
|
||||||
for field_name, value in filters.items():
|
|
||||||
if not hasattr(model_class, field_name):
|
|
||||||
logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'")
|
|
||||||
continue
|
|
||||||
|
|
||||||
field = getattr(model_class, field_name)
|
|
||||||
|
|
||||||
if isinstance(value, dict):
|
|
||||||
# 处理 MongoDB 风格的操作符
|
|
||||||
for op, op_value in value.items():
|
|
||||||
if op == "$gt":
|
|
||||||
conditions.append(field > op_value)
|
|
||||||
elif op == "$lt":
|
|
||||||
conditions.append(field < op_value)
|
|
||||||
elif op == "$gte":
|
|
||||||
conditions.append(field >= op_value)
|
|
||||||
elif op == "$lte":
|
|
||||||
conditions.append(field <= op_value)
|
|
||||||
elif op == "$ne":
|
|
||||||
conditions.append(field != op_value)
|
|
||||||
elif op == "$in":
|
|
||||||
conditions.append(field.in_(op_value))
|
|
||||||
elif op == "$nin":
|
|
||||||
conditions.append(~field.in_(op_value))
|
|
||||||
else:
|
|
||||||
logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')")
|
|
||||||
else:
|
|
||||||
# 直接相等比较
|
|
||||||
conditions.append(field == value)
|
|
||||||
|
|
||||||
return conditions
|
|
||||||
|
|
||||||
|
|
||||||
async def db_query(
|
|
||||||
model_class,
|
|
||||||
data: dict[str, Any] | None = None,
|
|
||||||
query_type: str | None = "get",
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
order_by: list[str] | None = None,
|
|
||||||
single_result: bool | None = False,
|
|
||||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
|
||||||
"""执行异步数据库查询操作
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_class: SQLAlchemy模型类
|
|
||||||
data: 用于创建或更新的数据字典
|
|
||||||
query_type: 查询类型 ("get", "create", "update", "delete", "count")
|
|
||||||
filters: 过滤条件字典
|
|
||||||
limit: 限制结果数量
|
|
||||||
order_by: 排序字段,前缀'-'表示降序
|
|
||||||
single_result: 是否只返回单个结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
根据查询类型返回相应结果
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
|
||||||
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
|
|
||||||
|
|
||||||
async with get_db_session() as session:
|
|
||||||
if not session:
|
|
||||||
logger.error("[SQLAlchemy] 无法获取数据库会话")
|
|
||||||
return None if single_result else []
|
|
||||||
|
|
||||||
if query_type == "get":
|
|
||||||
query = select(model_class)
|
|
||||||
|
|
||||||
# 应用过滤条件
|
|
||||||
if filters:
|
|
||||||
conditions = await build_filters(model_class, filters)
|
|
||||||
if conditions:
|
|
||||||
query = query.where(and_(*conditions))
|
|
||||||
|
|
||||||
# 应用排序
|
|
||||||
if order_by:
|
|
||||||
for field_name in order_by:
|
|
||||||
if field_name.startswith("-"):
|
|
||||||
field_name = field_name[1:]
|
|
||||||
if hasattr(model_class, field_name):
|
|
||||||
query = query.order_by(desc(getattr(model_class, field_name)))
|
|
||||||
else:
|
|
||||||
if hasattr(model_class, field_name):
|
|
||||||
query = query.order_by(asc(getattr(model_class, field_name)))
|
|
||||||
|
|
||||||
# 应用限制
|
|
||||||
if limit and limit > 0:
|
|
||||||
query = query.limit(limit)
|
|
||||||
|
|
||||||
# 执行查询
|
|
||||||
result = await session.execute(query)
|
|
||||||
results = result.scalars().all()
|
|
||||||
|
|
||||||
# 转换为字典格式
|
|
||||||
result_dicts = []
|
|
||||||
for result_obj in results:
|
|
||||||
result_dict = {}
|
|
||||||
for column in result_obj.__table__.columns:
|
|
||||||
result_dict[column.name] = getattr(result_obj, column.name)
|
|
||||||
result_dicts.append(result_dict)
|
|
||||||
|
|
||||||
if single_result:
|
|
||||||
return result_dicts[0] if result_dicts else None
|
|
||||||
return result_dicts
|
|
||||||
|
|
||||||
elif query_type == "create":
|
|
||||||
if not data:
|
|
||||||
raise ValueError("创建记录需要提供data参数")
|
|
||||||
|
|
||||||
# 创建新记录
|
|
||||||
new_record = model_class(**data)
|
|
||||||
session.add(new_record)
|
|
||||||
await session.flush() # 获取自动生成的ID
|
|
||||||
|
|
||||||
# 转换为字典格式返回
|
|
||||||
result_dict = {}
|
|
||||||
for column in new_record.__table__.columns:
|
|
||||||
result_dict[column.name] = getattr(new_record, column.name)
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
elif query_type == "update":
|
|
||||||
if not data:
|
|
||||||
raise ValueError("更新记录需要提供data参数")
|
|
||||||
|
|
||||||
query = select(model_class)
|
|
||||||
|
|
||||||
# 应用过滤条件
|
|
||||||
if filters:
|
|
||||||
conditions = await build_filters(model_class, filters)
|
|
||||||
if conditions:
|
|
||||||
query = query.where(and_(*conditions))
|
|
||||||
|
|
||||||
# 首先获取要更新的记录
|
|
||||||
result = await session.execute(query)
|
|
||||||
records_to_update = result.scalars().all()
|
|
||||||
|
|
||||||
# 更新每个记录
|
|
||||||
affected_rows = 0
|
|
||||||
for record in records_to_update:
|
|
||||||
for field, value in data.items():
|
|
||||||
if hasattr(record, field):
|
|
||||||
setattr(record, field, value)
|
|
||||||
affected_rows += 1
|
|
||||||
|
|
||||||
return affected_rows
|
|
||||||
|
|
||||||
elif query_type == "delete":
|
|
||||||
query = select(model_class)
|
|
||||||
|
|
||||||
# 应用过滤条件
|
|
||||||
if filters:
|
|
||||||
conditions = await build_filters(model_class, filters)
|
|
||||||
if conditions:
|
|
||||||
query = query.where(and_(*conditions))
|
|
||||||
|
|
||||||
# 首先获取要删除的记录
|
|
||||||
result = await session.execute(query)
|
|
||||||
records_to_delete = result.scalars().all()
|
|
||||||
|
|
||||||
# 删除记录
|
|
||||||
affected_rows = 0
|
|
||||||
for record in records_to_delete:
|
|
||||||
await session.delete(record)
|
|
||||||
affected_rows += 1
|
|
||||||
|
|
||||||
return affected_rows
|
|
||||||
|
|
||||||
elif query_type == "count":
|
|
||||||
query = select(func.count(model_class.id))
|
|
||||||
|
|
||||||
# 应用过滤条件
|
|
||||||
if filters:
|
|
||||||
conditions = await build_filters(model_class, filters)
|
|
||||||
if conditions:
|
|
||||||
query = query.where(and_(*conditions))
|
|
||||||
|
|
||||||
result = await session.execute(query)
|
|
||||||
return result.scalar()
|
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
|
||||||
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
# 根据查询类型返回合适的默认值
|
|
||||||
if query_type == "get":
|
|
||||||
return None if single_result else []
|
|
||||||
elif query_type in ["create", "update", "delete", "count"]:
|
|
||||||
return None
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[SQLAlchemy] 意外错误: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
|
|
||||||
if query_type == "get":
|
|
||||||
return None if single_result else []
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def db_save(
|
|
||||||
model_class, data: dict[str, Any], key_field: str | None = None, key_value: Any | None = None
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""异步保存数据到数据库(创建或更新)
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_class: SQLAlchemy模型类
|
|
||||||
data: 要保存的数据字典
|
|
||||||
key_field: 用于查找现有记录的字段名
|
|
||||||
key_value: 用于查找现有记录的字段值
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
保存后的记录数据或None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
async with get_db_session() as session:
|
|
||||||
if not session:
|
|
||||||
logger.error("[SQLAlchemy] 无法获取数据库会话")
|
|
||||||
return None
|
|
||||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
|
||||||
if key_field and key_value is not None:
|
|
||||||
if hasattr(model_class, key_field):
|
|
||||||
query = select(model_class).where(getattr(model_class, key_field) == key_value)
|
|
||||||
result = await session.execute(query)
|
|
||||||
existing_record = result.scalars().first()
|
|
||||||
|
|
||||||
if existing_record:
|
|
||||||
# 更新现有记录
|
|
||||||
for field, value in data.items():
|
|
||||||
if hasattr(existing_record, field):
|
|
||||||
setattr(existing_record, field, value)
|
|
||||||
|
|
||||||
await session.flush()
|
|
||||||
|
|
||||||
# 转换为字典格式返回
|
|
||||||
result_dict = {}
|
|
||||||
for column in existing_record.__table__.columns:
|
|
||||||
result_dict[column.name] = getattr(existing_record, column.name)
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
# 创建新记录
|
|
||||||
new_record = model_class(**data)
|
|
||||||
session.add(new_record)
|
|
||||||
await session.flush()
|
|
||||||
|
|
||||||
# 转换为字典格式返回
|
|
||||||
result_dict = {}
|
|
||||||
for column in new_record.__table__.columns:
|
|
||||||
result_dict[column.name] = getattr(new_record, column.name)
|
|
||||||
return result_dict
|
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
|
||||||
logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[SQLAlchemy] 保存时意外错误: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def db_get(
|
|
||||||
model_class,
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
order_by: str | None = None,
|
|
||||||
single_result: bool | None = False,
|
|
||||||
) -> list[dict[str, Any]] | dict[str, Any] | None:
|
|
||||||
"""异步从数据库获取记录
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_class: SQLAlchemy模型类
|
|
||||||
filters: 过滤条件
|
|
||||||
limit: 结果数量限制
|
|
||||||
order_by: 排序字段,前缀'-'表示降序
|
|
||||||
single_result: 是否只返回单个结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
记录数据或None
|
|
||||||
"""
|
|
||||||
order_by_list = [order_by] if order_by else None
|
|
||||||
return await db_query(
|
|
||||||
model_class=model_class,
|
|
||||||
query_type="get",
|
|
||||||
filters=filters,
|
|
||||||
limit=limit,
|
|
||||||
order_by=order_by_list,
|
|
||||||
single_result=single_result,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def store_action_info(
|
|
||||||
chat_stream=None,
|
|
||||||
action_build_into_prompt: bool = False,
|
|
||||||
action_prompt_display: str = "",
|
|
||||||
action_done: bool = True,
|
|
||||||
thinking_id: str = "",
|
|
||||||
action_data: dict | None = None,
|
|
||||||
action_name: str = "",
|
|
||||||
) -> dict[str, Any] | None:
|
|
||||||
"""异步存储动作信息到数据库
|
|
||||||
|
|
||||||
Args:
|
|
||||||
chat_stream: 聊天流对象
|
|
||||||
action_build_into_prompt: 是否将此动作构建到提示中
|
|
||||||
action_prompt_display: 动作的提示显示文本
|
|
||||||
action_done: 动作是否完成
|
|
||||||
thinking_id: 关联的思考ID
|
|
||||||
action_data: 动作数据字典
|
|
||||||
action_name: 动作名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
保存的记录数据或None
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
import orjson
|
|
||||||
|
|
||||||
# 构建动作记录数据
|
|
||||||
record_data = {
|
|
||||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
|
||||||
"time": time.time(),
|
|
||||||
"action_name": action_name,
|
|
||||||
"action_data": orjson.dumps(action_data or {}).decode("utf-8"),
|
|
||||||
"action_done": action_done,
|
|
||||||
"action_build_into_prompt": action_build_into_prompt,
|
|
||||||
"action_prompt_display": action_prompt_display,
|
|
||||||
}
|
|
||||||
|
|
||||||
# 从chat_stream获取聊天信息
|
|
||||||
if chat_stream:
|
|
||||||
record_data.update(
|
|
||||||
{
|
|
||||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
|
||||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
|
||||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
record_data.update(
|
|
||||||
{
|
|
||||||
"chat_id": "",
|
|
||||||
"chat_info_stream_id": "",
|
|
||||||
"chat_info_platform": "",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
# 保存记录
|
|
||||||
saved_record = await db_save(
|
|
||||||
ActionRecords, data=record_data, key_field="action_id", key_value=record_data["action_id"]
|
|
||||||
)
|
|
||||||
|
|
||||||
if saved_record:
|
|
||||||
logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
|
||||||
else:
|
|
||||||
logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}")
|
|
||||||
|
|
||||||
return saved_record
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
|
|
||||||
traceback.print_exc()
|
|
||||||
return None
|
|
||||||
@@ -1,124 +0,0 @@
|
|||||||
"""SQLAlchemy数据库初始化模块
|
|
||||||
|
|
||||||
替换Peewee的数据库初始化逻辑
|
|
||||||
提供统一的异步数据库初始化接口
|
|
||||||
"""
|
|
||||||
|
|
||||||
from sqlalchemy.exc import SQLAlchemyError
|
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_models import Base, get_engine, initialize_database
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("sqlalchemy_init")
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_sqlalchemy_database() -> bool:
|
|
||||||
"""
|
|
||||||
初始化SQLAlchemy异步数据库
|
|
||||||
创建所有表结构
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 初始化是否成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info("开始初始化SQLAlchemy异步数据库...")
|
|
||||||
|
|
||||||
# 初始化数据库引擎和会话
|
|
||||||
engine, session_local = await initialize_database()
|
|
||||||
|
|
||||||
if engine is None:
|
|
||||||
logger.error("数据库引擎初始化失败")
|
|
||||||
return False
|
|
||||||
|
|
||||||
logger.info("SQLAlchemy异步数据库初始化成功")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
|
||||||
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"数据库初始化过程中发生未知错误: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def create_all_tables() -> bool:
|
|
||||||
"""
|
|
||||||
异步创建所有数据库表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 创建是否成功
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info("开始创建数据库表...")
|
|
||||||
|
|
||||||
engine = await get_engine()
|
|
||||||
if engine is None:
|
|
||||||
logger.error("无法获取数据库引擎")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 异步创建所有表
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
await conn.run_sync(Base.metadata.create_all)
|
|
||||||
|
|
||||||
logger.info("数据库表创建成功")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except SQLAlchemyError as e:
|
|
||||||
logger.error(f"创建数据库表失败: {e}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建数据库表过程中发生未知错误: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
async def get_database_info() -> dict | None:
|
|
||||||
"""
|
|
||||||
异步获取数据库信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: 数据库信息字典,包含引擎信息等
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
engine = await get_engine()
|
|
||||||
if engine is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
info = {
|
|
||||||
"engine_name": engine.name,
|
|
||||||
"driver": engine.driver,
|
|
||||||
"url": str(engine.url).replace(engine.url.password or "", "***"), # 隐藏密码
|
|
||||||
"pool_size": getattr(engine.pool, "size", None),
|
|
||||||
"max_overflow": getattr(engine.pool, "max_overflow", None),
|
|
||||||
}
|
|
||||||
|
|
||||||
return info
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"获取数据库信息失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
_database_initialized = False
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_database_compat() -> bool:
|
|
||||||
"""
|
|
||||||
兼容性异步数据库初始化函数
|
|
||||||
用于替换原有的Peewee初始化代码
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 初始化是否成功
|
|
||||||
"""
|
|
||||||
global _database_initialized
|
|
||||||
|
|
||||||
if _database_initialized:
|
|
||||||
return True
|
|
||||||
|
|
||||||
success = await initialize_sqlalchemy_database()
|
|
||||||
if success:
|
|
||||||
success = await create_all_tables()
|
|
||||||
|
|
||||||
if success:
|
|
||||||
_database_initialized = True
|
|
||||||
|
|
||||||
return success
|
|
||||||
@@ -1,892 +0,0 @@
|
|||||||
"""SQLAlchemy数据库模型定义
|
|
||||||
|
|
||||||
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
|
||||||
|
|
||||||
说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到
|
|
||||||
SQLAlchemy 2.0 推荐的带类型注解的声明式风格:
|
|
||||||
|
|
||||||
field_name: Mapped[PyType] = mapped_column(Type, ...)
|
|
||||||
|
|
||||||
这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。
|
|
||||||
当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text, text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
|
|
||||||
from src.common.database.connection_pool_manager import get_connection_pool_manager
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("sqlalchemy_models")
|
|
||||||
|
|
||||||
# 创建基类
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
# 全局异步引擎与会话工厂占位(延迟初始化)
|
|
||||||
_engine: AsyncEngine | None = None
|
|
||||||
_SessionLocal: async_sessionmaker[AsyncSession] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
async def enable_sqlite_wal_mode(engine):
|
|
||||||
"""为 SQLite 启用 WAL 模式以提高并发性能"""
|
|
||||||
try:
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
# 启用 WAL 模式
|
|
||||||
await conn.execute(text("PRAGMA journal_mode = WAL"))
|
|
||||||
# 设置适中的同步级别,平衡性能和安全性
|
|
||||||
await conn.execute(text("PRAGMA synchronous = NORMAL"))
|
|
||||||
# 启用外键约束
|
|
||||||
await conn.execute(text("PRAGMA foreign_keys = ON"))
|
|
||||||
# 设置 busy_timeout,避免锁定错误
|
|
||||||
await conn.execute(text("PRAGMA busy_timeout = 60000")) # 60秒
|
|
||||||
|
|
||||||
logger.info("[SQLite] WAL 模式已启用,并发性能已优化")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SQLite] 启用 WAL 模式失败: {e},将使用默认配置")
|
|
||||||
|
|
||||||
|
|
||||||
async def maintain_sqlite_database():
|
|
||||||
"""定期维护 SQLite 数据库性能"""
|
|
||||||
try:
|
|
||||||
engine, SessionLocal = await initialize_database()
|
|
||||||
if not engine:
|
|
||||||
return
|
|
||||||
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
# 检查并确保 WAL 模式仍然启用
|
|
||||||
result = await conn.execute(text("PRAGMA journal_mode"))
|
|
||||||
journal_mode = result.scalar()
|
|
||||||
|
|
||||||
if journal_mode != "wal":
|
|
||||||
await conn.execute(text("PRAGMA journal_mode = WAL"))
|
|
||||||
logger.info("[SQLite] WAL 模式已重新启用")
|
|
||||||
|
|
||||||
# 优化数据库性能
|
|
||||||
await conn.execute(text("PRAGMA synchronous = NORMAL"))
|
|
||||||
await conn.execute(text("PRAGMA busy_timeout = 60000"))
|
|
||||||
await conn.execute(text("PRAGMA foreign_keys = ON"))
|
|
||||||
|
|
||||||
# 定期清理(可选,根据需要启用)
|
|
||||||
# await conn.execute(text("PRAGMA optimize"))
|
|
||||||
|
|
||||||
logger.info("[SQLite] 数据库维护完成")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SQLite] 数据库维护失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_sqlite_performance_config():
|
|
||||||
"""获取 SQLite 性能优化配置"""
|
|
||||||
return {
|
|
||||||
"journal_mode": "WAL", # 提高并发性能
|
|
||||||
"synchronous": "NORMAL", # 平衡性能和安全性
|
|
||||||
"busy_timeout": 60000, # 60秒超时
|
|
||||||
"foreign_keys": "ON", # 启用外键约束
|
|
||||||
"cache_size": -10000, # 10MB 缓存
|
|
||||||
"temp_store": "MEMORY", # 临时存储使用内存
|
|
||||||
"mmap_size": 268435456, # 256MB 内存映射
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# MySQL兼容的字段类型辅助函数
|
|
||||||
def get_string_field(max_length=255, **kwargs):
|
|
||||||
"""
|
|
||||||
根据数据库类型返回合适的字符串字段
|
|
||||||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
|
||||||
"""
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
if global_config.database.database_type == "mysql":
|
|
||||||
return String(max_length, **kwargs)
|
|
||||||
else:
|
|
||||||
return Text(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatStreams(Base):
|
|
||||||
"""聊天流模型"""
|
|
||||||
|
|
||||||
__tablename__ = "chat_streams"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
stream_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, unique=True, index=True)
|
|
||||||
create_time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
group_platform: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
group_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True)
|
|
||||||
group_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
user_platform: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
|
||||||
user_nickname: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
energy_value: Mapped[float | None] = mapped_column(Float, nullable=True, default=5.0)
|
|
||||||
sleep_pressure: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0)
|
|
||||||
focus_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5)
|
|
||||||
# 动态兴趣度系统字段
|
|
||||||
base_interest_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5)
|
|
||||||
message_interest_total: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0)
|
|
||||||
message_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
|
||||||
action_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
|
||||||
reply_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
|
||||||
last_interaction_time: Mapped[float | None] = mapped_column(Float, nullable=True, default=None)
|
|
||||||
consecutive_no_reply: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
|
||||||
# 消息打断系统字段
|
|
||||||
interruption_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
|
||||||
# 聊天流印象字段
|
|
||||||
stream_impression_text: Mapped[str | None] = mapped_column(Text, nullable=True) # 对聊天流的主观印象描述
|
|
||||||
stream_chat_style: Mapped[str | None] = mapped_column(Text, nullable=True) # 聊天流的总体风格
|
|
||||||
stream_topic_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 话题关键词,逗号分隔
|
|
||||||
stream_interest_score: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5) # 对聊天流的兴趣程度(0-1)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_chatstreams_stream_id", "stream_id"),
|
|
||||||
Index("idx_chatstreams_user_id", "user_id"),
|
|
||||||
Index("idx_chatstreams_group_id", "group_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMUsage(Base):
|
|
||||||
"""LLM使用记录模型"""
|
|
||||||
|
|
||||||
__tablename__ = "llm_usage"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
model_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
|
||||||
model_assign_name: Mapped[str] = mapped_column(get_string_field(100), index=True)
|
|
||||||
model_api_provider: Mapped[str] = mapped_column(get_string_field(100), index=True)
|
|
||||||
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
|
||||||
request_type: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
|
||||||
endpoint: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
prompt_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
completion_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
time_cost: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
total_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
cost: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
status: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_llmusage_model_name", "model_name"),
|
|
||||||
Index("idx_llmusage_model_assign_name", "model_assign_name"),
|
|
||||||
Index("idx_llmusage_model_api_provider", "model_api_provider"),
|
|
||||||
Index("idx_llmusage_time_cost", "time_cost"),
|
|
||||||
Index("idx_llmusage_user_id", "user_id"),
|
|
||||||
Index("idx_llmusage_request_type", "request_type"),
|
|
||||||
Index("idx_llmusage_timestamp", "timestamp"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Emoji(Base):
|
|
||||||
"""表情包模型"""
|
|
||||||
|
|
||||||
__tablename__ = "emoji"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
full_path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True)
|
|
||||||
format: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
|
||||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
query_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
is_registered: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
is_banned: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
emotion: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
record_time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
register_time: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
last_used_time: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_emoji_full_path", "full_path"),
|
|
||||||
Index("idx_emoji_hash", "emoji_hash"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Messages(Base):
|
|
||||||
"""消息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "messages"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
message_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
|
||||||
time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
|
||||||
reply_to: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
interest_value: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
key_words: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
key_words_lite: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
is_mentioned: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
|
||||||
|
|
||||||
# 从 chat_info 扁平化而来的字段
|
|
||||||
chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
chat_info_user_platform: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
chat_info_user_id: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
chat_info_user_nickname: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
chat_info_user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
chat_info_group_platform: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
chat_info_group_id: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
chat_info_group_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
chat_info_create_time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
chat_info_last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
|
|
||||||
# 从顶层 user_info 扁平化而来的字段
|
|
||||||
user_platform: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
user_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True)
|
|
||||||
user_nickname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
|
|
||||||
processed_plain_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
display_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
memorized_times: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
priority_mode: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
priority_info: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
additional_config: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
is_emoji: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
is_picid: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
is_command: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
is_notify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
is_public_notice: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
notice_type: Mapped[str | None] = mapped_column(String(50), nullable=True)
|
|
||||||
|
|
||||||
# 兴趣度系统字段
|
|
||||||
actions: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
should_reply: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False)
|
|
||||||
should_act: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_messages_message_id", "message_id"),
|
|
||||||
Index("idx_messages_chat_id", "chat_id"),
|
|
||||||
Index("idx_messages_time", "time"),
|
|
||||||
Index("idx_messages_user_id", "user_id"),
|
|
||||||
Index("idx_messages_should_reply", "should_reply"),
|
|
||||||
Index("idx_messages_should_act", "should_act"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ActionRecords(Base):
|
|
||||||
"""动作记录模型"""
|
|
||||||
|
|
||||||
__tablename__ = "action_records"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
action_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
|
||||||
time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
action_name: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
action_data: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
action_done: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
action_build_into_prompt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
action_prompt_display: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
|
||||||
chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_actionrecords_action_id", "action_id"),
|
|
||||||
Index("idx_actionrecords_chat_id", "chat_id"),
|
|
||||||
Index("idx_actionrecords_time", "time"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Images(Base):
|
|
||||||
"""图像信息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "images"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
image_id: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
|
||||||
emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
|
||||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True)
|
|
||||||
count: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
|
||||||
timestamp: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
type: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_images_emoji_hash", "emoji_hash"),
|
|
||||||
Index("idx_images_path", "path"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDescriptions(Base):
|
|
||||||
"""图像描述信息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "image_descriptions"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
type: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
image_description_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
|
||||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
timestamp: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),)
|
|
||||||
|
|
||||||
|
|
||||||
class Videos(Base):
|
|
||||||
"""视频信息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "videos"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
video_id: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
|
||||||
video_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True, unique=True)
|
|
||||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
count: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
|
||||||
timestamp: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
# 视频特有属性
|
|
||||||
duration: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
frame_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
||||||
fps: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
resolution: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
file_size: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_videos_video_hash", "video_hash"),
|
|
||||||
Index("idx_videos_timestamp", "timestamp"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OnlineTime(Base):
|
|
||||||
"""在线时长记录模型"""
|
|
||||||
|
|
||||||
__tablename__ = "online_time"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
timestamp: Mapped[str] = mapped_column(Text, nullable=False, default=str(datetime.datetime.now))
|
|
||||||
duration: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
start_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
end_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),)
|
|
||||||
|
|
||||||
|
|
||||||
class PersonInfo(Base):
|
|
||||||
"""人物信息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "person_info"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
person_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True)
|
|
||||||
person_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
name_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
|
||||||
nickname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
impression: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
short_impression: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
points: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
forgotten_points: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
info_list: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
know_times: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
know_since: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
last_know: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
attitude: Mapped[int | None] = mapped_column(Integer, nullable=True, default=50)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_personinfo_person_id", "person_id"),
|
|
||||||
Index("idx_personinfo_user_id", "user_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BotPersonalityInterests(Base):
|
|
||||||
"""机器人人格兴趣标签模型"""
|
|
||||||
|
|
||||||
__tablename__ = "bot_personality_interests"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
personality_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
|
||||||
personality_description: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
interest_tags: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
embedding_model: Mapped[str] = mapped_column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
|
|
||||||
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
|
||||||
last_updated: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_botpersonality_personality_id", "personality_id"),
|
|
||||||
Index("idx_botpersonality_version", "version"),
|
|
||||||
Index("idx_botpersonality_last_updated", "last_updated"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Memory(Base):
|
|
||||||
"""记忆模型"""
|
|
||||||
|
|
||||||
__tablename__ = "memory"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
memory_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
|
||||||
chat_id: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
memory_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
keywords: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
create_time: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
last_view_time: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_memory_memory_id", "memory_id"),)
|
|
||||||
|
|
||||||
|
|
||||||
class Expression(Base):
|
|
||||||
"""表达风格模型"""
|
|
||||||
|
|
||||||
__tablename__ = "expression"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
situation: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
style: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
count: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
|
||||||
type: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
create_date: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
|
|
||||||
|
|
||||||
|
|
||||||
class ThinkingLog(Base):
|
|
||||||
"""思考日志模型"""
|
|
||||||
|
|
||||||
__tablename__ = "thinking_logs"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
|
||||||
trigger_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
response_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
trigger_info_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
response_info_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
timing_results_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
chat_history_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
chat_history_in_thinking_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
chat_history_after_response_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
heartflow_data_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
reasoning_data_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphNodes(Base):
|
|
||||||
"""记忆图节点模型"""
|
|
||||||
|
|
||||||
__tablename__ = "graph_nodes"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
concept: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True)
|
|
||||||
memory_items: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
hash: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
weight: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
|
|
||||||
created_time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
last_modified: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_graphnodes_concept", "concept"),)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphEdges(Base):
|
|
||||||
"""记忆图边模型"""
|
|
||||||
|
|
||||||
__tablename__ = "graph_edges"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
source: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True)
|
|
||||||
target: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True)
|
|
||||||
strength: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
hash: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
created_time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
last_modified: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_graphedges_source", "source"),
|
|
||||||
Index("idx_graphedges_target", "target"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Schedule(Base):
|
|
||||||
"""日程模型"""
|
|
||||||
|
|
||||||
__tablename__ = "schedule"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
date: Mapped[str] = mapped_column(get_string_field(10), nullable=False, unique=True, index=True)
|
|
||||||
schedule_data: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_schedule_date", "date"),)
|
|
||||||
|
|
||||||
|
|
||||||
class MaiZoneScheduleStatus(Base):
|
|
||||||
"""麦麦空间日程处理状态模型"""
|
|
||||||
|
|
||||||
__tablename__ = "maizone_schedule_status"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
datetime_hour: Mapped[str] = mapped_column(get_string_field(13), nullable=False, unique=True, index=True)
|
|
||||||
activity: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
is_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
processed_at: Mapped[datetime.datetime | None] = mapped_column(DateTime, nullable=True)
|
|
||||||
story_content: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
send_success: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
|
||||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_maizone_datetime_hour", "datetime_hour"),
|
|
||||||
Index("idx_maizone_is_processed", "is_processed"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BanUser(Base):
|
|
||||||
"""被禁用用户模型
|
|
||||||
|
|
||||||
使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型,
|
|
||||||
避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "ban_users"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
|
||||||
violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
|
||||||
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_violation_num", "violation_num"),
|
|
||||||
Index("idx_banuser_user_id", "user_id"),
|
|
||||||
Index("idx_banuser_platform", "platform"),
|
|
||||||
Index("idx_banuser_platform_user_id", "platform", "user_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AntiInjectionStats(Base):
|
|
||||||
"""反注入系统统计模型"""
|
|
||||||
|
|
||||||
__tablename__ = "anti_injection_stats"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
total_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
"""总处理消息数"""
|
|
||||||
|
|
||||||
detected_injections: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
"""检测到的注入攻击数"""
|
|
||||||
|
|
||||||
blocked_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
"""被阻止的消息数"""
|
|
||||||
|
|
||||||
shielded_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
"""被加盾的消息数"""
|
|
||||||
|
|
||||||
processing_time_total: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
|
||||||
"""总处理时间"""
|
|
||||||
|
|
||||||
total_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
|
||||||
"""累计总处理时间"""
|
|
||||||
|
|
||||||
last_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
|
||||||
"""最近一次处理时间"""
|
|
||||||
|
|
||||||
error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
"""错误计数"""
|
|
||||||
|
|
||||||
start_time: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
"""统计开始时间"""
|
|
||||||
|
|
||||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
"""记录创建时间"""
|
|
||||||
|
|
||||||
updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
|
||||||
"""记录更新时间"""
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_anti_injection_stats_created_at", "created_at"),
|
|
||||||
Index("idx_anti_injection_stats_updated_at", "updated_at"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CacheEntries(Base):
|
|
||||||
"""工具缓存条目模型"""
|
|
||||||
|
|
||||||
__tablename__ = "cache_entries"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
cache_key: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True)
|
|
||||||
"""缓存键,包含工具名、参数和代码哈希"""
|
|
||||||
|
|
||||||
cache_value: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
"""缓存的数据,JSON格式"""
|
|
||||||
|
|
||||||
expires_at: Mapped[float] = mapped_column(Float, nullable=False, index=True)
|
|
||||||
"""过期时间戳"""
|
|
||||||
|
|
||||||
tool_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
|
||||||
"""工具名称"""
|
|
||||||
|
|
||||||
created_at: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time())
|
|
||||||
"""创建时间戳"""
|
|
||||||
|
|
||||||
last_accessed: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time())
|
|
||||||
"""最后访问时间戳"""
|
|
||||||
|
|
||||||
access_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
"""访问次数"""
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_cache_entries_key", "cache_key"),
|
|
||||||
Index("idx_cache_entries_expires_at", "expires_at"),
|
|
||||||
Index("idx_cache_entries_tool_name", "tool_name"),
|
|
||||||
Index("idx_cache_entries_created_at", "created_at"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MonthlyPlan(Base):
|
|
||||||
"""月度计划模型"""
|
|
||||||
|
|
||||||
__tablename__ = "monthly_plans"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
plan_text: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
target_month: Mapped[str] = mapped_column(String(7), nullable=False, index=True)
|
|
||||||
status: Mapped[str] = mapped_column(get_string_field(20), nullable=False, default="active", index=True)
|
|
||||||
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True)
|
|
||||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, index=True)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
|
|
||||||
Index("idx_monthlyplan_last_used_date", "last_used_date"),
|
|
||||||
Index("idx_monthlyplan_usage_count", "usage_count"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_database_url():
|
|
||||||
"""获取数据库连接URL"""
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
config = global_config.database
|
|
||||||
|
|
||||||
if config.database_type == "mysql":
|
|
||||||
# 对用户名和密码进行URL编码,处理特殊字符
|
|
||||||
from urllib.parse import quote_plus
|
|
||||||
|
|
||||||
encoded_user = quote_plus(config.mysql_user)
|
|
||||||
encoded_password = quote_plus(config.mysql_password)
|
|
||||||
|
|
||||||
# 检查是否配置了Unix socket连接
|
|
||||||
if config.mysql_unix_socket:
|
|
||||||
# 使用Unix socket连接
|
|
||||||
encoded_socket = quote_plus(config.mysql_unix_socket)
|
|
||||||
return (
|
|
||||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
|
||||||
f"@/{config.mysql_database}"
|
|
||||||
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 使用标准TCP连接
|
|
||||||
return (
|
|
||||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
|
||||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
|
||||||
f"?charset={config.mysql_charset}"
|
|
||||||
)
|
|
||||||
else: # SQLite
|
|
||||||
# 如果是相对路径,则相对于项目根目录
|
|
||||||
if not os.path.isabs(config.sqlite_path):
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
||||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
|
||||||
else:
|
|
||||||
db_path = config.sqlite_path
|
|
||||||
|
|
||||||
# 确保数据库目录存在
|
|
||||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
|
||||||
|
|
||||||
return f"sqlite+aiosqlite:///{db_path}"
|
|
||||||
|
|
||||||
|
|
||||||
_initializing: bool = False # 防止递归初始化
|
|
||||||
|
|
||||||
async def initialize_database() -> tuple["AsyncEngine", async_sessionmaker[AsyncSession]]:
|
|
||||||
"""初始化异步数据库引擎和会话
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
tuple[AsyncEngine, async_sessionmaker[AsyncSession]]: 创建好的异步引擎与会话工厂。
|
|
||||||
|
|
||||||
说明:
|
|
||||||
显式的返回类型标注有助于 Pyright/Pylance 正确推断调用处的对象,
|
|
||||||
避免后续对返回值再次 `await` 时出现 *"tuple[...] 并非 awaitable"* 的误用。
|
|
||||||
"""
|
|
||||||
global _engine, _SessionLocal, _initializing
|
|
||||||
|
|
||||||
# 已经初始化直接返回
|
|
||||||
if _engine is not None and _SessionLocal is not None:
|
|
||||||
return _engine, _SessionLocal
|
|
||||||
|
|
||||||
# 正在初始化的并发调用等待主初始化完成,避免递归
|
|
||||||
if _initializing:
|
|
||||||
import asyncio
|
|
||||||
for _ in range(1000): # 最多等待约10秒
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
if _engine is not None and _SessionLocal is not None:
|
|
||||||
return _engine, _SessionLocal
|
|
||||||
raise RuntimeError("等待数据库初始化完成超时 (reentrancy guard)")
|
|
||||||
|
|
||||||
_initializing = True
|
|
||||||
try:
|
|
||||||
database_url = get_database_url()
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
config = global_config.database
|
|
||||||
|
|
||||||
# 配置引擎参数
|
|
||||||
engine_kwargs: dict[str, Any] = {
|
|
||||||
"echo": False, # 生产环境关闭SQL日志
|
|
||||||
"future": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.database_type == "mysql":
|
|
||||||
engine_kwargs.update(
|
|
||||||
{
|
|
||||||
"pool_size": config.connection_pool_size,
|
|
||||||
"max_overflow": config.connection_pool_size * 2,
|
|
||||||
"pool_timeout": config.connection_timeout,
|
|
||||||
"pool_recycle": 3600,
|
|
||||||
"pool_pre_ping": True,
|
|
||||||
"connect_args": {
|
|
||||||
"autocommit": config.mysql_autocommit,
|
|
||||||
"charset": config.mysql_charset,
|
|
||||||
"connect_timeout": config.connection_timeout,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
engine_kwargs.update(
|
|
||||||
{
|
|
||||||
"connect_args": {
|
|
||||||
"check_same_thread": False,
|
|
||||||
"timeout": 60,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_engine = create_async_engine(database_url, **engine_kwargs)
|
|
||||||
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
|
||||||
|
|
||||||
# 迁移
|
|
||||||
from src.common.database.db_migration import check_and_migrate_database
|
|
||||||
await check_and_migrate_database(existing_engine=_engine)
|
|
||||||
|
|
||||||
if config.database_type == "sqlite":
|
|
||||||
await enable_sqlite_wal_mode(_engine)
|
|
||||||
|
|
||||||
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
|
|
||||||
return _engine, _SessionLocal
|
|
||||||
finally:
|
|
||||||
_initializing = False
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def get_db_session() -> AsyncGenerator[AsyncSession]:
|
|
||||||
"""
|
|
||||||
异步数据库会话上下文管理器。
|
|
||||||
在初始化失败时会yield None,调用方需要检查会话是否为None。
|
|
||||||
|
|
||||||
现在使用透明的连接池管理器来复用现有连接,提高并发性能。
|
|
||||||
"""
|
|
||||||
SessionLocal = None
|
|
||||||
try:
|
|
||||||
_, SessionLocal = await initialize_database()
|
|
||||||
if not SessionLocal:
|
|
||||||
raise RuntimeError("数据库会话工厂 (_SessionLocal) 未初始化。")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"数据库初始化失败,无法创建会话: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# 使用连接池管理器获取会话
|
|
||||||
pool_manager = get_connection_pool_manager()
|
|
||||||
|
|
||||||
async with pool_manager.get_session(SessionLocal) as session:
|
|
||||||
# 对于 SQLite,在会话开始时设置 PRAGMA(仅对新连接)
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
if global_config.database.database_type == "sqlite":
|
|
||||||
try:
|
|
||||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
|
||||||
await session.execute(text("PRAGMA foreign_keys = ON"))
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}")
|
|
||||||
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
async def get_engine():
|
|
||||||
"""获取异步数据库引擎"""
|
|
||||||
engine, _ = await initialize_database()
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
class PermissionNodes(Base):
|
|
||||||
"""权限节点模型"""
|
|
||||||
|
|
||||||
__tablename__ = "permission_nodes"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
node_name: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True)
|
|
||||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
plugin_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
|
||||||
default_granted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
|
||||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_permission_plugin", "plugin_name"),
|
|
||||||
Index("idx_permission_node", "node_name"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UserPermissions(Base):
|
|
||||||
"""用户权限模型"""
|
|
||||||
|
|
||||||
__tablename__ = "user_permissions"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
platform: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
|
||||||
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
|
||||||
permission_node: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True)
|
|
||||||
granted: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
|
||||||
granted_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
|
||||||
granted_by: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_user_platform_id", "platform", "user_id"),
|
|
||||||
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
|
||||||
Index("idx_permission_granted", "permission_node", "granted"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UserRelationships(Base):
|
|
||||||
"""用户关系模型 - 存储用户与bot的关系数据"""
|
|
||||||
|
|
||||||
__tablename__ = "user_relationships"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True)
|
|
||||||
user_name: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True)
|
|
||||||
user_aliases: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户别名,逗号分隔
|
|
||||||
relationship_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
preference_keywords: Mapped[str | None] = mapped_column(Text, nullable=True) # 用户偏好关键词,逗号分隔
|
|
||||||
relationship_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
|
||||||
last_updated: Mapped[float] = mapped_column(Float, nullable=False, default=time.time)
|
|
||||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_user_relationship_id", "user_id"),
|
|
||||||
Index("idx_relationship_score", "relationship_score"),
|
|
||||||
Index("idx_relationship_updated", "last_updated"),
|
|
||||||
)
|
|
||||||
@@ -1,872 +0,0 @@
|
|||||||
"""SQLAlchemy数据库模型定义
|
|
||||||
|
|
||||||
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
|
||||||
|
|
||||||
说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到
|
|
||||||
SQLAlchemy 2.0 推荐的带类型注解的声明式风格:
|
|
||||||
|
|
||||||
field_name: Mapped[PyType] = mapped_column(Type, ...)
|
|
||||||
|
|
||||||
这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。
|
|
||||||
当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。
|
|
||||||
"""
|
|
||||||
|
|
||||||
import datetime
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
||||||
from sqlalchemy.ext.declarative import declarative_base
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column
|
|
||||||
|
|
||||||
from src.common.database.connection_pool_manager import get_connection_pool_manager
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("sqlalchemy_models")
|
|
||||||
|
|
||||||
# 创建基类
|
|
||||||
Base = declarative_base()
|
|
||||||
|
|
||||||
|
|
||||||
async def enable_sqlite_wal_mode(engine):
|
|
||||||
"""为 SQLite 启用 WAL 模式以提高并发性能"""
|
|
||||||
try:
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
# 启用 WAL 模式
|
|
||||||
await conn.execute(text("PRAGMA journal_mode = WAL"))
|
|
||||||
# 设置适中的同步级别,平衡性能和安全性
|
|
||||||
await conn.execute(text("PRAGMA synchronous = NORMAL"))
|
|
||||||
# 启用外键约束
|
|
||||||
await conn.execute(text("PRAGMA foreign_keys = ON"))
|
|
||||||
# 设置 busy_timeout,避免锁定错误
|
|
||||||
await conn.execute(text("PRAGMA busy_timeout = 60000")) # 60秒
|
|
||||||
|
|
||||||
logger.info("[SQLite] WAL 模式已启用,并发性能已优化")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SQLite] 启用 WAL 模式失败: {e},将使用默认配置")
|
|
||||||
|
|
||||||
|
|
||||||
async def maintain_sqlite_database():
|
|
||||||
"""定期维护 SQLite 数据库性能"""
|
|
||||||
try:
|
|
||||||
engine, SessionLocal = await initialize_database()
|
|
||||||
if not engine:
|
|
||||||
return
|
|
||||||
|
|
||||||
async with engine.begin() as conn:
|
|
||||||
# 检查并确保 WAL 模式仍然启用
|
|
||||||
result = await conn.execute(text("PRAGMA journal_mode"))
|
|
||||||
journal_mode = result.scalar()
|
|
||||||
|
|
||||||
if journal_mode != "wal":
|
|
||||||
await conn.execute(text("PRAGMA journal_mode = WAL"))
|
|
||||||
logger.info("[SQLite] WAL 模式已重新启用")
|
|
||||||
|
|
||||||
# 优化数据库性能
|
|
||||||
await conn.execute(text("PRAGMA synchronous = NORMAL"))
|
|
||||||
await conn.execute(text("PRAGMA busy_timeout = 60000"))
|
|
||||||
await conn.execute(text("PRAGMA foreign_keys = ON"))
|
|
||||||
|
|
||||||
# 定期清理(可选,根据需要启用)
|
|
||||||
# await conn.execute(text("PRAGMA optimize"))
|
|
||||||
|
|
||||||
logger.info("[SQLite] 数据库维护完成")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[SQLite] 数据库维护失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def get_sqlite_performance_config():
|
|
||||||
"""获取 SQLite 性能优化配置"""
|
|
||||||
return {
|
|
||||||
"journal_mode": "WAL", # 提高并发性能
|
|
||||||
"synchronous": "NORMAL", # 平衡性能和安全性
|
|
||||||
"busy_timeout": 60000, # 60秒超时
|
|
||||||
"foreign_keys": "ON", # 启用外键约束
|
|
||||||
"cache_size": -10000, # 10MB 缓存
|
|
||||||
"temp_store": "MEMORY", # 临时存储使用内存
|
|
||||||
"mmap_size": 268435456, # 256MB 内存映射
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# MySQL兼容的字段类型辅助函数
|
|
||||||
def get_string_field(max_length=255, **kwargs):
|
|
||||||
"""
|
|
||||||
根据数据库类型返回合适的字符串字段
|
|
||||||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
|
||||||
"""
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
if global_config.database.database_type == "mysql":
|
|
||||||
return String(max_length, **kwargs)
|
|
||||||
else:
|
|
||||||
return Text(**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
class ChatStreams(Base):
|
|
||||||
"""聊天流模型"""
|
|
||||||
|
|
||||||
__tablename__ = "chat_streams"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
|
|
||||||
create_time = Column(Float, nullable=False)
|
|
||||||
group_platform = Column(Text, nullable=True)
|
|
||||||
group_id = Column(get_string_field(100), nullable=True, index=True)
|
|
||||||
group_name = Column(Text, nullable=True)
|
|
||||||
last_active_time = Column(Float, nullable=False)
|
|
||||||
platform = Column(Text, nullable=False)
|
|
||||||
user_platform = Column(Text, nullable=False)
|
|
||||||
user_id = Column(get_string_field(100), nullable=False, index=True)
|
|
||||||
user_nickname = Column(Text, nullable=False)
|
|
||||||
user_cardname = Column(Text, nullable=True)
|
|
||||||
energy_value = Column(Float, nullable=True, default=5.0)
|
|
||||||
sleep_pressure = Column(Float, nullable=True, default=0.0)
|
|
||||||
focus_energy = Column(Float, nullable=True, default=0.5)
|
|
||||||
# 动态兴趣度系统字段
|
|
||||||
base_interest_energy = Column(Float, nullable=True, default=0.5)
|
|
||||||
message_interest_total = Column(Float, nullable=True, default=0.0)
|
|
||||||
message_count = Column(Integer, nullable=True, default=0)
|
|
||||||
action_count = Column(Integer, nullable=True, default=0)
|
|
||||||
reply_count = Column(Integer, nullable=True, default=0)
|
|
||||||
last_interaction_time = Column(Float, nullable=True, default=None)
|
|
||||||
consecutive_no_reply = Column(Integer, nullable=True, default=0)
|
|
||||||
# 消息打断系统字段
|
|
||||||
interruption_count = Column(Integer, nullable=True, default=0)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_chatstreams_stream_id", "stream_id"),
|
|
||||||
Index("idx_chatstreams_user_id", "user_id"),
|
|
||||||
Index("idx_chatstreams_group_id", "group_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class LLMUsage(Base):
|
|
||||||
"""LLM使用记录模型"""
|
|
||||||
|
|
||||||
__tablename__ = "llm_usage"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
model_name = Column(get_string_field(100), nullable=False, index=True)
|
|
||||||
model_assign_name = Column(get_string_field(100), index=True) # 添加索引
|
|
||||||
model_api_provider = Column(get_string_field(100), index=True) # 添加索引
|
|
||||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
|
||||||
request_type = Column(get_string_field(50), nullable=False, index=True)
|
|
||||||
endpoint = Column(Text, nullable=False)
|
|
||||||
prompt_tokens = Column(Integer, nullable=False)
|
|
||||||
completion_tokens = Column(Integer, nullable=False)
|
|
||||||
time_cost = Column(Float, nullable=True)
|
|
||||||
total_tokens = Column(Integer, nullable=False)
|
|
||||||
cost = Column(Float, nullable=False)
|
|
||||||
status = Column(Text, nullable=False)
|
|
||||||
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_llmusage_model_name", "model_name"),
|
|
||||||
Index("idx_llmusage_model_assign_name", "model_assign_name"),
|
|
||||||
Index("idx_llmusage_model_api_provider", "model_api_provider"),
|
|
||||||
Index("idx_llmusage_time_cost", "time_cost"),
|
|
||||||
Index("idx_llmusage_user_id", "user_id"),
|
|
||||||
Index("idx_llmusage_request_type", "request_type"),
|
|
||||||
Index("idx_llmusage_timestamp", "timestamp"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Emoji(Base):
|
|
||||||
"""表情包模型"""
|
|
||||||
|
|
||||||
__tablename__ = "emoji"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
|
||||||
format = Column(Text, nullable=False)
|
|
||||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
|
||||||
description = Column(Text, nullable=False)
|
|
||||||
query_count = Column(Integer, nullable=False, default=0)
|
|
||||||
is_registered = Column(Boolean, nullable=False, default=False)
|
|
||||||
is_banned = Column(Boolean, nullable=False, default=False)
|
|
||||||
emotion = Column(Text, nullable=True)
|
|
||||||
record_time = Column(Float, nullable=False)
|
|
||||||
register_time = Column(Float, nullable=True)
|
|
||||||
usage_count = Column(Integer, nullable=False, default=0)
|
|
||||||
last_used_time = Column(Float, nullable=True)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_emoji_full_path", "full_path"),
|
|
||||||
Index("idx_emoji_hash", "emoji_hash"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Messages(Base):
|
|
||||||
"""消息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "messages"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
message_id = Column(get_string_field(100), nullable=False, index=True)
|
|
||||||
time = Column(Float, nullable=False)
|
|
||||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
|
||||||
reply_to = Column(Text, nullable=True)
|
|
||||||
interest_value = Column(Float, nullable=True)
|
|
||||||
key_words = Column(Text, nullable=True)
|
|
||||||
key_words_lite = Column(Text, nullable=True)
|
|
||||||
is_mentioned = Column(Boolean, nullable=True)
|
|
||||||
|
|
||||||
# 从 chat_info 扁平化而来的字段
|
|
||||||
chat_info_stream_id = Column(Text, nullable=False)
|
|
||||||
chat_info_platform = Column(Text, nullable=False)
|
|
||||||
chat_info_user_platform = Column(Text, nullable=False)
|
|
||||||
chat_info_user_id = Column(Text, nullable=False)
|
|
||||||
chat_info_user_nickname = Column(Text, nullable=False)
|
|
||||||
chat_info_user_cardname = Column(Text, nullable=True)
|
|
||||||
chat_info_group_platform = Column(Text, nullable=True)
|
|
||||||
chat_info_group_id = Column(Text, nullable=True)
|
|
||||||
chat_info_group_name = Column(Text, nullable=True)
|
|
||||||
chat_info_create_time = Column(Float, nullable=False)
|
|
||||||
chat_info_last_active_time = Column(Float, nullable=False)
|
|
||||||
|
|
||||||
# 从顶层 user_info 扁平化而来的字段
|
|
||||||
user_platform = Column(Text, nullable=True)
|
|
||||||
user_id = Column(get_string_field(100), nullable=True, index=True)
|
|
||||||
user_nickname = Column(Text, nullable=True)
|
|
||||||
user_cardname = Column(Text, nullable=True)
|
|
||||||
|
|
||||||
processed_plain_text = Column(Text, nullable=True)
|
|
||||||
display_message = Column(Text, nullable=True)
|
|
||||||
memorized_times = Column(Integer, nullable=False, default=0)
|
|
||||||
priority_mode = Column(Text, nullable=True)
|
|
||||||
priority_info = Column(Text, nullable=True)
|
|
||||||
additional_config = Column(Text, nullable=True)
|
|
||||||
is_emoji = Column(Boolean, nullable=False, default=False)
|
|
||||||
is_picid = Column(Boolean, nullable=False, default=False)
|
|
||||||
is_command = Column(Boolean, nullable=False, default=False)
|
|
||||||
is_notify = Column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
# 兴趣度系统字段
|
|
||||||
actions = Column(Text, nullable=True) # JSON格式存储动作列表
|
|
||||||
should_reply = Column(Boolean, nullable=True, default=False)
|
|
||||||
should_act = Column(Boolean, nullable=True, default=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_messages_message_id", "message_id"),
|
|
||||||
Index("idx_messages_chat_id", "chat_id"),
|
|
||||||
Index("idx_messages_time", "time"),
|
|
||||||
Index("idx_messages_user_id", "user_id"),
|
|
||||||
Index("idx_messages_should_reply", "should_reply"),
|
|
||||||
Index("idx_messages_should_act", "should_act"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ActionRecords(Base):
|
|
||||||
"""动作记录模型"""
|
|
||||||
|
|
||||||
__tablename__ = "action_records"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
action_id = Column(get_string_field(100), nullable=False, index=True)
|
|
||||||
time = Column(Float, nullable=False)
|
|
||||||
action_name = Column(Text, nullable=False)
|
|
||||||
action_data = Column(Text, nullable=False)
|
|
||||||
action_done = Column(Boolean, nullable=False, default=False)
|
|
||||||
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
|
|
||||||
action_prompt_display = Column(Text, nullable=False)
|
|
||||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
|
||||||
chat_info_stream_id = Column(Text, nullable=False)
|
|
||||||
chat_info_platform = Column(Text, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_actionrecords_action_id", "action_id"),
|
|
||||||
Index("idx_actionrecords_chat_id", "chat_id"),
|
|
||||||
Index("idx_actionrecords_time", "time"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Images(Base):
|
|
||||||
"""图像信息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "images"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
image_id = Column(Text, nullable=False, default="")
|
|
||||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
|
||||||
description = Column(Text, nullable=True)
|
|
||||||
path = Column(get_string_field(500), nullable=False, unique=True)
|
|
||||||
count = Column(Integer, nullable=False, default=1)
|
|
||||||
timestamp = Column(Float, nullable=False)
|
|
||||||
type = Column(Text, nullable=False)
|
|
||||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_images_emoji_hash", "emoji_hash"),
|
|
||||||
Index("idx_images_path", "path"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ImageDescriptions(Base):
|
|
||||||
"""图像描述信息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "image_descriptions"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
type = Column(Text, nullable=False)
|
|
||||||
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
|
|
||||||
description = Column(Text, nullable=False)
|
|
||||||
timestamp = Column(Float, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),)
|
|
||||||
|
|
||||||
|
|
||||||
class Videos(Base):
|
|
||||||
"""视频信息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "videos"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
video_id = Column(Text, nullable=False, default="")
|
|
||||||
video_hash = Column(get_string_field(64), nullable=False, index=True, unique=True)
|
|
||||||
description = Column(Text, nullable=True)
|
|
||||||
count = Column(Integer, nullable=False, default=1)
|
|
||||||
timestamp = Column(Float, nullable=False)
|
|
||||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
# 视频特有属性
|
|
||||||
duration = Column(Float, nullable=True) # 视频时长(秒)
|
|
||||||
frame_count = Column(Integer, nullable=True) # 总帧数
|
|
||||||
fps = Column(Float, nullable=True) # 帧率
|
|
||||||
resolution = Column(Text, nullable=True) # 分辨率
|
|
||||||
file_size = Column(Integer, nullable=True) # 文件大小(字节)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_videos_video_hash", "video_hash"),
|
|
||||||
Index("idx_videos_timestamp", "timestamp"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OnlineTime(Base):
|
|
||||||
"""在线时长记录模型"""
|
|
||||||
|
|
||||||
__tablename__ = "online_time"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
|
|
||||||
duration = Column(Integer, nullable=False)
|
|
||||||
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
end_timestamp = Column(DateTime, nullable=False, index=True)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),)
|
|
||||||
|
|
||||||
|
|
||||||
class PersonInfo(Base):
|
|
||||||
"""人物信息模型"""
|
|
||||||
|
|
||||||
__tablename__ = "person_info"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
|
|
||||||
person_name = Column(Text, nullable=True)
|
|
||||||
name_reason = Column(Text, nullable=True)
|
|
||||||
platform = Column(Text, nullable=False)
|
|
||||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
|
||||||
nickname = Column(Text, nullable=True)
|
|
||||||
impression = Column(Text, nullable=True)
|
|
||||||
short_impression = Column(Text, nullable=True)
|
|
||||||
points = Column(Text, nullable=True)
|
|
||||||
forgotten_points = Column(Text, nullable=True)
|
|
||||||
info_list = Column(Text, nullable=True)
|
|
||||||
know_times = Column(Float, nullable=True)
|
|
||||||
know_since = Column(Float, nullable=True)
|
|
||||||
last_know = Column(Float, nullable=True)
|
|
||||||
attitude = Column(Integer, nullable=True, default=50)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_personinfo_person_id", "person_id"),
|
|
||||||
Index("idx_personinfo_user_id", "user_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BotPersonalityInterests(Base):
|
|
||||||
"""机器人人格兴趣标签模型"""
|
|
||||||
|
|
||||||
__tablename__ = "bot_personality_interests"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
personality_id = Column(get_string_field(100), nullable=False, index=True)
|
|
||||||
personality_description = Column(Text, nullable=False)
|
|
||||||
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
|
|
||||||
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
|
|
||||||
version = Column(Integer, nullable=False, default=1)
|
|
||||||
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_botpersonality_personality_id", "personality_id"),
|
|
||||||
Index("idx_botpersonality_version", "version"),
|
|
||||||
Index("idx_botpersonality_last_updated", "last_updated"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Memory(Base):
|
|
||||||
"""记忆模型"""
|
|
||||||
|
|
||||||
__tablename__ = "memory"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
memory_id = Column(get_string_field(64), nullable=False, index=True)
|
|
||||||
chat_id = Column(Text, nullable=True)
|
|
||||||
memory_text = Column(Text, nullable=True)
|
|
||||||
keywords = Column(Text, nullable=True)
|
|
||||||
create_time = Column(Float, nullable=True)
|
|
||||||
last_view_time = Column(Float, nullable=True)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_memory_memory_id", "memory_id"),)
|
|
||||||
|
|
||||||
|
|
||||||
class Expression(Base):
|
|
||||||
"""表达风格模型"""
|
|
||||||
|
|
||||||
__tablename__ = "expression"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
situation: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
style: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
count: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
|
||||||
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
|
||||||
type: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
create_date: Mapped[float | None] = mapped_column(Float, nullable=True)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
|
|
||||||
|
|
||||||
|
|
||||||
class ThinkingLog(Base):
|
|
||||||
"""思考日志模型"""
|
|
||||||
|
|
||||||
__tablename__ = "thinking_logs"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
|
||||||
trigger_text = Column(Text, nullable=True)
|
|
||||||
response_text = Column(Text, nullable=True)
|
|
||||||
trigger_info_json = Column(Text, nullable=True)
|
|
||||||
response_info_json = Column(Text, nullable=True)
|
|
||||||
timing_results_json = Column(Text, nullable=True)
|
|
||||||
chat_history_json = Column(Text, nullable=True)
|
|
||||||
chat_history_in_thinking_json = Column(Text, nullable=True)
|
|
||||||
chat_history_after_response_json = Column(Text, nullable=True)
|
|
||||||
heartflow_data_json = Column(Text, nullable=True)
|
|
||||||
reasoning_data_json = Column(Text, nullable=True)
|
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphNodes(Base):
|
|
||||||
"""记忆图节点模型"""
|
|
||||||
|
|
||||||
__tablename__ = "graph_nodes"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
|
||||||
memory_items = Column(Text, nullable=False)
|
|
||||||
hash = Column(Text, nullable=False)
|
|
||||||
weight = Column(Float, nullable=False, default=1.0)
|
|
||||||
created_time = Column(Float, nullable=False)
|
|
||||||
last_modified = Column(Float, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_graphnodes_concept", "concept"),)
|
|
||||||
|
|
||||||
|
|
||||||
class GraphEdges(Base):
|
|
||||||
"""记忆图边模型"""
|
|
||||||
|
|
||||||
__tablename__ = "graph_edges"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
source = Column(get_string_field(255), nullable=False, index=True)
|
|
||||||
target = Column(get_string_field(255), nullable=False, index=True)
|
|
||||||
strength = Column(Integer, nullable=False)
|
|
||||||
hash = Column(Text, nullable=False)
|
|
||||||
created_time = Column(Float, nullable=False)
|
|
||||||
last_modified = Column(Float, nullable=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_graphedges_source", "source"),
|
|
||||||
Index("idx_graphedges_target", "target"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Schedule(Base):
|
|
||||||
"""日程模型"""
|
|
||||||
|
|
||||||
__tablename__ = "schedule"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式
|
|
||||||
schedule_data = Column(Text, nullable=False) # JSON格式的日程数据
|
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (Index("idx_schedule_date", "date"),)
|
|
||||||
|
|
||||||
|
|
||||||
class MaiZoneScheduleStatus(Base):
|
|
||||||
"""麦麦空间日程处理状态模型"""
|
|
||||||
|
|
||||||
__tablename__ = "maizone_schedule_status"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
datetime_hour = Column(
|
|
||||||
get_string_field(13), nullable=False, unique=True, index=True
|
|
||||||
) # YYYY-MM-DD HH格式,精确到小时
|
|
||||||
activity = Column(Text, nullable=False) # 该小时的活动内容
|
|
||||||
is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理
|
|
||||||
processed_at = Column(DateTime, nullable=True) # 处理时间
|
|
||||||
story_content = Column(Text, nullable=True) # 生成的说说内容
|
|
||||||
send_success = Column(Boolean, nullable=False, default=False) # 是否发送成功
|
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_maizone_datetime_hour", "datetime_hour"),
|
|
||||||
Index("idx_maizone_is_processed", "is_processed"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BanUser(Base):
|
|
||||||
"""被禁用用户模型
|
|
||||||
|
|
||||||
使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型,
|
|
||||||
避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "ban_users"
|
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
|
||||||
violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
|
||||||
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_violation_num", "violation_num"),
|
|
||||||
Index("idx_banuser_user_id", "user_id"),
|
|
||||||
Index("idx_banuser_platform", "platform"),
|
|
||||||
Index("idx_banuser_platform_user_id", "platform", "user_id"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AntiInjectionStats(Base):
|
|
||||||
"""反注入系统统计模型"""
|
|
||||||
|
|
||||||
__tablename__ = "anti_injection_stats"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
total_messages = Column(Integer, nullable=False, default=0)
|
|
||||||
"""总处理消息数"""
|
|
||||||
|
|
||||||
detected_injections = Column(Integer, nullable=False, default=0)
|
|
||||||
"""检测到的注入攻击数"""
|
|
||||||
|
|
||||||
blocked_messages = Column(Integer, nullable=False, default=0)
|
|
||||||
"""被阻止的消息数"""
|
|
||||||
|
|
||||||
shielded_messages = Column(Integer, nullable=False, default=0)
|
|
||||||
"""被加盾的消息数"""
|
|
||||||
|
|
||||||
processing_time_total = Column(Float, nullable=False, default=0.0)
|
|
||||||
"""总处理时间"""
|
|
||||||
|
|
||||||
total_process_time = Column(Float, nullable=False, default=0.0)
|
|
||||||
"""累计总处理时间"""
|
|
||||||
|
|
||||||
last_process_time = Column(Float, nullable=False, default=0.0)
|
|
||||||
"""最近一次处理时间"""
|
|
||||||
|
|
||||||
error_count = Column(Integer, nullable=False, default=0)
|
|
||||||
"""错误计数"""
|
|
||||||
|
|
||||||
start_time = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
"""统计开始时间"""
|
|
||||||
|
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
"""记录创建时间"""
|
|
||||||
|
|
||||||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
|
||||||
"""记录更新时间"""
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_anti_injection_stats_created_at", "created_at"),
|
|
||||||
Index("idx_anti_injection_stats_updated_at", "updated_at"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CacheEntries(Base):
|
|
||||||
"""工具缓存条目模型"""
|
|
||||||
|
|
||||||
__tablename__ = "cache_entries"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
|
||||||
"""缓存键,包含工具名、参数和代码哈希"""
|
|
||||||
|
|
||||||
cache_value = Column(Text, nullable=False)
|
|
||||||
"""缓存的数据,JSON格式"""
|
|
||||||
|
|
||||||
expires_at = Column(Float, nullable=False, index=True)
|
|
||||||
"""过期时间戳"""
|
|
||||||
|
|
||||||
tool_name = Column(get_string_field(100), nullable=False, index=True)
|
|
||||||
"""工具名称"""
|
|
||||||
|
|
||||||
created_at = Column(Float, nullable=False, default=lambda: time.time())
|
|
||||||
"""创建时间戳"""
|
|
||||||
|
|
||||||
last_accessed = Column(Float, nullable=False, default=lambda: time.time())
|
|
||||||
"""最后访问时间戳"""
|
|
||||||
|
|
||||||
access_count = Column(Integer, nullable=False, default=0)
|
|
||||||
"""访问次数"""
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_cache_entries_key", "cache_key"),
|
|
||||||
Index("idx_cache_entries_expires_at", "expires_at"),
|
|
||||||
Index("idx_cache_entries_tool_name", "tool_name"),
|
|
||||||
Index("idx_cache_entries_created_at", "created_at"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MonthlyPlan(Base):
|
|
||||||
"""月度计划模型"""
|
|
||||||
|
|
||||||
__tablename__ = "monthly_plans"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
plan_text = Column(Text, nullable=False)
|
|
||||||
target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM"
|
|
||||||
status = Column(
|
|
||||||
get_string_field(20), nullable=False, default="active", index=True
|
|
||||||
) # 'active', 'completed', 'archived'
|
|
||||||
usage_count = Column(Integer, nullable=False, default=0)
|
|
||||||
last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format
|
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
|
|
||||||
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
|
|
||||||
is_deleted = Column(Boolean, nullable=False, default=False)
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
|
|
||||||
Index("idx_monthlyplan_last_used_date", "last_used_date"),
|
|
||||||
Index("idx_monthlyplan_usage_count", "usage_count"),
|
|
||||||
# 保留旧索引以兼容
|
|
||||||
Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# 数据库引擎和会话管理
|
|
||||||
_engine = None
|
|
||||||
_SessionLocal = None
|
|
||||||
|
|
||||||
|
|
||||||
def get_database_url():
|
|
||||||
"""获取数据库连接URL"""
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
config = global_config.database
|
|
||||||
|
|
||||||
if config.database_type == "mysql":
|
|
||||||
# 对用户名和密码进行URL编码,处理特殊字符
|
|
||||||
from urllib.parse import quote_plus
|
|
||||||
|
|
||||||
encoded_user = quote_plus(config.mysql_user)
|
|
||||||
encoded_password = quote_plus(config.mysql_password)
|
|
||||||
|
|
||||||
# 检查是否配置了Unix socket连接
|
|
||||||
if config.mysql_unix_socket:
|
|
||||||
# 使用Unix socket连接
|
|
||||||
encoded_socket = quote_plus(config.mysql_unix_socket)
|
|
||||||
return (
|
|
||||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
|
||||||
f"@/{config.mysql_database}"
|
|
||||||
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 使用标准TCP连接
|
|
||||||
return (
|
|
||||||
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
|
||||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
|
||||||
f"?charset={config.mysql_charset}"
|
|
||||||
)
|
|
||||||
else: # SQLite
|
|
||||||
# 如果是相对路径,则相对于项目根目录
|
|
||||||
if not os.path.isabs(config.sqlite_path):
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
|
||||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
|
||||||
else:
|
|
||||||
db_path = config.sqlite_path
|
|
||||||
|
|
||||||
# 确保数据库目录存在
|
|
||||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
|
||||||
|
|
||||||
return f"sqlite+aiosqlite:///{db_path}"
|
|
||||||
|
|
||||||
|
|
||||||
async def initialize_database():
|
|
||||||
"""初始化异步数据库引擎和会话"""
|
|
||||||
global _engine, _SessionLocal
|
|
||||||
|
|
||||||
if _engine is not None:
|
|
||||||
return _engine, _SessionLocal
|
|
||||||
|
|
||||||
database_url = get_database_url()
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
config = global_config.database
|
|
||||||
|
|
||||||
# 配置引擎参数
|
|
||||||
engine_kwargs: dict[str, Any] = {
|
|
||||||
"echo": False, # 生产环境关闭SQL日志
|
|
||||||
"future": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.database_type == "mysql":
|
|
||||||
# MySQL连接池配置 - 异步引擎使用默认连接池
|
|
||||||
engine_kwargs.update(
|
|
||||||
{
|
|
||||||
"pool_size": config.connection_pool_size,
|
|
||||||
"max_overflow": config.connection_pool_size * 2,
|
|
||||||
"pool_timeout": config.connection_timeout,
|
|
||||||
"pool_recycle": 3600, # 1小时回收连接
|
|
||||||
"pool_pre_ping": True, # 连接前ping检查
|
|
||||||
"connect_args": {
|
|
||||||
"autocommit": config.mysql_autocommit,
|
|
||||||
"charset": config.mysql_charset,
|
|
||||||
"connect_timeout": config.connection_timeout,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# SQLite配置 - aiosqlite不支持连接池参数
|
|
||||||
engine_kwargs.update(
|
|
||||||
{
|
|
||||||
"connect_args": {
|
|
||||||
"check_same_thread": False,
|
|
||||||
"timeout": 60, # 增加超时时间
|
|
||||||
},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_engine = create_async_engine(database_url, **engine_kwargs)
|
|
||||||
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
|
||||||
|
|
||||||
# 调用新的迁移函数,它会处理表的创建和列的添加
|
|
||||||
from src.common.database.db_migration import check_and_migrate_database
|
|
||||||
|
|
||||||
await check_and_migrate_database()
|
|
||||||
|
|
||||||
# 如果是 SQLite,启用 WAL 模式以提高并发性能
|
|
||||||
if config.database_type == "sqlite":
|
|
||||||
await enable_sqlite_wal_mode(_engine)
|
|
||||||
|
|
||||||
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
|
|
||||||
return _engine, _SessionLocal
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def get_db_session() -> AsyncGenerator[AsyncSession]:
|
|
||||||
"""
|
|
||||||
异步数据库会话上下文管理器。
|
|
||||||
在初始化失败时会yield None,调用方需要检查会话是否为None。
|
|
||||||
|
|
||||||
现在使用透明的连接池管理器来复用现有连接,提高并发性能。
|
|
||||||
"""
|
|
||||||
SessionLocal = None
|
|
||||||
try:
|
|
||||||
_, SessionLocal = await initialize_database()
|
|
||||||
if not SessionLocal:
|
|
||||||
raise RuntimeError("数据库会话工厂 (_SessionLocal) 未初始化。")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"数据库初始化失败,无法创建会话: {e}")
|
|
||||||
raise
|
|
||||||
|
|
||||||
# 使用连接池管理器获取会话
|
|
||||||
pool_manager = get_connection_pool_manager()
|
|
||||||
|
|
||||||
async with pool_manager.get_session(SessionLocal) as session:
|
|
||||||
# 对于 SQLite,在会话开始时设置 PRAGMA(仅对新连接)
|
|
||||||
from src.config.config import global_config
|
|
||||||
|
|
||||||
if global_config.database.database_type == "sqlite":
|
|
||||||
try:
|
|
||||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
|
||||||
await session.execute(text("PRAGMA foreign_keys = ON"))
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}")
|
|
||||||
|
|
||||||
yield session
|
|
||||||
|
|
||||||
|
|
||||||
async def get_engine():
|
|
||||||
"""获取异步数据库引擎"""
|
|
||||||
engine, _ = await initialize_database()
|
|
||||||
return engine
|
|
||||||
|
|
||||||
|
|
||||||
class PermissionNodes(Base):
|
|
||||||
"""权限节点模型"""
|
|
||||||
|
|
||||||
__tablename__ = "permission_nodes"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称
|
|
||||||
description = Column(Text, nullable=False) # 权限描述
|
|
||||||
plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件
|
|
||||||
default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_permission_plugin", "plugin_name"),
|
|
||||||
Index("idx_permission_node", "node_name"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UserPermissions(Base):
|
|
||||||
"""用户权限模型"""
|
|
||||||
|
|
||||||
__tablename__ = "user_permissions"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型
|
|
||||||
user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID
|
|
||||||
permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称
|
|
||||||
granted = Column(Boolean, default=True, nullable=False) # 是否授权
|
|
||||||
granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间
|
|
||||||
granted_by = Column(get_string_field(100), nullable=True) # 授权者信息
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_user_platform_id", "platform", "user_id"),
|
|
||||||
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
|
||||||
Index("idx_permission_granted", "permission_node", "granted"),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class UserRelationships(Base):
|
|
||||||
"""用户关系模型 - 存储用户与bot的关系数据"""
|
|
||||||
|
|
||||||
__tablename__ = "user_relationships"
|
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
|
||||||
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
|
|
||||||
user_name = Column(get_string_field(100), nullable=True) # 用户名
|
|
||||||
relationship_text = Column(Text, nullable=True) # 关系印象描述
|
|
||||||
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
|
||||||
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
|
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
|
||||||
|
|
||||||
__table_args__ = (
|
|
||||||
Index("idx_user_relationship_id", "user_id"),
|
|
||||||
Index("idx_relationship_score", "relationship_score"),
|
|
||||||
Index("idx_relationship_updated", "last_updated"),
|
|
||||||
)
|
|
||||||
Reference in New Issue
Block a user