初始化

This commit is contained in:
雅诺狐
2025-08-11 19:34:18 +08:00
committed by Windpicker-owo
parent ef7a3aee23
commit 23ee3767ef
77 changed files with 10000 additions and 7525 deletions

View File

@@ -1,14 +1,103 @@
import os
from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database
from rich.traceback import install
from src.common.logger import get_logger
# SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_engine, get_session
install(extra_lines=3)
_client = None
_db = None
_sql_engine = None
logger = get_logger("database")
# 兼容性为了不破坏现有代码保留db变量但指向SQLAlchemy
class DatabaseProxy:
"""数据库代理类提供Peewee到SQLAlchemy的兼容性接口"""
def __init__(self):
self._engine = None
self._session = None
def initialize(self, *args, **kwargs):
"""初始化数据库连接"""
return initialize_database_compat()
def connect(self, reuse_if_open=True):
"""连接数据库(兼容性方法)"""
try:
self._engine = get_engine()
return True
except Exception as e:
logger.error(f"数据库连接失败: {e}")
return False
def is_closed(self):
"""检查数据库是否关闭(兼容性方法)"""
return self._engine is None
def create_tables(self, models, safe=True):
"""创建表(兼容性方法)"""
try:
from src.common.database.sqlalchemy_models import Base
engine = get_engine()
Base.metadata.create_all(bind=engine)
return True
except Exception as e:
logger.error(f"创建表失败: {e}")
return False
def table_exists(self, model):
"""检查表是否存在(兼容性方法)"""
try:
from sqlalchemy import inspect
engine = get_engine()
inspector = inspect(engine)
table_name = getattr(model, '_meta', {}).get('table_name', model.__name__.lower())
return table_name in inspector.get_table_names()
except Exception:
return False
def execute_sql(self, sql):
"""执行SQL兼容性方法"""
try:
from sqlalchemy import text
session = get_session()
result = session.execute(text(sql))
session.close()
return result
except Exception as e:
logger.error(f"执行SQL失败: {e}")
raise
def atomic(self):
"""事务上下文管理器(兼容性方法)"""
return SQLAlchemyTransaction()
class SQLAlchemyTransaction:
"""SQLAlchemy事务上下文管理器"""
def __init__(self):
self.session = None
def __enter__(self):
self.session = get_session()
return self.session
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is None:
self.session.commit()
else:
self.session.rollback()
self.session.close()
# 创建全局数据库代理实例
db = DatabaseProxy()
def __create_database_instance():
uri = os.getenv("MONGODB_URI")
@@ -39,7 +128,7 @@ def __create_database_instance():
def get_db():
"""获取数据库连接实例,延迟初始化。"""
"""获取MongoDB连接实例,延迟初始化。"""
global _client, _db
if _client is None:
_client = __create_database_instance()
@@ -47,6 +136,47 @@ def get_db():
return _db
def initialize_sql_database(database_config):
"""
根据配置初始化SQL数据库连接SQLAlchemy版本
Args:
database_config: DatabaseConfig对象
"""
global _sql_engine
try:
logger.info("使用SQLAlchemy初始化SQL数据库...")
# 记录数据库配置信息
if database_config.database_type == "mysql":
connection_info = f"{database_config.mysql_user}@{database_config.mysql_host}:{database_config.mysql_port}/{database_config.mysql_database}"
logger.info("MySQL数据库连接配置:")
logger.info(f" 连接信息: {connection_info}")
logger.info(f" 字符集: {database_config.mysql_charset}")
else:
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
if not os.path.isabs(database_config.sqlite_path):
db_path = os.path.join(ROOT_PATH, database_config.sqlite_path)
else:
db_path = database_config.sqlite_path
logger.info("SQLite数据库连接配置:")
logger.info(f" 数据库文件: {db_path}")
# 使用SQLAlchemy初始化
success = initialize_database_compat()
if success:
_sql_engine = get_engine()
logger.info("SQLAlchemy数据库初始化成功")
else:
logger.error("SQLAlchemy数据库初始化失败")
return _sql_engine
except Exception as e:
logger.error(f"初始化SQL数据库失败: {e}")
return None
class DBWrapper:
"""数据库代理类,保持接口兼容性同时实现懒加载。"""
@@ -57,26 +187,6 @@ class DBWrapper:
return get_db()[key] # type: ignore
# 全局数据库访问点
# 全局MongoDB数据库访问点
memory_db: Database = DBWrapper() # type: ignore
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_DB_DIR = os.path.join(ROOT_PATH, "data")
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
# 确保数据库目录存在
os.makedirs(_DB_DIR, exist_ok=True)
# 全局 Peewee SQLite 数据库访问点
db = SqliteDatabase(
_DB_FILE,
pragmas={
"journal_mode": "wal", # WAL模式提高并发性能
"cache_size": -64 * 1000, # 64MB缓存
"foreign_keys": 1,
"ignore_check_constraints": 0,
"synchronous": 0, # 异步写入提高性能
"busy_timeout": 1000, # 1秒超时而不是3秒
},
)

View File

@@ -0,0 +1,420 @@
"""SQLAlchemy数据库API模块
提供基于SQLAlchemy的数据库操作替换Peewee以解决MySQL连接问题
支持自动重连、连接池管理和更好的错误处理
"""
import traceback
import time
from typing import Dict, List, Any, Union, Type, Optional
from contextlib import contextmanager
from sqlalchemy.orm import Session
from sqlalchemy.exc import SQLAlchemyError, DisconnectionError, OperationalError
from sqlalchemy import desc, asc, func, and_, or_
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import (
Base, get_db_session, Messages, ActionRecords, PersonInfo, ChatStreams,
LLMUsage, Emoji, Images, ImageDescriptions, OnlineTime, Memory,
Expression, ThinkingLog, GraphNodes, GraphEdges,get_session
)
logger = get_logger("sqlalchemy_database_api")
# 模型映射表,用于通过名称获取模型类
MODEL_MAPPING = {
'Messages': Messages,
'ActionRecords': ActionRecords,
'PersonInfo': PersonInfo,
'ChatStreams': ChatStreams,
'LLMUsage': LLMUsage,
'Emoji': Emoji,
'Images': Images,
'ImageDescriptions': ImageDescriptions,
'OnlineTime': OnlineTime,
'Memory': Memory,
'Expression': Expression,
'ThinkingLog': ThinkingLog,
'GraphNodes': GraphNodes,
'GraphEdges': GraphEdges,
}
@contextmanager
def get_db_session():
"""数据库会话上下文管理器,自动处理事务和连接错误"""
session = None
max_retries = 3
retry_delay = 1.0
for attempt in range(max_retries):
try:
session = get_session()
yield session
session.commit()
break
except (DisconnectionError, OperationalError) as e:
logger.warning(f"数据库连接错误 (尝试 {attempt + 1}/{max_retries}): {e}")
if session:
session.rollback()
session.close()
if attempt < max_retries - 1:
time.sleep(retry_delay * (attempt + 1))
else:
raise
except Exception as e:
if session:
session.rollback()
raise
finally:
if session:
session.close()
def build_filters(session: Session, model_class: Type[Base], filters: Dict[str, Any]):
"""构建查询过滤条件"""
conditions = []
for field_name, value in filters.items():
if not hasattr(model_class, field_name):
logger.warning(f"模型 {model_class.__name__} 中不存在字段 '{field_name}'")
continue
field = getattr(model_class, field_name)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(~field.in_(op_value))
else:
logger.warning(f"未知操作符 '{op}' (字段: '{field_name}')")
else:
# 直接相等比较
conditions.append(field == value)
return conditions
async def db_query(
model_class: Type[Base],
data: Optional[Dict[str, Any]] = None,
query_type: Optional[str] = "get",
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""执行数据库查询操作
Args:
model_class: SQLAlchemy模型类
data: 用于创建或更新的数据字典
query_type: 查询类型 ("get", "create", "update", "delete", "count")
filters: 过滤条件字典
limit: 限制结果数量
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
根据查询类型返回相应结果
"""
try:
if query_type not in ["get", "create", "update", "delete", "count"]:
raise ValueError("query_type must be 'get', 'create', 'update', 'delete' or 'count'")
with get_db_session() as session:
if query_type == "get":
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 应用排序
if order_by:
for field_name in order_by:
if field_name.startswith("-"):
field_name = field_name[1:]
if hasattr(model_class, field_name):
query = query.order_by(desc(getattr(model_class, field_name)))
else:
if hasattr(model_class, field_name):
query = query.order_by(asc(getattr(model_class, field_name)))
# 应用限制
if limit and limit > 0:
query = query.limit(limit)
# 执行查询
results = query.all()
# 转换为字典格式
result_dicts = []
for result in results:
result_dict = {}
for column in result.__table__.columns:
result_dict[column.name] = getattr(result, column.name)
result_dicts.append(result_dict)
if single_result:
return result_dicts[0] if result_dicts else None
return result_dicts
elif query_type == "create":
if not data:
raise ValueError("创建记录需要提供data参数")
# 创建新记录
new_record = model_class(**data)
session.add(new_record)
session.flush() # 获取自动生成的ID
# 转换为字典格式返回
result_dict = {}
for column in new_record.__table__.columns:
result_dict[column.name] = getattr(new_record, column.name)
return result_dict
elif query_type == "update":
if not data:
raise ValueError("更新记录需要提供data参数")
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 执行更新
affected_rows = query.update(data)
return affected_rows
elif query_type == "delete":
query = session.query(model_class)
# 应用过滤条件
if filters:
conditions = build_filters(session, model_class, filters)
if conditions:
query = query.filter(and_(*conditions))
# 执行删除
affected_rows = query.delete()
return affected_rows
elif query_type == "count":
query = session.query(func.count(model_class.id))
# 应用过滤条件
if filters:
base_query = session.query(model_class)
conditions = build_filters(session, model_class, filters)
if conditions:
base_query = base_query.filter(and_(*conditions))
query = session.query(func.count()).select_from(base_query.subquery())
return query.scalar()
except SQLAlchemyError as e:
logger.error(f"[SQLAlchemy] 数据库操作出错: {e}")
traceback.print_exc()
# 根据查询类型返回合适的默认值
if query_type == "get":
return None if single_result else []
elif query_type in ["create", "update", "delete", "count"]:
return None
return None
except Exception as e:
logger.error(f"[SQLAlchemy] 意外错误: {e}")
traceback.print_exc()
if query_type == "get":
return None if single_result else []
return None
async def db_save(
model_class: Type[Base],
data: Dict[str, Any],
key_field: Optional[str] = None,
key_value: Optional[Any] = None
) -> Optional[Dict[str, Any]]:
"""保存数据到数据库(创建或更新)
Args:
model_class: SQLAlchemy模型类
data: 要保存的数据字典
key_field: 用于查找现有记录的字段名
key_value: 用于查找现有记录的字段值
Returns:
保存后的记录数据或None
"""
try:
with get_db_session() as session:
# 如果提供了key_field和key_value尝试更新现有记录
if key_field and key_value is not None:
if hasattr(model_class, key_field):
existing_record = session.query(model_class).filter(
getattr(model_class, key_field) == key_value
).first()
if existing_record:
# 更新现有记录
for field, value in data.items():
if hasattr(existing_record, field):
setattr(existing_record, field, value)
session.flush()
# 转换为字典格式返回
result_dict = {}
for column in existing_record.__table__.columns:
result_dict[column.name] = getattr(existing_record, column.name)
return result_dict
# 创建新记录
new_record = model_class(**data)
session.add(new_record)
session.flush()
# 转换为字典格式返回
result_dict = {}
for column in new_record.__table__.columns:
result_dict[column.name] = getattr(new_record, column.name)
return result_dict
except SQLAlchemyError as e:
logger.error(f"[SQLAlchemy] 保存数据库记录出错: {e}")
traceback.print_exc()
return None
except Exception as e:
logger.error(f"[SQLAlchemy] 保存时意外错误: {e}")
traceback.print_exc()
return None
async def db_get(
model_class: Type[Base],
filters: Optional[Dict[str, Any]] = None,
limit: Optional[int] = None,
order_by: Optional[str] = None,
single_result: Optional[bool] = False,
) -> Union[List[Dict[str, Any]], Dict[str, Any], None]:
"""从数据库获取记录
Args:
model_class: SQLAlchemy模型类
filters: 过滤条件
limit: 结果数量限制
order_by: 排序字段,前缀'-'表示降序
single_result: 是否只返回单个结果
Returns:
记录数据或None
"""
order_by_list = [order_by] if order_by else None
return await db_query(
model_class=model_class,
query_type="get",
filters=filters,
limit=limit,
order_by=order_by_list,
single_result=single_result
)
async def store_action_info(
chat_stream=None,
action_build_into_prompt: bool = False,
action_prompt_display: str = "",
action_done: bool = True,
thinking_id: str = "",
action_data: Optional[dict] = None,
action_name: str = "",
) -> Optional[Dict[str, Any]]:
"""存储动作信息到数据库
Args:
chat_stream: 聊天流对象
action_build_into_prompt: 是否将此动作构建到提示中
action_prompt_display: 动作的提示显示文本
action_done: 动作是否完成
thinking_id: 关联的思考ID
action_data: 动作数据字典
action_name: 动作名称
Returns:
保存的记录数据或None
"""
try:
import json
# 构建动作记录数据
record_data = {
"action_id": thinking_id or str(int(time.time() * 1000000)),
"time": time.time(),
"action_name": action_name,
"action_data": json.dumps(action_data or {}, ensure_ascii=False),
"action_done": action_done,
"action_build_into_prompt": action_build_into_prompt,
"action_prompt_display": action_prompt_display,
}
# 从chat_stream获取聊天信息
if chat_stream:
record_data.update({
"chat_id": getattr(chat_stream, "stream_id", ""),
"chat_info_stream_id": getattr(chat_stream, "stream_id", ""),
"chat_info_platform": getattr(chat_stream, "platform", ""),
})
else:
record_data.update({
"chat_id": "",
"chat_info_stream_id": "",
"chat_info_platform": "",
})
# 保存记录
saved_record = await db_save(
ActionRecords,
data=record_data,
key_field="action_id",
key_value=record_data["action_id"]
)
if saved_record:
logger.debug(f"[SQLAlchemy] 成功存储动作信息: {action_name} (ID: {record_data['action_id']})")
else:
logger.error(f"[SQLAlchemy] 存储动作信息失败: {action_name}")
return saved_record
except Exception as e:
logger.error(f"[SQLAlchemy] 存储动作信息时发生错误: {e}")
traceback.print_exc()
return None
# 兼容性函数方便从Peewee迁移
def get_model_class(model_name: str) -> Optional[Type[Base]]:
"""根据模型名称获取模型类"""
return MODEL_MAPPING.get(model_name)

View File

@@ -0,0 +1,158 @@
"""SQLAlchemy数据库初始化模块
替换Peewee的数据库初始化逻辑
提供统一的数据库初始化接口
"""
from typing import Optional
from sqlalchemy.exc import SQLAlchemyError
from src.common.logger import get_logger
from src.common.database.sqlalchemy_models import (
Base, get_engine, get_session, initialize_database
)
logger = get_logger("sqlalchemy_init")
def initialize_sqlalchemy_database() -> bool:
"""
初始化SQLAlchemy数据库
创建所有表结构
Returns:
bool: 初始化是否成功
"""
try:
logger.info("开始初始化SQLAlchemy数据库...")
# 初始化数据库引擎和会话
engine, session_local = initialize_database()
if engine is None:
logger.error("数据库引擎初始化失败")
return False
logger.info("SQLAlchemy数据库初始化成功")
return True
except SQLAlchemyError as e:
logger.error(f"SQLAlchemy数据库初始化失败: {e}")
return False
except Exception as e:
logger.error(f"数据库初始化过程中发生未知错误: {e}")
return False
def create_all_tables() -> bool:
"""
创建所有数据库表
Returns:
bool: 创建是否成功
"""
try:
logger.info("开始创建数据库表...")
engine = get_engine()
if engine is None:
logger.error("无法获取数据库引擎")
return False
# 创建所有表
Base.metadata.create_all(bind=engine)
logger.info("数据库表创建成功")
return True
except SQLAlchemyError as e:
logger.error(f"创建数据库表失败: {e}")
return False
except Exception as e:
logger.error(f"创建数据库表过程中发生未知错误: {e}")
return False
def check_database_connection() -> bool:
"""
检查数据库连接是否正常
Returns:
bool: 连接是否正常
"""
try:
session = get_session()
if session is None:
logger.error("无法获取数据库会话")
return False
# 检查会话是否可用(如果能获取到会话说明连接正常)
if session is None:
logger.error("数据库会话无效")
return False
session.close()
logger.info("数据库连接检查通过")
return True
except SQLAlchemyError as e:
logger.error(f"数据库连接检查失败: {e}")
return False
except Exception as e:
logger.error(f"数据库连接检查过程中发生未知错误: {e}")
return False
def get_database_info() -> Optional[dict]:
"""
获取数据库信息
Returns:
dict: 数据库信息字典,包含引擎信息等
"""
try:
engine = get_engine()
if engine is None:
return None
info = {
'engine_name': engine.name,
'driver': engine.driver,
'url': str(engine.url).replace(engine.url.password or '', '***'), # 隐藏密码
'pool_size': getattr(engine.pool, 'size', None),
'max_overflow': getattr(engine.pool, 'max_overflow', None),
}
return info
except Exception as e:
logger.error(f"获取数据库信息失败: {e}")
return None
_database_initialized = False
def initialize_database_compat() -> bool:
"""
兼容性数据库初始化函数
用于替换原有的Peewee初始化代码
Returns:
bool: 初始化是否成功
"""
global _database_initialized
if _database_initialized:
return True
success = initialize_sqlalchemy_database()
if success:
success = create_all_tables()
if success:
success = check_database_connection()
if success:
_database_initialized = True
return success

View File

@@ -0,0 +1,555 @@
"""SQLAlchemy数据库模型定义
替换Peewee ORM使用SQLAlchemy提供更好的连接池管理和错误恢复能力
"""
from sqlalchemy import Column, String, Float, Integer, Boolean, Text, Index, create_engine, DateTime
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import QueuePool
import os
import datetime
from src.config.config import global_config
from src.common.logger import get_logger
import threading
from contextlib import contextmanager
logger = get_logger("sqlalchemy_models")
# 创建基类
Base = declarative_base()
# MySQL兼容的字段类型辅助函数
def get_string_field(max_length=255, **kwargs):
"""
根据数据库类型返回合适的字符串字段
MySQL需要指定长度的VARCHAR用于索引SQLite可以使用Text
"""
if global_config.database.database_type == "mysql":
return String(max_length, **kwargs)
else:
return Text(**kwargs)
class SessionProxy:
"""线程安全的Session代理类自动管理session生命周期"""
def __init__(self):
self._local = threading.local()
def _get_current_session(self):
"""获取当前线程的session如果没有则创建新的"""
if not hasattr(self._local, 'session') or self._local.session is None:
_, SessionLocal = initialize_database()
self._local.session = SessionLocal()
return self._local.session
def _close_current_session(self):
"""关闭当前线程的session"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
self._local.session.close()
except:
pass
finally:
self._local.session = None
def __getattr__(self, name):
"""代理所有session方法"""
session = self._get_current_session()
attr = getattr(session, name)
# 如果是方法,需要特殊处理一些关键方法
if callable(attr):
if name in ['commit', 'rollback']:
def wrapper(*args, **kwargs):
try:
result = attr(*args, **kwargs)
if name == 'commit':
# commit后不要清除session只是刷新状态
pass # 保持session活跃
return result
except Exception as e:
try:
if session and hasattr(session, 'rollback'):
session.rollback()
except:
pass
# 发生错误时重新创建session
self._close_current_session()
raise
return wrapper
elif name == 'close':
def wrapper(*args, **kwargs):
result = attr(*args, **kwargs)
self._close_current_session()
return result
return wrapper
elif name in ['execute', 'query', 'add', 'delete', 'merge']:
def wrapper(*args, **kwargs):
try:
return attr(*args, **kwargs)
except Exception as e:
# 如果是连接相关错误重新创建session再试一次
if "not bound to a Session" in str(e) or "provisioning a new connection" in str(e):
logger.warning(f"Session问题重新创建session: {e}")
self._close_current_session()
new_session = self._get_current_session()
new_attr = getattr(new_session, name)
return new_attr(*args, **kwargs)
raise
return wrapper
return attr
def new_session(self):
"""强制创建新的session关闭当前的创建新的"""
self._close_current_session()
return self._get_current_session()
def ensure_fresh_session(self):
"""确保使用新鲜的session如果当前session有问题则重新创建"""
if hasattr(self._local, 'session') and self._local.session is not None:
try:
# 测试session是否还可用
self._local.session.execute("SELECT 1")
except Exception:
# session有问题重新创建
self._close_current_session()
return self._get_current_session()
# 创建全局session代理实例
_global_session_proxy = SessionProxy()
def get_session():
"""返回线程安全的session代理自动管理生命周期"""
return _global_session_proxy
class ChatStreams(Base):
"""聊天流模型"""
__tablename__ = 'chat_streams'
id = Column(Integer, primary_key=True, autoincrement=True)
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
create_time = Column(Float, nullable=False)
group_platform = Column(Text, nullable=True)
group_id = Column(get_string_field(100), nullable=True, index=True)
group_name = Column(Text, nullable=True)
last_active_time = Column(Float, nullable=False)
platform = Column(Text, nullable=False)
user_platform = Column(Text, nullable=False)
user_id = Column(get_string_field(100), nullable=False, index=True)
user_nickname = Column(Text, nullable=False)
user_cardname = Column(Text, nullable=True)
__table_args__ = (
Index('idx_chatstreams_stream_id', 'stream_id'),
Index('idx_chatstreams_user_id', 'user_id'),
Index('idx_chatstreams_group_id', 'group_id'),
)
class LLMUsage(Base):
"""LLM使用记录模型"""
__tablename__ = 'llm_usage'
id = Column(Integer, primary_key=True, autoincrement=True)
model_name = Column(get_string_field(100), nullable=False, index=True)
user_id = Column(get_string_field(50), nullable=False, index=True)
request_type = Column(get_string_field(50), nullable=False, index=True)
endpoint = Column(Text, nullable=False)
prompt_tokens = Column(Integer, nullable=False)
completion_tokens = Column(Integer, nullable=False)
total_tokens = Column(Integer, nullable=False)
cost = Column(Float, nullable=False)
status = Column(Text, nullable=False)
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
__table_args__ = (
Index('idx_llmusage_model_name', 'model_name'),
Index('idx_llmusage_user_id', 'user_id'),
Index('idx_llmusage_request_type', 'request_type'),
Index('idx_llmusage_timestamp', 'timestamp'),
)
class Emoji(Base):
"""表情包模型"""
__tablename__ = 'emoji'
id = Column(Integer, primary_key=True, autoincrement=True)
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
format = Column(Text, nullable=False)
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
query_count = Column(Integer, nullable=False, default=0)
is_registered = Column(Boolean, nullable=False, default=False)
is_banned = Column(Boolean, nullable=False, default=False)
emotion = Column(Text, nullable=True)
record_time = Column(Float, nullable=False)
register_time = Column(Float, nullable=True)
usage_count = Column(Integer, nullable=False, default=0)
last_used_time = Column(Float, nullable=True)
__table_args__ = (
Index('idx_emoji_full_path', 'full_path'),
Index('idx_emoji_hash', 'emoji_hash'),
)
class Messages(Base):
"""消息模型"""
__tablename__ = 'messages'
id = Column(Integer, primary_key=True, autoincrement=True)
message_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
reply_to = Column(Text, nullable=True)
interest_value = Column(Float, nullable=True)
is_mentioned = Column(Boolean, nullable=True)
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
chat_info_user_platform = Column(Text, nullable=False)
chat_info_user_id = Column(Text, nullable=False)
chat_info_user_nickname = Column(Text, nullable=False)
chat_info_user_cardname = Column(Text, nullable=True)
chat_info_group_platform = Column(Text, nullable=True)
chat_info_group_id = Column(Text, nullable=True)
chat_info_group_name = Column(Text, nullable=True)
chat_info_create_time = Column(Float, nullable=False)
chat_info_last_active_time = Column(Float, nullable=False)
# 从顶层 user_info 扁平化而来的字段
user_platform = Column(Text, nullable=True)
user_id = Column(get_string_field(100), nullable=True, index=True)
user_nickname = Column(Text, nullable=True)
user_cardname = Column(Text, nullable=True)
processed_plain_text = Column(Text, nullable=True)
display_message = Column(Text, nullable=True)
memorized_times = Column(Integer, nullable=False, default=0)
priority_mode = Column(Text, nullable=True)
priority_info = Column(Text, nullable=True)
additional_config = Column(Text, nullable=True)
is_emoji = Column(Boolean, nullable=False, default=False)
is_picid = Column(Boolean, nullable=False, default=False)
is_command = Column(Boolean, nullable=False, default=False)
is_notify = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index('idx_messages_message_id', 'message_id'),
Index('idx_messages_chat_id', 'chat_id'),
Index('idx_messages_time', 'time'),
Index('idx_messages_user_id', 'user_id'),
)
class ActionRecords(Base):
"""动作记录模型"""
__tablename__ = 'action_records'
id = Column(Integer, primary_key=True, autoincrement=True)
action_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
action_name = Column(Text, nullable=False)
action_data = Column(Text, nullable=False)
action_done = Column(Boolean, nullable=False, default=False)
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
action_prompt_display = Column(Text, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
__table_args__ = (
Index('idx_actionrecords_action_id', 'action_id'),
Index('idx_actionrecords_chat_id', 'chat_id'),
Index('idx_actionrecords_time', 'time'),
)
class Images(Base):
"""图像信息模型"""
__tablename__ = 'images'
id = Column(Integer, primary_key=True, autoincrement=True)
image_id = Column(Text, nullable=False, default="")
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=True)
path = Column(get_string_field(500), nullable=False, unique=True)
count = Column(Integer, nullable=False, default=1)
timestamp = Column(Float, nullable=False)
type = Column(Text, nullable=False)
vlm_processed = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index('idx_images_emoji_hash', 'emoji_hash'),
Index('idx_images_path', 'path'),
)
class ImageDescriptions(Base):
"""图像描述信息模型"""
__tablename__ = 'image_descriptions'
id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(Text, nullable=False)
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
timestamp = Column(Float, nullable=False)
__table_args__ = (
Index('idx_imagedesc_hash', 'image_description_hash'),
)
class OnlineTime(Base):
"""在线时长记录模型"""
__tablename__ = 'online_time'
id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
duration = Column(Integer, nullable=False)
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
end_timestamp = Column(DateTime, nullable=False, index=True)
__table_args__ = (
Index('idx_onlinetime_end_timestamp', 'end_timestamp'),
)
class PersonInfo(Base):
"""人物信息模型"""
__tablename__ = 'person_info'
id = Column(Integer, primary_key=True, autoincrement=True)
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
person_name = Column(Text, nullable=True)
name_reason = Column(Text, nullable=True)
platform = Column(Text, nullable=False)
user_id = Column(get_string_field(50), nullable=False, index=True)
nickname = Column(Text, nullable=True)
impression = Column(Text, nullable=True)
short_impression = Column(Text, nullable=True)
points = Column(Text, nullable=True)
forgotten_points = Column(Text, nullable=True)
info_list = Column(Text, nullable=True)
know_times = Column(Float, nullable=True)
know_since = Column(Float, nullable=True)
last_know = Column(Float, nullable=True)
attitude = Column(Integer, nullable=True, default=50)
__table_args__ = (
Index('idx_personinfo_person_id', 'person_id'),
Index('idx_personinfo_user_id', 'user_id'),
)
class Memory(Base):
"""记忆模型"""
__tablename__ = 'memory'
id = Column(Integer, primary_key=True, autoincrement=True)
memory_id = Column(get_string_field(64), nullable=False, index=True)
chat_id = Column(Text, nullable=True)
memory_text = Column(Text, nullable=True)
keywords = Column(Text, nullable=True)
create_time = Column(Float, nullable=True)
last_view_time = Column(Float, nullable=True)
__table_args__ = (
Index('idx_memory_memory_id', 'memory_id'),
)
class Expression(Base):
"""表达风格模型"""
__tablename__ = 'expression'
id = Column(Integer, primary_key=True, autoincrement=True)
situation = Column(Text, nullable=False)
style = Column(Text, nullable=False)
count = Column(Float, nullable=False)
last_active_time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
type = Column(Text, nullable=False)
create_date = Column(Float, nullable=True)
__table_args__ = (
Index('idx_expression_chat_id', 'chat_id'),
)
class ThinkingLog(Base):
"""思考日志模型"""
__tablename__ = 'thinking_logs'
id = Column(Integer, primary_key=True, autoincrement=True)
chat_id = Column(get_string_field(64), nullable=False, index=True)
trigger_text = Column(Text, nullable=True)
response_text = Column(Text, nullable=True)
trigger_info_json = Column(Text, nullable=True)
response_info_json = Column(Text, nullable=True)
timing_results_json = Column(Text, nullable=True)
chat_history_json = Column(Text, nullable=True)
chat_history_in_thinking_json = Column(Text, nullable=True)
chat_history_after_response_json = Column(Text, nullable=True)
heartflow_data_json = Column(Text, nullable=True)
reasoning_data_json = Column(Text, nullable=True)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (
Index('idx_thinkinglog_chat_id', 'chat_id'),
)
class GraphNodes(Base):
"""记忆图节点模型"""
__tablename__ = 'graph_nodes'
id = Column(Integer, primary_key=True, autoincrement=True)
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
memory_items = Column(Text, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (
Index('idx_graphnodes_concept', 'concept'),
)
class GraphEdges(Base):
"""记忆图边模型"""
__tablename__ = 'graph_edges'
id = Column(Integer, primary_key=True, autoincrement=True)
source = Column(get_string_field(255), nullable=False, index=True)
target = Column(get_string_field(255), nullable=False, index=True)
strength = Column(Integer, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (
Index('idx_graphedges_source', 'source'),
Index('idx_graphedges_target', 'target'),
)
# 数据库引擎和会话管理
_engine = None
_SessionLocal = None
def get_database_url():
"""获取数据库连接URL"""
config = global_config.database
if config.database_type == "mysql":
# 对用户名和密码进行URL编码处理特殊字符
from urllib.parse import quote_plus
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
return (
f"mysql+pymysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
else: # SQLite
# 如果是相对路径,则相对于项目根目录
if not os.path.isabs(config.sqlite_path):
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
return f"sqlite:///{db_path}"
def initialize_database():
"""初始化数据库引擎和会话"""
global _engine, _SessionLocal
if _engine is not None:
return _engine, _SessionLocal
database_url = get_database_url()
config = global_config.database
# 配置引擎参数
engine_kwargs = {
'echo': False, # 生产环境关闭SQL日志
'future': True,
}
if config.database_type == "mysql":
# MySQL连接池配置
engine_kwargs.update({
'poolclass': QueuePool,
'pool_size': config.connection_pool_size,
'max_overflow': config.connection_pool_size * 2,
'pool_timeout': config.connection_timeout,
'pool_recycle': 3600, # 1小时回收连接
'pool_pre_ping': True, # 连接前ping检查
'connect_args': {
'autocommit': config.mysql_autocommit,
'charset': config.mysql_charset,
'connect_timeout': config.connection_timeout,
'read_timeout': 30,
'write_timeout': 30,
}
})
else:
# SQLite配置 - 添加连接池设置以避免连接耗尽
engine_kwargs.update({
'poolclass': QueuePool,
'pool_size': 20, # 增加池大小
'max_overflow': 30, # 增加溢出连接数
'pool_timeout': 60, # 增加超时时间
'pool_recycle': 3600, # 1小时回收连接
'pool_pre_ping': True, # 连接前ping检查
'connect_args': {
'check_same_thread': False,
'timeout': 30,
}
})
_engine = create_engine(database_url, **engine_kwargs)
_SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=_engine)
# 创建所有表
Base.metadata.create_all(bind=_engine)
logger.info(f"SQLAlchemy数据库初始化成功: {config.database_type}")
return _engine, _SessionLocal
@contextmanager
def get_db_session():
"""数据库会话上下文管理器 - 推荐使用这个而不是get_session()"""
session = None
try:
_, SessionLocal = initialize_database()
session = SessionLocal()
yield session
session.commit()
except Exception as e:
if session:
session.rollback()
raise
finally:
if session:
session.close()
def get_engine():
"""获取数据库引擎"""
engine, _ = initialize_database()
return engine

View File

@@ -373,6 +373,7 @@ MODULE_COLORS = {
"base_command": "\033[38;5;208m", # 橙色
"component_registry": "\033[38;5;214m", # 橙黄色
"stream_api": "\033[38;5;220m", # 黄色
"plugin_hot_reload": "\033[38;5;226m", #品红色
"config_api": "\033[38;5;226m", # 亮黄色
"heartflow_api": "\033[38;5;154m", # 黄绿色
"action_apis": "\033[38;5;118m", # 绿色
@@ -406,6 +407,7 @@ MODULE_COLORS = {
"base_action": "\033[38;5;250m", # 浅灰色
# 数据库和消息
"database_model": "\033[38;5;94m", # 橙褐色
"database": "\033[38;5;46m", # 橙褐色
"maim_message": "\033[38;5;140m", # 紫褐色
# 日志系统
"logger": "\033[38;5;8m", # 深灰色
@@ -430,6 +432,8 @@ MODULE_ALIASES = {
"memory_activator": "记忆",
"tool_use": "工具",
"expressor": "表达方式",
"plugin_hot_reload": "热重载",
"database": "数据库",
"database_model": "数据库",
"mood": "情绪",
"memory": "记忆",

View File

@@ -1,20 +1,26 @@
import traceback
from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
from typing import List, Optional, Any, Dict
from sqlalchemy import not_, select, func
from sqlalchemy.orm import DeclarativeBase
from src.config.config import global_config
from src.common.database.database_model import Messages
# from src.common.database.database_model import Messages
from src.common.database.sqlalchemy_models import Messages
from src.common.database.sqlalchemy_database_api import get_session
from src.common.logger import get_logger
logger = get_logger(__name__)
class Base(DeclarativeBase):
pass
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
def _model_to_dict(instance: Base) -> Dict[str, Any]:
"""
Peewee 模型实例转换为字典。
SQLAlchemy 模型实例转换为字典。
"""
return model_instance.__data__
return {col.name: getattr(instance, col.name) for col in instance.__table__.columns}
def find_messages(
@@ -38,7 +44,8 @@ def find_messages(
消息字典列表,如果出错则返回空列表。
"""
try:
query = Messages.select()
session = get_session()
query = select(Messages)
# 应用过滤器
if message_filter:
@@ -77,42 +84,57 @@ def find_messages(
query = query.where(Messages.user_id != global_config.bot.qq_account)
if filter_command:
query = query.where(not Messages.is_command)
query = query.where(not_(Messages.is_command))
if limit > 0:
# 确保limit是正整数
limit = max(1, int(limit))
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.order_by(Messages.time.asc()).limit(limit)
peewee_results = list(query)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行earliest查询失败: {e}")
results = []
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.order_by(Messages.time.desc()).limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
try:
latest_results = session.execute(query).scalars().all()
# 将结果按时间正序排列
results = sorted(latest_results, key=lambda msg: msg.time)
except Exception as e:
logger.error(f"执行latest查询失败: {e}")
results = []
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
peewee_sort_terms = []
sort_terms = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field.asc())
sort_terms.append(field.asc())
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
if sort_terms:
query = query.order_by(*sort_terms)
try:
results = session.execute(query).scalars().all()
except Exception as e:
logger.error(f"执行无限制查询失败: {e}")
results = []
return [_model_to_dict(msg) for msg in peewee_results]
return [_model_to_dict(msg) for msg in results]
except Exception as e:
log_message = (
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
f"使用 SQLAlchemy 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
@@ -130,7 +152,8 @@ def count_messages(message_filter: dict[str, Any]) -> int:
符合条件的消息数量,如果出错则返回 0。
"""
try:
query = Messages.select()
session = get_session()
query = select(func.count(Messages.id))
# 应用过滤器
if message_filter:
@@ -167,14 +190,14 @@ def count_messages(message_filter: dict[str, Any]) -> int:
if conditions:
query = query.where(*conditions)
count = query.count()
return count
count = session.execute(query).scalar()
return count or 0
except Exception as e:
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
log_message = f"使用 SQLAlchemy 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message)
return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。
# 注意:对于 SQLAlchemy插入操作通常是使用 session.add() 和 session.commit()。
# 查找单个消息可以使用 session.execute(select(Messages).where(...)).scalar_one_or_none()。