初始化
This commit is contained in:
@@ -1,14 +1,103 @@
|
||||
import os
|
||||
from pymongo import MongoClient
|
||||
from peewee import SqliteDatabase
|
||||
from pymongo.database import Database
|
||||
from rich.traceback import install
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# SQLAlchemy相关导入
|
||||
from src.common.database.sqlalchemy_init import initialize_database_compat
|
||||
from src.common.database.sqlalchemy_models import get_engine, get_session
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
_client = None
|
||||
_db = None
|
||||
_sql_engine = None
|
||||
|
||||
logger = get_logger("database")
|
||||
|
||||
# 兼容性:为了不破坏现有代码,保留db变量但指向SQLAlchemy
|
||||
class DatabaseProxy:
|
||||
"""数据库代理类,提供Peewee到SQLAlchemy的兼容性接口"""
|
||||
|
||||
def __init__(self):
|
||||
self._engine = None
|
||||
self._session = None
|
||||
|
||||
def initialize(self, *args, **kwargs):
|
||||
"""初始化数据库连接"""
|
||||
return initialize_database_compat()
|
||||
|
||||
def connect(self, reuse_if_open=True):
|
||||
"""连接数据库(兼容性方法)"""
|
||||
try:
|
||||
self._engine = get_engine()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接失败: {e}")
|
||||
return False
|
||||
|
||||
def is_closed(self):
|
||||
"""检查数据库是否关闭(兼容性方法)"""
|
||||
return self._engine is None
|
||||
|
||||
def create_tables(self, models, safe=True):
|
||||
"""创建表(兼容性方法)"""
|
||||
try:
|
||||
from src.common.database.sqlalchemy_models import Base
|
||||
engine = get_engine()
|
||||
Base.metadata.create_all(bind=engine)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"创建表失败: {e}")
|
||||
return False
|
||||
|
||||
def table_exists(self, model):
|
||||
"""检查表是否存在(兼容性方法)"""
|
||||
try:
|
||||
from sqlalchemy import inspect
|
||||
engine = get_engine()
|
||||
inspector = inspect(engine)
|
||||
table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower())
|
||||
return table_name in inspector.get_table_names()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def execute_sql(self, sql):
|
||||
"""执行SQL(兼容性方法)"""
|
||||
try:
|
||||
from sqlalchemy import text
|
||||
session = get_session()
|
||||
result = session.execute(text(sql))
|
||||
session.close()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"执行SQL失败: {e}")
|
||||
raise
|
||||
|
||||
def atomic(self):
|
||||
"""事务上下文管理器(兼容性方法)"""
|
||||
return SQLAlchemyTransaction()
|
||||
|
||||
class SQLAlchemyTransaction:
|
||||
"""SQLAlchemy事务上下文管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self.session = None
|
||||
|
||||
def __enter__(self):
|
||||
self.session = get_session()
|
||||
return self.session
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
if exc_type is None:
|
||||
self.session.commit()
|
||||
else:
|
||||
self.session.rollback()
|
||||
self.session.close()
|
||||
|
||||
# 创建全局数据库代理实例
|
||||
db = DatabaseProxy()
|
||||
|
||||
def __create_database_instance():
|
||||
uri = os.getenv("MONGODB_URI")
|
||||
@@ -39,7 +128,7 @@ def __create_database_instance():
|
||||
|
||||
|
||||
def get_db():
|
||||
"""获取数据库连接实例,延迟初始化。"""
|
||||
"""获取MongoDB连接实例,延迟初始化。"""
|
||||
global _client, _db
|
||||
if _client is None:
|
||||
_client = __create_database_instance()
|
||||
@@ -47,6 +136,47 @@ def get_db():
|
||||
return _db
|
||||
|
||||
|
||||
def initialize_sql_database(database_config):
|
||||
"""
|
||||
根据配置初始化SQL数据库连接(SQLAlchemy版本)
|
||||
|
||||
Args:
|
||||
database_config: DatabaseConfig对象
|
||||
"""
|
||||
global _sql_engine
|
||||
|
||||
try:
|
||||
logger.info("使用SQLAlchemy初始化SQL数据库...")
|
||||
|
||||
# 记录数据库配置信息
|
||||
if database_config.database_type == "mysql":
|
||||
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
|
||||
logger.info("MySQL数据库连接配置:")
|
||||
logger.info(f" 连接信息: {connection_info}")
|
||||
logger.info(f" 字符集: {database_config.mysql_charset}")
|
||||
else:
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
if not os.path.isabs(database_config.sqlite_path):
|
||||
db_path = os.path.join(ROOT_PATH, database_config.sqlite_path)
|
||||
else:
|
||||
db_path = database_config.sqlite_path
|
||||
logger.info("SQLite数据库连接配置:")
|
||||
logger.info(f" 数据库文件: {db_path}")
|
||||
|
||||
# 使用SQLAlchemy初始化
|
||||
success = initialize_database_compat()
|
||||
if success:
|
||||
_sql_engine = get_engine()
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
else:
|
||||
logger.error("SQLAlchemy数据库初始化失败")
|
||||
|
||||
return _sql_engine
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"初始化SQL数据库失败: {e}")
|
||||
return None
|
||||
|
||||
class DBWrapper:
|
||||
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
|
||||
|
||||
@@ -57,26 +187,6 @@ class DBWrapper:
|
||||
return get_db()[key] # type: ignore
|
||||
|
||||
|
||||
# 全局数据库访问点
|
||||
# 全局MongoDB数据库访问点
|
||||
memory_db: Database = DBWrapper() # type: ignore
|
||||
|
||||
# 定义数据库文件路径
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
_DB_DIR = os.path.join(ROOT_PATH, "data")
|
||||
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(_DB_DIR, exist_ok=True)
|
||||
|
||||
# 全局 Peewee SQLite 数据库访问点
|
||||
db = SqliteDatabase(
|
||||
_DB_FILE,
|
||||
pragmas={
|
||||
"journal_mode": "wal", # WAL模式提高并发性能
|
||||
"cache_size": -64 * 1000, # 64MB缓存
|
||||
"foreign_keys": 1,
|
||||
"ignore_check_constraints": 0,
|
||||
"synchronous": 0, # 异步写入提高性能
|
||||
"busy_timeout": 1000, # 1秒超时而不是3秒
|
||||
},
|
||||
)
|
||||
|
||||
420
src/common/database/sqlalchemy_database_api.py
Normal file
420
src/common/database/sqlalchemy_database_api.py
Normal file
@@ -0,0 +1,420 @@
|
||||
"""SQLAlchemy数据库API模块
|
||||
|
||||
提供基于SQLAlchemy的数据库操作,替换Peewee以解决MySQL连接问题
|
||||
支持自动重连、连接池管理和更好的错误处理
|
||||
"""
|
||||
|
||||
import traceback
|
||||
import time
|
||||
from typing import Dict, List, Any, Union, Type, Optional
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError
|
||||
from sqlalchemy import desc, asc, func, and_, or_
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
|
||||
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
|
||||
Expression, ThinkingLog, GraphNodes, GraphEdges,get_session
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_database_api")
|
||||
|
||||
# 模型映射表,用于通过名称获取模型类
|
||||
MODEL_MAPPING = {
|
||||
'Messages': Messages,
|
||||
'ActionRecords': ActionRecords,
|
||||
'PersonInfo': PersonInfo,
|
||||
'ChatStreams': ChatStreams,
|
||||
'LLMUsage': LLMUsage,
|
||||
'Emoji': Emoji,
|
||||
'Images': Images,
|
||||
'ImageDescriptions': ImageDescriptions,
|
||||
'OnlineTime': OnlineTime,
|
||||
'Memory': Memory,
|
||||
'Expression': Expression,
|
||||
'ThinkingLog': ThinkingLog,
|
||||
'GraphNodes': GraphNodes,
|
||||
'GraphEdges': GraphEdges,
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session():
|
||||
"""数据库会话上下文管理器,自动处理事务和连接错误"""
|
||||
session = None
|
||||
max_retries = 3
|
||||
retry_delay = 1.0
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
session = get_session()
|
||||
yield session
|
||||
session.commit()
|
||||
break
|
||||
except (DisconnectionError, OperationalError) as e:
|
||||
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
|
||||
if session:
|
||||
session.rollback()
|
||||
session.close()
|
||||
if attempt < max_retries - 1:
|
||||
time.sleep(retry_delay * (attempt + 1))
|
||||
else:
|
||||
raise
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]):
|
||||
"""构建查询过滤条件"""
|
||||
conditions = []
|
||||
|
||||
for field_name, value in filters.items():
|
||||
if not hasattr(model_class, field_name):
|
||||
logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'")
|
||||
continue
|
||||
|
||||
field = getattr(model_class, field_name)
|
||||
|
||||
if isinstance(value, dict):
|
||||
# 处理 MongoDB 风格的操作符
|
||||
for op, op_value in value.items():
|
||||
if op == "$gt":
|
||||
conditions.append(field > op_value)
|
||||
elif op == "$lt":
|
||||
conditions.append(field < op_value)
|
||||
elif op == "$gte":
|
||||
conditions.append(field >= op_value)
|
||||
elif op == "$lte":
|
||||
conditions.append(field <= op_value)
|
||||
elif op == "$ne":
|
||||
conditions.append(field != op_value)
|
||||
elif op == "$in":
|
||||
conditions.append(field.in_(op_value))
|
||||
elif op == "$nin":
|
||||
conditions.append(~field.in_(op_value))
|
||||
else:
|
||||
logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')")
|
||||
else:
|
||||
# 直接相等比较
|
||||
conditions.append(field == value)
|
||||
|
||||
return conditions
|
||||
|
||||
|
||||
async def db_query(
|
||||
model_class: Type[Base],
|
||||
data: Optional[Dict[str, Any]] = None,
|
||||
query_type: Optional[str] = "get",
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[List[str]] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""执行数据库查询操作
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 用于创建或更新的数据字典
|
||||
query_type: 查询类型 ("get", "create", "update", "delete", "count")
|
||||
filters: 过滤条件字典
|
||||
limit: 限制结果数量
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
根据查询类型返回相应结果
|
||||
"""
|
||||
try:
|
||||
if query_type not in ["get", "create", "update", "delete", "count"]:
|
||||
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
|
||||
|
||||
with get_db_session() as session:
|
||||
if query_type == "get":
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 应用排序
|
||||
if order_by:
|
||||
for field_name in order_by:
|
||||
if field_name.startswith("-"):
|
||||
field_name = field_name[1:]
|
||||
if hasattr(model_class, field_name):
|
||||
query = query.order_by(desc(getattr(model_class, field_name)))
|
||||
else:
|
||||
if hasattr(model_class, field_name):
|
||||
query = query.order_by(asc(getattr(model_class, field_name)))
|
||||
|
||||
# 应用限制
|
||||
if limit and limit > 0:
|
||||
query = query.limit(limit)
|
||||
|
||||
# 执行查询
|
||||
results = query.all()
|
||||
|
||||
# 转换为字典格式
|
||||
result_dicts = []
|
||||
for result in results:
|
||||
result_dict = {}
|
||||
for column in result.__table__.columns:
|
||||
result_dict[column.name] = getattr(result, column.name)
|
||||
result_dicts.append(result_dict)
|
||||
|
||||
if single_result:
|
||||
return result_dicts[0] if result_dicts else None
|
||||
return result_dicts
|
||||
|
||||
elif query_type == "create":
|
||||
if not data:
|
||||
raise ValueError("创建记录需要提供data参数")
|
||||
|
||||
# 创建新记录
|
||||
new_record = model_class(**data)
|
||||
session.add(new_record)
|
||||
session.flush() # 获取自动生成的ID
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in new_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(new_record, column.name)
|
||||
return result_dict
|
||||
|
||||
elif query_type == "update":
|
||||
if not data:
|
||||
raise ValueError("更新记录需要提供data参数")
|
||||
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 执行更新
|
||||
affected_rows = query.update(data)
|
||||
return affected_rows
|
||||
|
||||
elif query_type == "delete":
|
||||
query = session.query(model_class)
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
query = query.filter(and_(*conditions))
|
||||
|
||||
# 执行删除
|
||||
affected_rows = query.delete()
|
||||
return affected_rows
|
||||
|
||||
elif query_type == "count":
|
||||
query = session.query(func.count(model_class.id))
|
||||
|
||||
# 应用过滤条件
|
||||
if filters:
|
||||
base_query = session.query(model_class)
|
||||
conditions = build_filters(session, model_class, filters)
|
||||
if conditions:
|
||||
base_query = base_query.filter(and_(*conditions))
|
||||
query = session.query(func.count()).select_from(base_query.subquery())
|
||||
|
||||
return query.scalar()
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
# 根据查询类型返回合适的默认值
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
elif query_type in ["create", "update", "delete", "count"]:
|
||||
return None
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 意外错误: {e}")
|
||||
traceback.print_exc()
|
||||
|
||||
if query_type == "get":
|
||||
return None if single_result else []
|
||||
return None
|
||||
|
||||
|
||||
async def db_save(
|
||||
model_class: Type[Base],
|
||||
data: Dict[str, Any],
|
||||
key_field: Optional[str] = None,
|
||||
key_value: Optional[Any] = None
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""保存数据到数据库(创建或更新)
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
data: 要保存的数据字典
|
||||
key_field: 用于查找现有记录的字段名
|
||||
key_value: 用于查找现有记录的字段值
|
||||
|
||||
Returns:
|
||||
保存后的记录数据或None
|
||||
"""
|
||||
try:
|
||||
with get_db_session() as session:
|
||||
# 如果提供了key_field和key_value,尝试更新现有记录
|
||||
if key_field and key_value is not None:
|
||||
if hasattr(model_class, key_field):
|
||||
existing_record = session.query(model_class).filter(
|
||||
getattr(model_class, key_field) == key_value
|
||||
).first()
|
||||
|
||||
if existing_record:
|
||||
# 更新现有记录
|
||||
for field, value in data.items():
|
||||
if hasattr(existing_record, field):
|
||||
setattr(existing_record, field, value)
|
||||
|
||||
session.flush()
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in existing_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(existing_record, column.name)
|
||||
return result_dict
|
||||
|
||||
# 创建新记录
|
||||
new_record = model_class(**data)
|
||||
session.add(new_record)
|
||||
session.flush()
|
||||
|
||||
# 转换为字典格式返回
|
||||
result_dict = {}
|
||||
for column in new_record.__table__.columns:
|
||||
result_dict[column.name] = getattr(new_record, column.name)
|
||||
return result_dict
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 保存时意外错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
async def db_get(
|
||||
model_class: Type[Base],
|
||||
filters: Optional[Dict[str, Any]] = None,
|
||||
limit: Optional[int] = None,
|
||||
order_by: Optional[str] = None,
|
||||
single_result: Optional[bool] = False,
|
||||
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
|
||||
"""从数据库获取记录
|
||||
|
||||
Args:
|
||||
model_class: SQLAlchemy模型类
|
||||
filters: 过滤条件
|
||||
limit: 结果数量限制
|
||||
order_by: 排序字段,前缀'-'表示降序
|
||||
single_result: 是否只返回单个结果
|
||||
|
||||
Returns:
|
||||
记录数据或None
|
||||
"""
|
||||
order_by_list = [order_by] if order_by else None
|
||||
return await db_query(
|
||||
model_class=model_class,
|
||||
query_type="get",
|
||||
filters=filters,
|
||||
limit=limit,
|
||||
order_by=order_by_list,
|
||||
single_result=single_result
|
||||
)
|
||||
|
||||
|
||||
async def store_action_info(
|
||||
chat_stream=None,
|
||||
action_build_into_prompt: bool = False,
|
||||
action_prompt_display: str = "",
|
||||
action_done: bool = True,
|
||||
thinking_id: str = "",
|
||||
action_data: Optional[dict] = None,
|
||||
action_name: str = "",
|
||||
) -> Optional[Dict[str, Any]]:
|
||||
"""存储动作信息到数据库
|
||||
|
||||
Args:
|
||||
chat_stream: 聊天流对象
|
||||
action_build_into_prompt: 是否将此动作构建到提示中
|
||||
action_prompt_display: 动作的提示显示文本
|
||||
action_done: 动作是否完成
|
||||
thinking_id: 关联的思考ID
|
||||
action_data: 动作数据字典
|
||||
action_name: 动作名称
|
||||
|
||||
Returns:
|
||||
保存的记录数据或None
|
||||
"""
|
||||
try:
|
||||
import json
|
||||
|
||||
# 构建动作记录数据
|
||||
record_data = {
|
||||
"action_id": thinking_id or str(int(time.time() * 1000000)),
|
||||
"time": time.time(),
|
||||
"action_name": action_name,
|
||||
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
|
||||
"action_done": action_done,
|
||||
"action_build_into_prompt": action_build_into_prompt,
|
||||
"action_prompt_display": action_prompt_display,
|
||||
}
|
||||
|
||||
# 从chat_stream获取聊天信息
|
||||
if chat_stream:
|
||||
record_data.update({
|
||||
"chat_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
|
||||
"chat_info_platform": getattr(chat_stream, "platform", ""),
|
||||
})
|
||||
else:
|
||||
record_data.update({
|
||||
"chat_id": "",
|
||||
"chat_info_stream_id": "",
|
||||
"chat_info_platform": "",
|
||||
})
|
||||
|
||||
# 保存记录
|
||||
saved_record = await db_save(
|
||||
ActionRecords,
|
||||
data=record_data,
|
||||
key_field="action_id",
|
||||
key_value=record_data["action_id"]
|
||||
)
|
||||
|
||||
if saved_record:
|
||||
logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
|
||||
else:
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}")
|
||||
|
||||
return saved_record
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
|
||||
# 兼容性函数,方便从Peewee迁移
|
||||
def get_model_class(model_name: str) -> Optional[Type[Base]]:
|
||||
"""根据模型名称获取模型类"""
|
||||
return MODEL_MAPPING.get(model_name)
|
||||
158
src/common/database/sqlalchemy_init.py
Normal file
158
src/common/database/sqlalchemy_init.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""SQLAlchemy数据库初始化模块
|
||||
|
||||
替换Peewee的数据库初始化逻辑
|
||||
提供统一的数据库初始化接口
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from src.common.logger import get_logger
|
||||
from src.common.database.sqlalchemy_models import (
|
||||
Base, get_engine, get_session, initialize_database
|
||||
)
|
||||
|
||||
logger = get_logger("sqlalchemy_init")
|
||||
|
||||
|
||||
def initialize_sqlalchemy_database() -> bool:
|
||||
"""
|
||||
初始化SQLAlchemy数据库
|
||||
创建所有表结构
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始初始化SQLAlchemy数据库...")
|
||||
|
||||
# 初始化数据库引擎和会话
|
||||
engine, session_local = initialize_database()
|
||||
|
||||
if engine is None:
|
||||
logger.error("数据库引擎初始化失败")
|
||||
return False
|
||||
|
||||
logger.info("SQLAlchemy数据库初始化成功")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"数据库初始化过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def create_all_tables() -> bool:
|
||||
"""
|
||||
创建所有数据库表
|
||||
|
||||
Returns:
|
||||
bool: 创建是否成功
|
||||
"""
|
||||
try:
|
||||
logger.info("开始创建数据库表...")
|
||||
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
logger.error("无法获取数据库引擎")
|
||||
return False
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
logger.info("数据库表创建成功")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"创建数据库表失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"创建数据库表过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def check_database_connection() -> bool:
|
||||
"""
|
||||
检查数据库连接是否正常
|
||||
|
||||
Returns:
|
||||
bool: 连接是否正常
|
||||
"""
|
||||
try:
|
||||
session = get_session()
|
||||
if session is None:
|
||||
logger.error("无法获取数据库会话")
|
||||
return False
|
||||
|
||||
# 检查会话是否可用(如果能获取到会话说明连接正常)
|
||||
if session is None:
|
||||
logger.error("数据库会话无效")
|
||||
return False
|
||||
|
||||
session.close()
|
||||
|
||||
logger.info("数据库连接检查通过")
|
||||
return True
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
logger.error(f"数据库连接检查失败: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.error(f"数据库连接检查过程中发生未知错误: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def get_database_info() -> Optional[dict]:
|
||||
"""
|
||||
获取数据库信息
|
||||
|
||||
Returns:
|
||||
dict: 数据库信息字典,包含引擎信息等
|
||||
"""
|
||||
try:
|
||||
engine = get_engine()
|
||||
if engine is None:
|
||||
return None
|
||||
|
||||
info = {
|
||||
'engine_name': engine.name,
|
||||
'driver': engine.driver,
|
||||
'url': str(engine.url).replace(engine.url.password or '', '***'), # 隐藏密码
|
||||
'pool_size': getattr(engine.pool, 'size', None),
|
||||
'max_overflow': getattr(engine.pool, 'max_overflow', None),
|
||||
}
|
||||
|
||||
return info
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取数据库信息失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
_database_initialized = False
|
||||
|
||||
def initialize_database_compat() -> bool:
|
||||
"""
|
||||
兼容性数据库初始化函数
|
||||
用于替换原有的Peewee初始化代码
|
||||
|
||||
Returns:
|
||||
bool: 初始化是否成功
|
||||
"""
|
||||
global _database_initialized
|
||||
|
||||
if _database_initialized:
|
||||
return True
|
||||
|
||||
success = initialize_sqlalchemy_database()
|
||||
if success:
|
||||
success = create_all_tables()
|
||||
|
||||
if success:
|
||||
success = check_database_connection()
|
||||
|
||||
if success:
|
||||
_database_initialized = True
|
||||
|
||||
return success
|
||||
555
src/common/database/sqlalchemy_models.py
Normal file
555
src/common/database/sqlalchemy_models.py
Normal file
@@ -0,0 +1,555 @@
|
||||
"""SQLAlchemy数据库模型定义
|
||||
|
||||
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
||||
"""
|
||||
|
||||
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import QueuePool
|
||||
import os
|
||||
import datetime
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
|
||||
logger = get_logger("sqlalchemy_models")
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
# MySQL兼容的字段类型辅助函数
|
||||
def get_string_field(max_length=255, **kwargs):
|
||||
"""
|
||||
根据数据库类型返回合适的字符串字段
|
||||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
||||
"""
|
||||
if global_config.database.database_type == "mysql":
|
||||
return String(max_length, **kwargs)
|
||||
else:
|
||||
return Text(**kwargs)
|
||||
|
||||
class SessionProxy:
|
||||
"""线程安全的Session代理类,自动管理session生命周期"""
|
||||
|
||||
def __init__(self):
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_current_session(self):
|
||||
"""获取当前线程的session,如果没有则创建新的"""
|
||||
if not hasattr(self._local, 'session') or self._local.session is None:
|
||||
_, SessionLocal = initialize_database()
|
||||
self._local.session = SessionLocal()
|
||||
return self._local.session
|
||||
|
||||
def _close_current_session(self):
|
||||
"""关闭当前线程的session"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
self._local.session.close()
|
||||
except:
|
||||
pass
|
||||
finally:
|
||||
self._local.session = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
"""代理所有session方法"""
|
||||
session = self._get_current_session()
|
||||
attr = getattr(session, name)
|
||||
|
||||
# 如果是方法,需要特殊处理一些关键方法
|
||||
if callable(attr):
|
||||
if name in ['commit', 'rollback']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
result = attr(*args, **kwargs)
|
||||
if name == 'commit':
|
||||
# commit后不要清除session,只是刷新状态
|
||||
pass # 保持session活跃
|
||||
return result
|
||||
except Exception as e:
|
||||
try:
|
||||
if session and hasattr(session, 'rollback'):
|
||||
session.rollback()
|
||||
except:
|
||||
pass
|
||||
# 发生错误时重新创建session
|
||||
self._close_current_session()
|
||||
raise
|
||||
return wrapper
|
||||
elif name == 'close':
|
||||
def wrapper(*args, **kwargs):
|
||||
result = attr(*args, **kwargs)
|
||||
self._close_current_session()
|
||||
return result
|
||||
return wrapper
|
||||
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
|
||||
def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return attr(*args, **kwargs)
|
||||
except Exception as e:
|
||||
# 如果是连接相关错误,重新创建session再试一次
|
||||
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
|
||||
logger.warning(f"Session问题,重新创建session: {e}")
|
||||
self._close_current_session()
|
||||
new_session = self._get_current_session()
|
||||
new_attr = getattr(new_session, name)
|
||||
return new_attr(*args, **kwargs)
|
||||
raise
|
||||
return wrapper
|
||||
|
||||
return attr
|
||||
|
||||
def new_session(self):
|
||||
"""强制创建新的session(关闭当前的,创建新的)"""
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
def ensure_fresh_session(self):
|
||||
"""确保使用新鲜的session(如果当前session有问题则重新创建)"""
|
||||
if hasattr(self._local, 'session') and self._local.session is not None:
|
||||
try:
|
||||
# 测试session是否还可用
|
||||
self._local.session.execute("SELECT 1")
|
||||
except Exception:
|
||||
# session有问题,重新创建
|
||||
self._close_current_session()
|
||||
return self._get_current_session()
|
||||
|
||||
# 创建全局session代理实例
|
||||
_global_session_proxy = SessionProxy()
|
||||
|
||||
def get_session():
|
||||
"""返回线程安全的session代理,自动管理生命周期"""
|
||||
return _global_session_proxy
|
||||
|
||||
|
||||
class ChatStreams(Base):
|
||||
"""聊天流模型"""
|
||||
__tablename__ = 'chat_streams'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
|
||||
create_time = Column(Float, nullable=False)
|
||||
group_platform = Column(Text, nullable=True)
|
||||
group_id = Column(get_string_field(100), nullable=True, index=True)
|
||||
group_name = Column(Text, nullable=True)
|
||||
last_active_time = Column(Float, nullable=False)
|
||||
platform = Column(Text, nullable=False)
|
||||
user_platform = Column(Text, nullable=False)
|
||||
user_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
user_nickname = Column(Text, nullable=False)
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_chatstreams_stream_id', 'stream_id'),
|
||||
Index('idx_chatstreams_user_id', 'user_id'),
|
||||
Index('idx_chatstreams_group_id', 'group_id'),
|
||||
)
|
||||
|
||||
|
||||
class LLMUsage(Base):
|
||||
"""LLM使用记录模型"""
|
||||
__tablename__ = 'llm_usage'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
model_name = Column(get_string_field(100), nullable=False, index=True)
|
||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||
request_type = Column(get_string_field(50), nullable=False, index=True)
|
||||
endpoint = Column(Text, nullable=False)
|
||||
prompt_tokens = Column(Integer, nullable=False)
|
||||
completion_tokens = Column(Integer, nullable=False)
|
||||
total_tokens = Column(Integer, nullable=False)
|
||||
cost = Column(Float, nullable=False)
|
||||
status = Column(Text, nullable=False)
|
||||
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_llmusage_model_name', 'model_name'),
|
||||
Index('idx_llmusage_user_id', 'user_id'),
|
||||
Index('idx_llmusage_request_type', 'request_type'),
|
||||
Index('idx_llmusage_timestamp', 'timestamp'),
|
||||
)
|
||||
|
||||
|
||||
class Emoji(Base):
|
||||
"""表情包模型"""
|
||||
__tablename__ = 'emoji'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||
format = Column(Text, nullable=False)
|
||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
query_count = Column(Integer, nullable=False, default=0)
|
||||
is_registered = Column(Boolean, nullable=False, default=False)
|
||||
is_banned = Column(Boolean, nullable=False, default=False)
|
||||
emotion = Column(Text, nullable=True)
|
||||
record_time = Column(Float, nullable=False)
|
||||
register_time = Column(Float, nullable=True)
|
||||
usage_count = Column(Integer, nullable=False, default=0)
|
||||
last_used_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_emoji_full_path', 'full_path'),
|
||||
Index('idx_emoji_hash', 'emoji_hash'),
|
||||
)
|
||||
|
||||
|
||||
class Messages(Base):
|
||||
"""消息模型"""
|
||||
__tablename__ = 'messages'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
message_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
time = Column(Float, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
reply_to = Column(Text, nullable=True)
|
||||
interest_value = Column(Float, nullable=True)
|
||||
is_mentioned = Column(Boolean, nullable=True)
|
||||
|
||||
# 从 chat_info 扁平化而来的字段
|
||||
chat_info_stream_id = Column(Text, nullable=False)
|
||||
chat_info_platform = Column(Text, nullable=False)
|
||||
chat_info_user_platform = Column(Text, nullable=False)
|
||||
chat_info_user_id = Column(Text, nullable=False)
|
||||
chat_info_user_nickname = Column(Text, nullable=False)
|
||||
chat_info_user_cardname = Column(Text, nullable=True)
|
||||
chat_info_group_platform = Column(Text, nullable=True)
|
||||
chat_info_group_id = Column(Text, nullable=True)
|
||||
chat_info_group_name = Column(Text, nullable=True)
|
||||
chat_info_create_time = Column(Float, nullable=False)
|
||||
chat_info_last_active_time = Column(Float, nullable=False)
|
||||
|
||||
# 从顶层 user_info 扁平化而来的字段
|
||||
user_platform = Column(Text, nullable=True)
|
||||
user_id = Column(get_string_field(100), nullable=True, index=True)
|
||||
user_nickname = Column(Text, nullable=True)
|
||||
user_cardname = Column(Text, nullable=True)
|
||||
|
||||
processed_plain_text = Column(Text, nullable=True)
|
||||
display_message = Column(Text, nullable=True)
|
||||
memorized_times = Column(Integer, nullable=False, default=0)
|
||||
priority_mode = Column(Text, nullable=True)
|
||||
priority_info = Column(Text, nullable=True)
|
||||
additional_config = Column(Text, nullable=True)
|
||||
is_emoji = Column(Boolean, nullable=False, default=False)
|
||||
is_picid = Column(Boolean, nullable=False, default=False)
|
||||
is_command = Column(Boolean, nullable=False, default=False)
|
||||
is_notify = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_messages_message_id', 'message_id'),
|
||||
Index('idx_messages_chat_id', 'chat_id'),
|
||||
Index('idx_messages_time', 'time'),
|
||||
Index('idx_messages_user_id', 'user_id'),
|
||||
)
|
||||
|
||||
|
||||
class ActionRecords(Base):
|
||||
"""动作记录模型"""
|
||||
__tablename__ = 'action_records'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
action_id = Column(get_string_field(100), nullable=False, index=True)
|
||||
time = Column(Float, nullable=False)
|
||||
action_name = Column(Text, nullable=False)
|
||||
action_data = Column(Text, nullable=False)
|
||||
action_done = Column(Boolean, nullable=False, default=False)
|
||||
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
|
||||
action_prompt_display = Column(Text, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
chat_info_stream_id = Column(Text, nullable=False)
|
||||
chat_info_platform = Column(Text, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_actionrecords_action_id', 'action_id'),
|
||||
Index('idx_actionrecords_chat_id', 'chat_id'),
|
||||
Index('idx_actionrecords_time', 'time'),
|
||||
)
|
||||
|
||||
|
||||
class Images(Base):
|
||||
"""图像信息模型"""
|
||||
__tablename__ = 'images'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
image_id = Column(Text, nullable=False, default="")
|
||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=True)
|
||||
path = Column(get_string_field(500), nullable=False, unique=True)
|
||||
count = Column(Integer, nullable=False, default=1)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
type = Column(Text, nullable=False)
|
||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_images_emoji_hash', 'emoji_hash'),
|
||||
Index('idx_images_path', 'path'),
|
||||
)
|
||||
|
||||
|
||||
class ImageDescriptions(Base):
|
||||
"""图像描述信息模型"""
|
||||
__tablename__ = 'image_descriptions'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
type = Column(Text, nullable=False)
|
||||
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||
description = Column(Text, nullable=False)
|
||||
timestamp = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_imagedesc_hash', 'image_description_hash'),
|
||||
)
|
||||
|
||||
|
||||
class OnlineTime(Base):
|
||||
"""在线时长记录模型"""
|
||||
__tablename__ = 'online_time'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
|
||||
duration = Column(Integer, nullable=False)
|
||||
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
end_timestamp = Column(DateTime, nullable=False, index=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_onlinetime_end_timestamp', 'end_timestamp'),
|
||||
)
|
||||
|
||||
|
||||
class PersonInfo(Base):
|
||||
"""人物信息模型"""
|
||||
__tablename__ = 'person_info'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
|
||||
person_name = Column(Text, nullable=True)
|
||||
name_reason = Column(Text, nullable=True)
|
||||
platform = Column(Text, nullable=False)
|
||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||
nickname = Column(Text, nullable=True)
|
||||
impression = Column(Text, nullable=True)
|
||||
short_impression = Column(Text, nullable=True)
|
||||
points = Column(Text, nullable=True)
|
||||
forgotten_points = Column(Text, nullable=True)
|
||||
info_list = Column(Text, nullable=True)
|
||||
know_times = Column(Float, nullable=True)
|
||||
know_since = Column(Float, nullable=True)
|
||||
last_know = Column(Float, nullable=True)
|
||||
attitude = Column(Integer, nullable=True, default=50)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_personinfo_person_id', 'person_id'),
|
||||
Index('idx_personinfo_user_id', 'user_id'),
|
||||
)
|
||||
|
||||
|
||||
class Memory(Base):
|
||||
"""记忆模型"""
|
||||
__tablename__ = 'memory'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
memory_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
chat_id = Column(Text, nullable=True)
|
||||
memory_text = Column(Text, nullable=True)
|
||||
keywords = Column(Text, nullable=True)
|
||||
create_time = Column(Float, nullable=True)
|
||||
last_view_time = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_memory_memory_id', 'memory_id'),
|
||||
)
|
||||
|
||||
|
||||
class Expression(Base):
|
||||
"""表达风格模型"""
|
||||
__tablename__ = 'expression'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
situation = Column(Text, nullable=False)
|
||||
style = Column(Text, nullable=False)
|
||||
count = Column(Float, nullable=False)
|
||||
last_active_time = Column(Float, nullable=False)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
type = Column(Text, nullable=False)
|
||||
create_date = Column(Float, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_expression_chat_id', 'chat_id'),
|
||||
)
|
||||
|
||||
|
||||
class ThinkingLog(Base):
|
||||
"""思考日志模型"""
|
||||
__tablename__ = 'thinking_logs'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||
trigger_text = Column(Text, nullable=True)
|
||||
response_text = Column(Text, nullable=True)
|
||||
trigger_info_json = Column(Text, nullable=True)
|
||||
response_info_json = Column(Text, nullable=True)
|
||||
timing_results_json = Column(Text, nullable=True)
|
||||
chat_history_json = Column(Text, nullable=True)
|
||||
chat_history_in_thinking_json = Column(Text, nullable=True)
|
||||
chat_history_after_response_json = Column(Text, nullable=True)
|
||||
heartflow_data_json = Column(Text, nullable=True)
|
||||
reasoning_data_json = Column(Text, nullable=True)
|
||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_thinkinglog_chat_id', 'chat_id'),
|
||||
)
|
||||
|
||||
|
||||
class GraphNodes(Base):
|
||||
"""记忆图节点模型"""
|
||||
__tablename__ = 'graph_nodes'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
||||
memory_items = Column(Text, nullable=False)
|
||||
hash = Column(Text, nullable=False)
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphnodes_concept', 'concept'),
|
||||
)
|
||||
|
||||
|
||||
class GraphEdges(Base):
|
||||
"""记忆图边模型"""
|
||||
__tablename__ = 'graph_edges'
|
||||
|
||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||
source = Column(get_string_field(255), nullable=False, index=True)
|
||||
target = Column(get_string_field(255), nullable=False, index=True)
|
||||
strength = Column(Integer, nullable=False)
|
||||
hash = Column(Text, nullable=False)
|
||||
created_time = Column(Float, nullable=False)
|
||||
last_modified = Column(Float, nullable=False)
|
||||
|
||||
__table_args__ = (
|
||||
Index('idx_graphedges_source', 'source'),
|
||||
Index('idx_graphedges_target', 'target'),
|
||||
)
|
||||
|
||||
|
||||
# 数据库引擎和会话管理
|
||||
_engine = None
|
||||
_SessionLocal = None
|
||||
|
||||
|
||||
def get_database_url():
|
||||
"""获取数据库连接URL"""
|
||||
config = global_config.database
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# 对用户名和密码进行URL编码,处理特殊字符
|
||||
from urllib.parse import quote_plus
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
return (
|
||||
f"mysql+pymysql://{encoded_user}:{encoded_password}"
|
||||
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
f"?charset={config.mysql_charset}"
|
||||
)
|
||||
else: # SQLite
|
||||
# 如果是相对路径,则相对于项目根目录
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
return f"sqlite:///{db_path}"
|
||||
|
||||
|
||||
def initialize_database():
|
||||
"""初始化数据库引擎和会话"""
|
||||
global _engine, _SessionLocal
|
||||
|
||||
if _engine is not None:
|
||||
return _engine, _SessionLocal
|
||||
|
||||
database_url = get_database_url()
|
||||
config = global_config.database
|
||||
|
||||
# 配置引擎参数
|
||||
engine_kwargs = {
|
||||
'echo': False, # 生产环境关闭SQL日志
|
||||
'future': True,
|
||||
}
|
||||
|
||||
if config.database_type == "mysql":
|
||||
# MySQL连接池配置
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': config.connection_pool_size,
|
||||
'max_overflow': config.connection_pool_size * 2,
|
||||
'pool_timeout': config.connection_timeout,
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'autocommit': config.mysql_autocommit,
|
||||
'charset': config.mysql_charset,
|
||||
'connect_timeout': config.connection_timeout,
|
||||
'read_timeout': 30,
|
||||
'write_timeout': 30,
|
||||
}
|
||||
})
|
||||
else:
|
||||
# SQLite配置 - 添加连接池设置以避免连接耗尽
|
||||
engine_kwargs.update({
|
||||
'poolclass': QueuePool,
|
||||
'pool_size': 20, # 增加池大小
|
||||
'max_overflow': 30, # 增加溢出连接数
|
||||
'pool_timeout': 60, # 增加超时时间
|
||||
'pool_recycle': 3600, # 1小时回收连接
|
||||
'pool_pre_ping': True, # 连接前ping检查
|
||||
'connect_args': {
|
||||
'check_same_thread': False,
|
||||
'timeout': 30,
|
||||
}
|
||||
})
|
||||
|
||||
_engine = create_engine(database_url, **engine_kwargs)
|
||||
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
|
||||
|
||||
# 创建所有表
|
||||
Base.metadata.create_all(bind=_engine)
|
||||
|
||||
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
|
||||
return _engine, _SessionLocal
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_db_session():
|
||||
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
|
||||
session = None
|
||||
try:
|
||||
_, SessionLocal = initialize_database()
|
||||
session = SessionLocal()
|
||||
yield session
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
if session:
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
if session:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_engine():
|
||||
"""获取数据库引擎"""
|
||||
engine, _ = initialize_database()
|
||||
return engine
|
||||
@@ -373,6 +373,7 @@ MODULE_COLORS = {
|
||||
"base_command": "\033[38;5;208m", # 橙色
|
||||
"component_registry": "\033[38;5;214m", # 橙黄色
|
||||
"stream_api": "\033[38;5;220m", # 黄色
|
||||
"plugin_hot_reload": "\033[38;5;226m", #品红色
|
||||
"config_api": "\033[38;5;226m", # 亮黄色
|
||||
"heartflow_api": "\033[38;5;154m", # 黄绿色
|
||||
"action_apis": "\033[38;5;118m", # 绿色
|
||||
@@ -406,6 +407,7 @@ MODULE_COLORS = {
|
||||
"base_action": "\033[38;5;250m", # 浅灰色
|
||||
# 数据库和消息
|
||||
"database_model": "\033[38;5;94m", # 橙褐色
|
||||
"database": "\033[38;5;46m", # 橙褐色
|
||||
"maim_message": "\033[38;5;140m", # 紫褐色
|
||||
# 日志系统
|
||||
"logger": "\033[38;5;8m", # 深灰色
|
||||
@@ -430,6 +432,8 @@ MODULE_ALIASES = {
|
||||
"memory_activator": "记忆",
|
||||
"tool_use": "工具",
|
||||
"expressor": "表达方式",
|
||||
"plugin_hot_reload": "热重载",
|
||||
"database": "数据库",
|
||||
"database_model": "数据库",
|
||||
"mood": "情绪",
|
||||
"memory": "记忆",
|
||||
|
||||
@@ -1,20 +1,26 @@
|
||||
import traceback
|
||||
|
||||
from typing import List, Any, Optional
|
||||
from peewee import Model # 添加 Peewee Model 导入
|
||||
from typing import List, Optional, Any, Dict
|
||||
from sqlalchemy import not_, select, func
|
||||
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
from src.config.config import global_config
|
||||
|
||||
from src.common.database.database_model import Messages
|
||||
# from src.common.database.database_model import Messages
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
from src.common.database.sqlalchemy_database_api import get_session
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
|
||||
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
|
||||
def _model_to_dict(instance: Base) -> Dict[str, Any]:
|
||||
"""
|
||||
将 Peewee 模型实例转换为字典。
|
||||
将 SQLAlchemy 模型实例转换为字典。
|
||||
"""
|
||||
return model_instance.__data__
|
||||
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
|
||||
|
||||
|
||||
def find_messages(
|
||||
@@ -38,7 +44,8 @@ def find_messages(
|
||||
消息字典列表,如果出错则返回空列表。
|
||||
"""
|
||||
try:
|
||||
query = Messages.select()
|
||||
session = get_session()
|
||||
query = select(Messages)
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
@@ -77,42 +84,57 @@ def find_messages(
|
||||
query = query.where(Messages.user_id != global_config.bot.qq_account)
|
||||
|
||||
if filter_command:
|
||||
query = query.where(not Messages.is_command)
|
||||
query = query.where(not_(Messages.is_command))
|
||||
|
||||
if limit > 0:
|
||||
# 确保limit是正整数
|
||||
limit = max(1, int(limit))
|
||||
|
||||
if limit_mode == "earliest":
|
||||
# 获取时间最早的 limit 条记录,已经是正序
|
||||
query = query.order_by(Messages.time.asc()).limit(limit)
|
||||
peewee_results = list(query)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行earliest查询失败: {e}")
|
||||
results = []
|
||||
else: # 默认为 'latest'
|
||||
# 获取时间最晚的 limit 条记录
|
||||
query = query.order_by(Messages.time.desc()).limit(limit)
|
||||
latest_results_peewee = list(query)
|
||||
# 将结果按时间正序排列
|
||||
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
|
||||
try:
|
||||
latest_results = session.execute(query).scalars().all()
|
||||
# 将结果按时间正序排列
|
||||
results = sorted(latest_results, key=lambda msg: msg.time)
|
||||
except Exception as e:
|
||||
logger.error(f"执行latest查询失败: {e}")
|
||||
results = []
|
||||
else:
|
||||
# limit 为 0 时,应用传入的 sort 参数
|
||||
if sort:
|
||||
peewee_sort_terms = []
|
||||
sort_terms = []
|
||||
for field_name, direction in sort:
|
||||
if hasattr(Messages, field_name):
|
||||
field = getattr(Messages, field_name)
|
||||
if direction == 1: # ASC
|
||||
peewee_sort_terms.append(field.asc())
|
||||
sort_terms.append(field.asc())
|
||||
elif direction == -1: # DESC
|
||||
peewee_sort_terms.append(field.desc())
|
||||
sort_terms.append(field.desc())
|
||||
else:
|
||||
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
|
||||
else:
|
||||
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
|
||||
if peewee_sort_terms:
|
||||
query = query.order_by(*peewee_sort_terms)
|
||||
peewee_results = list(query)
|
||||
if sort_terms:
|
||||
query = query.order_by(*sort_terms)
|
||||
try:
|
||||
results = session.execute(query).scalars().all()
|
||||
except Exception as e:
|
||||
logger.error(f"执行无限制查询失败: {e}")
|
||||
results = []
|
||||
|
||||
return [_model_to_dict(msg) for msg in peewee_results]
|
||||
return [_model_to_dict(msg) for msg in results]
|
||||
except Exception as e:
|
||||
log_message = (
|
||||
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
|
||||
+ traceback.format_exc()
|
||||
)
|
||||
logger.error(log_message)
|
||||
@@ -130,7 +152,8 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
符合条件的消息数量,如果出错则返回 0。
|
||||
"""
|
||||
try:
|
||||
query = Messages.select()
|
||||
session = get_session()
|
||||
query = select(func.count(Messages.id))
|
||||
|
||||
# 应用过滤器
|
||||
if message_filter:
|
||||
@@ -167,14 +190,14 @@ def count_messages(message_filter: dict[str, Any]) -> int:
|
||||
if conditions:
|
||||
query = query.where(*conditions)
|
||||
|
||||
count = query.count()
|
||||
return count
|
||||
count = session.execute(query).scalar()
|
||||
return count or 0
|
||||
except Exception as e:
|
||||
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
|
||||
logger.error(log_message)
|
||||
return 0
|
||||
|
||||
|
||||
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
|
||||
# 注意:对于 Peewee,插入操作通常是 Messages.create(...) 或 instance.save()。
|
||||
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
|
||||
# 注意:对于 SQLAlchemy,插入操作通常是使用 session.add() 和 session.commit()。
|
||||
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。
|
||||
|
||||
Reference in New Issue
Block a user