Files
Mofox-Core/src/common/database/database_model.py
cuckoo711 939f17890a refactor(database): 重构数据库初始化和字段检查逻辑
- 添加对 MySQL 数据库的支持
- 优化字段检查和添加逻辑,处理 NOT NULL 字段和默认值
- 改进错误处理和日志记录
- 调整表和字段操作的 SQL 语句以适应不同数据库类型
2025-08-07 11:22:01 +08:00

546 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import datetime
from peewee import BooleanField, CharField, DateTimeField, DoubleField, FloatField, IntegerField, Model, TextField
from src.common.database.database import db
from src.common.logger import get_logger
from src.config.config import global_config
table_prefix = global_config.data_base.table_prefix
logger = get_logger("database_model")
logger.info(f"正在加载数据库模型...数据库表前缀为: {table_prefix}")
# 请在此处定义您的数据库实例。
# 您需要取消注释并配置适合您的数据库的部分。
# 例如,对于 SQLite:
# db = SqliteDatabase('MaiBot.db')
#
# 对于 PostgreSQL:
# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=5432)
#
# 对于 MySQL:
# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=3306)
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
# 这允许您在一个地方为所有模型指定数据库。
class BaseModel(Model):
class Meta:
# 将下面的 'db' 替换为您实际的数据库实例变量名。
database = db # 例如: database = my_actual_db_instance
pass # 在用户定义数据库实例之前,此处为占位符
class ChatStreams(BaseModel):
"""
用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
"""
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
stream_id = CharField(max_length=64, unique=True)
# create_time: 1746096761.4490178 (时间戳精确到小数点后7位)
# DoubleField 用于存储浮点数,适合此类时间戳。
create_time = DoubleField()
# group_info 字段:
# platform: "qq"
# group_id: "941657197"
# group_name: "测试"
group_platform = TextField(null=True) # 群聊信息可能不存在
group_id = TextField(null=True)
group_name = TextField(null=True)
# last_active_time: 1746623771.4825106 (时间戳精确到小数点后7位)
last_active_time = DoubleField()
# platform: "qq" (顶层平台字段)
platform = TextField()
# user_info 字段:
# platform: "qq"
# user_id: "1787882683"
# user_nickname: "墨梓柒(IceSakurary)"
# user_cardname: ""
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
user_cardname = TextField(null=True)
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
# database = db
table_name = table_prefix + "chat_streams" # 可选:明确指定数据库中的表名
class LLMUsage(BaseModel):
"""
用于存储 API 使用日志数据的模型。
"""
model_name = CharField(max_length=64, index=True) # 添加索引
user_id = CharField(max_length=64, index=True) # 添加索引
request_type = CharField(max_length=64, index=True) # 添加索引
endpoint = TextField()
prompt_tokens = IntegerField()
completion_tokens = IntegerField()
total_tokens = IntegerField()
cost = DoubleField()
status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = table_prefix + "llm_usage"
class Emoji(BaseModel):
"""表情包"""
full_path = CharField(max_length=512, unique=True) # 文件的完整路径 (包括文件名)
format = TextField() # 图片格式
emoji_hash = CharField(max_length=64, index=True) # 表情包的哈希值
description = TextField() # 表情包的描述
query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数)
is_registered = BooleanField(default=False) # 是否已注册
is_banned = BooleanField(default=False) # 是否被禁止注册
# emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化
emotion = TextField(null=True)
record_time = FloatField() # 记录时间(被创建的时间)
register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间)
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
last_used_time = FloatField(null=True) # 上次使用时间
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "emoji"
class Messages(BaseModel):
"""
用于存储消息数据的模型。
"""
message_id = CharField(max_length=128, index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
chat_id = CharField(max_length=128, index=True) # 对应的 ChatStreams stream_id
reply_to = TextField(null=True)
interest_value = DoubleField(null=True)
is_mentioned = BooleanField(null=True)
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = TextField()
chat_info_platform = TextField()
chat_info_user_platform = TextField()
chat_info_user_id = TextField()
chat_info_user_nickname = TextField()
chat_info_user_cardname = TextField(null=True)
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
chat_info_group_id = TextField(null=True)
chat_info_group_name = TextField(null=True)
chat_info_create_time = DoubleField()
chat_info_last_active_time = DoubleField()
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
user_cardname = TextField(null=True)
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
display_message = TextField(null=True) # 显示的消息
memorized_times = IntegerField(default=0) # 被记忆的次数
priority_mode = TextField(null=True)
priority_info = TextField(null=True)
additional_config = TextField(null=True)
is_emoji = BooleanField(default=False)
is_picid = BooleanField(default=False)
is_command = BooleanField(default=False)
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "messages"
class ActionRecords(BaseModel):
"""
用于存储动作记录数据的模型。
"""
action_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
action_name = TextField()
action_data = TextField()
action_done = BooleanField(default=False)
action_build_into_prompt = BooleanField(default=False)
action_prompt_display = TextField()
chat_id = CharField(max_length=128, index=True) # 对应的 ChatStreams stream_id
chat_info_stream_id = TextField()
chat_info_platform = TextField()
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "action_records"
class Images(BaseModel):
"""
用于存储图像信息的模型。
"""
image_id = TextField(default="") # 图片唯一ID
emoji_hash = CharField(max_length=64, index=True) # 图像的哈希值
description = TextField(null=True) # 图像的描述
path = CharField(max_length=512, unique=True) # 图像文件的路径
# base64 = TextField() # 图片的base64编码
count = IntegerField(default=1) # 图片被引用的次数
timestamp = FloatField() # 时间戳
type = TextField() # 图像类型,例如 "emoji"
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
class Meta:
table_name = table_prefix + "images"
class ImageDescriptions(BaseModel):
"""
用于存储图像描述信息的模型。
"""
type = TextField() # 类型,例如 "emoji"
image_description_hash = CharField(max_length=64, index=True) # 图像的哈希值
description = TextField() # 图像的描述
timestamp = FloatField() # 时间戳
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "image_descriptions"
class OnlineTime(BaseModel):
"""
用于存储在线时长记录的模型。
"""
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
timestamp = CharField(max_length=64, default=datetime.datetime.now) # 时间戳
duration = IntegerField() # 时长,单位分钟
start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "online_time"
class PersonInfo(BaseModel):
"""
用于存储个人信息数据的模型。
"""
person_id = CharField(max_length=64, unique=True) # 个人唯一ID
person_name = TextField(null=True) # 个人名称 (允许为空)
name_reason = TextField(null=True) # 名称设定的原因
platform = TextField() # 平台
user_id = CharField(max_length=64, index=True) # 用户ID
nickname = TextField() # 用户昵称
impression = TextField(null=True) # 个人印象
short_impression = TextField(null=True) # 个人印象的简短描述
points = TextField(null=True) # 个人印象的点
forgotten_points = TextField(null=True) # 被遗忘的点
info_list = TextField(null=True) # 与Bot的互动
know_times = FloatField(null=True) # 认识时间 (时间戳)
know_since = FloatField(null=True) # 首次印象总结时间
last_know = FloatField(null=True) # 最后一次印象总结时间
attitude = IntegerField(null=True, default=50) # 态度0-100从非常厌恶到十分喜欢
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "person_info"
class Memory(BaseModel):
memory_id = CharField(max_length=128, index=True)
chat_id = TextField(null=True)
memory_text = TextField(null=True)
keywords = TextField(null=True)
create_time = FloatField(null=True)
last_view_time = FloatField(null=True)
class Meta:
table_name = table_prefix + "memory"
class Expression(BaseModel):
"""
用于存储表达风格的模型。
"""
situation = TextField()
style = TextField()
count = FloatField()
last_active_time = FloatField()
chat_id = CharField(max_length=128, index=True)
type = TextField()
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
class Meta:
table_name = table_prefix + "expression"
class ThinkingLog(BaseModel):
chat_id = CharField(max_length=128, index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
# Store complex dicts/lists as JSON strings
trigger_info_json = TextField(null=True)
response_info_json = TextField(null=True)
timing_results_json = TextField(null=True)
chat_history_json = TextField(null=True)
chat_history_in_thinking_json = TextField(null=True)
chat_history_after_response_json = TextField(null=True)
heartflow_data_json = TextField(null=True)
reasoning_data_json = TextField(null=True)
# Add a timestamp for the log entry itself
# Ensure you have: from peewee import DateTimeField
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = table_prefix + "thinking_logs"
class GraphNodes(BaseModel):
"""
用于存储记忆图节点的模型
"""
concept = CharField(max_length=128, unique=True) # 节点概念
memory_items = TextField() # JSON格式存储的记忆列表
hash = TextField() # 节点哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = table_prefix + "graph_nodes"
class GraphEdges(BaseModel):
"""
用于存储记忆图边的模型
"""
source = CharField(max_length=128, index=True) # 源节点
target = CharField(max_length=128, index=True) # 目标节点
strength = IntegerField() # 连接强度
hash = TextField() # 边哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = table_prefix + "graph_edges"
def create_tables():
"""
创建所有在模型中定义的数据库表。
"""
with db:
db.create_tables(
[
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Expression,
ThinkingLog,
GraphNodes, # 添加图节点表
GraphEdges, # 添加图边表
Memory,
ActionRecords, # 添加 ActionRecords 到初始化列表
]
)
def initialize_database():
"""
检查所有定义的表是否存在,如果不存在则创建它们。
检查所有表的所有字段是否存在,如果缺失则自动添加。
"""
models = [
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Expression,
Memory,
ThinkingLog,
GraphNodes,
GraphEdges,
ActionRecords,
]
# 保持 del_extra 为 False以避免在生产环境中意外删除数据。
# 如果需要删除多余字段,请谨慎设置为 True。
del_extra = False
# 辅助函数:根据字段对象和数据库类型获取对应的 SQL 类型字符串
def get_sql_type(field_obj, db_type):
field_type_name = field_obj.__class__.__name__
if db_type == "sqlite":
return {
"TextField": "TEXT",
"IntegerField": "INTEGER",
"FloatField": "FLOAT",
"DoubleField": "DOUBLE",
"BooleanField": "INTEGER",
"DateTimeField": "DATETIME",
}.get(field_type_name, "TEXT")
elif db_type == "mysql":
# CharField 的 max_length 将在主循环中单独处理
return {
"TextField": "LONGTEXT", # MySQL TEXT 类型长度有限LONGTEXT 更安全
"IntegerField": "INT",
"FloatField": "FLOAT",
"DoubleField": "DOUBLE",
"BooleanField": "TINYINT(1)", # MySQL 布尔值存储为 TINYINT(1)
"DateTimeField": "DATETIME",
}.get(field_type_name, "TEXT")
logger.error(f"不支持的数据库类型: {db_type}")
return "TEXT" # 默认回退类型
# 辅助函数:将 Peewee 字段的默认值转换为 SQL 语句中的 DEFAULT 子句
def get_sql_default_value(field_obj):
if field_obj.default is None:
return "" # 没有定义默认值
# 可调用默认值(如 datetime.datetime.now无法直接转换为 SQL DDL 的 DEFAULT 子句
# 因此,对于这类情况,我们不生成 DEFAULT 子句,并依赖 Peewee 在应用层处理
# 如果字段为 NOT NULL 且无法提供字面默认值,则需要在 ADD COLUMN 时临时设为 NULLABLE
if callable(field_obj.default):
return ""
default_value = field_obj.default
if isinstance(default_value, str):
# 字符串默认值需要用单引号括起来,并对内部的单引号进行转义
escaped_value = str(default_value).replace("'", "''")
return f" DEFAULT '{escaped_value}'"
elif isinstance(default_value, bool):
return f" DEFAULT {int(default_value)}" # 布尔值转换为 0 或 1
elif isinstance(default_value, (int, float)):
return f" DEFAULT {default_value}"
return "" # 其他无法直接转换为 SQL 字面值的类型
try:
with db:
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
logger.warning(f"'{table_name}' 未找到,正在创建...")
db.create_tables([model])
logger.info(f"'{table_name}' 创建成功")
# 表刚创建,无需检查字段
continue
# 获取现有列
db_type = global_config.data_base.db_type
if db_type == "sqlite":
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
existing_columns = {row[1] for row in cursor.fetchall()}
elif db_type == "mysql":
cursor = db.execute_sql(f"SHOW COLUMNS FROM {table_name}")
existing_columns = {row[0] for row in cursor.fetchall()}
else:
logger.error(f"不支持的数据库类型 '{db_type}',跳过表 '{table_name}' 的字段检查。")
continue
model_fields = set(model._meta.fields.keys())
# 识别并添加缺失字段
missing_fields = model_fields - existing_columns
if missing_fields:
logger.warning(f"'{table_name}' 缺失字段: {missing_fields}")
for field_name, field_obj in model._meta.fields.items():
if field_name not in existing_columns:
logger.info(f"'{table_name}' 缺失字段 '{field_name}',正在尝试添加...")
sql_type = get_sql_type(field_obj, db_type)
# 特殊处理 MySQL 的 CharField需要 max_length
if isinstance(field_obj, CharField) and db_type == "mysql":
sql_type = f"VARCHAR({field_obj.max_length})"
null_clause = " NULL" if field_obj.null else " NOT NULL"
default_clause = get_sql_default_value(field_obj)
# 如果字段定义为 NOT NULL 且无法在 SQL DDL 中提供字面默认值 (如可调用默认值)
# 为了避免在有数据的表中添加列时失败,暂时将其添加为 NULLABLE。
# 这是一种务实的兼容性处理,后续可能需要手动回填数据并修改为 NOT NULL。
if not field_obj.null and not default_clause:
logger.warning(
f"'{table_name}' 的字段 '{field_name}' 为 NOT NULL 但无法生成SQL默认值"
f"将暂时添加为 NULLABLE 以避免现有数据行错误。"
)
null_clause = " NULL" # 强制设为 NULLABLE
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}{null_clause}{default_clause}"
try:
db.execute_sql(alter_sql)
logger.info(f"字段 '{field_name}' 添加成功")
except Exception as e:
logger.error(f"添加字段 '{field_name}' 失败: {e}. SQL 语句: {alter_sql}")
# 检查并删除多余字段(根据 del_extra 旗标决定)
if del_extra:
extra_fields = existing_columns - model_fields
if extra_fields:
logger.warning(f"'{table_name}' 存在模型中未定义的字段: {extra_fields}")
for field_name in extra_fields:
try:
logger.warning(f"'{table_name}' 正在尝试删除多余字段 '{field_name}'...")
db.execute_sql(f"ALTER TABLE {table_name} DROP COLUMN {field_name}")
logger.info(f"字段 '{field_name}' 删除成功")
except Exception as e:
logger.error(f"删除字段 '{field_name}' 失败: {e}")
except Exception as e:
logger.exception(f"数据库初始化过程中发生异常: {e}")
# 如果初始化失败(例如数据库不可用),则退出
return
logger.info("数据库初始化完成")
# 模块加载时调用初始化函数
initialize_database()