Merge pull request #1166 from MaiM-with-u/revert-1164-dev

Revert "feat(database): 添加MySQL支持并重构数据库配置"
This commit is contained in:
墨梓柒
2025-08-07 13:10:21 +08:00
committed by GitHub
5 changed files with 125 additions and 249 deletions

View File

@@ -1,11 +1,9 @@
import os
from pymongo import MongoClient
from peewee import MySQLDatabase, SqliteDatabase
from peewee import SqliteDatabase
from pymongo.database import Database
from rich.traceback import install
from src.config.config import global_config
install(extra_lines=3)
_client = None
@@ -59,24 +57,19 @@ class DBWrapper:
return get_db()[key] # type: ignore
def create_peewee_database():
data_base_config = global_config.data_base
# 全局数据库访问点
memory_db: Database = DBWrapper() # type: ignore
if data_base_config.db_type == "mysql":
return MySQLDatabase(
data_base_config.database,
user=data_base_config.username,
password=data_base_config.password,
host=data_base_config.host,
port=int(data_base_config.port),
charset='utf8mb4'
)
elif data_base_config.db_type == "sqlite":
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_DB_DIR = os.path.join(ROOT_PATH, "data")
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
# 确保数据库目录存在
os.makedirs(_DB_DIR, exist_ok=True)
return SqliteDatabase(
# 全局 Peewee SQLite 数据库访问点
db = SqliteDatabase(
_DB_FILE,
pragmas={
"journal_mode": "wal", # WAL模式提高并发性能
@@ -85,13 +78,5 @@ def create_peewee_database():
"ignore_check_constraints": 0,
"synchronous": 0, # 异步写入提高性能
"busy_timeout": 1000, # 1秒超时而不是3秒
}, )
else:
raise ValueError(f"Unsupported PEEWEE_DB_TYPE: {data_base_config.db_type}")
# 全局数据库访问点
memory_db: Database | DBWrapper = DBWrapper()
# 全局 Peewee SQLite 数据库访问点
db = create_peewee_database()
},
)

View File

@@ -1,16 +1,9 @@
from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField
from .database import db
import datetime
from peewee import BooleanField, CharField, DateTimeField, DoubleField, FloatField, IntegerField, Model, TextField
from src.common.database.database import db
from src.common.logger import get_logger
from src.config.config import global_config
table_prefix = global_config.data_base.table_prefix
logger = get_logger("database_model")
logger.info(f"正在加载数据库模型...数据库表前缀为: {table_prefix}")
# 请在此处定义您的数据库实例。
# 您需要取消注释并配置适合您的数据库的部分。
# 例如,对于 SQLite:
@@ -41,7 +34,7 @@ class ChatStreams(BaseModel):
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
stream_id = CharField(max_length=64, unique=True)
stream_id = TextField(unique=True, index=True)
# create_time: 1746096761.4490178 (时间戳精确到小数点后7位)
# DoubleField 用于存储浮点数,适合此类时间戳。
@@ -77,7 +70,7 @@ class ChatStreams(BaseModel):
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
# database = db
table_name = table_prefix + "chat_streams" # 可选:明确指定数据库中的表名
table_name = "chat_streams" # 可选:明确指定数据库中的表名
class LLMUsage(BaseModel):
@@ -85,9 +78,9 @@ class LLMUsage(BaseModel):
用于存储 API 使用日志数据的模型。
"""
model_name = CharField(max_length=64, index=True) # 添加索引
user_id = CharField(max_length=64, index=True) # 添加索引
request_type = CharField(max_length=64, index=True) # 添加索引
model_name = TextField(index=True) # 添加索引
user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引
endpoint = TextField()
prompt_tokens = IntegerField()
completion_tokens = IntegerField()
@@ -99,15 +92,15 @@ class LLMUsage(BaseModel):
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = table_prefix + "llm_usage"
table_name = "llm_usage"
class Emoji(BaseModel):
"""表情包"""
full_path = CharField(max_length=512, unique=True) # 文件的完整路径 (包括文件名)
full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名)
format = TextField() # 图片格式
emoji_hash = CharField(max_length=64, index=True) # 表情包的哈希值
emoji_hash = TextField(index=True) # 表情包的哈希值
description = TextField() # 表情包的描述
query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数)
is_registered = BooleanField(default=False) # 是否已注册
@@ -121,7 +114,7 @@ class Emoji(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "emoji"
table_name = "emoji"
class Messages(BaseModel):
@@ -129,10 +122,10 @@ class Messages(BaseModel):
用于存储消息数据的模型。
"""
message_id = CharField(max_length=128, index=True) # 消息 ID (更改自 IntegerField)
message_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
chat_id = CharField(max_length=128, index=True) # 对应的 ChatStreams stream_id
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
reply_to = TextField(null=True)
@@ -172,7 +165,7 @@ class Messages(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "messages"
table_name = "messages"
class ActionRecords(BaseModel):
@@ -190,13 +183,13 @@ class ActionRecords(BaseModel):
action_build_into_prompt = BooleanField(default=False)
action_prompt_display = TextField()
chat_id = CharField(max_length=128, index=True) # 对应的 ChatStreams stream_id
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
chat_info_stream_id = TextField()
chat_info_platform = TextField()
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "action_records"
table_name = "action_records"
class Images(BaseModel):
@@ -205,9 +198,9 @@ class Images(BaseModel):
"""
image_id = TextField(default="") # 图片唯一ID
emoji_hash = CharField(max_length=64, index=True) # 图像的哈希值
emoji_hash = TextField(index=True) # 图像的哈希值
description = TextField(null=True) # 图像的描述
path = CharField(max_length=512, unique=True) # 图像文件的路径
path = TextField(unique=True) # 图像文件的路径
# base64 = TextField() # 图片的base64编码
count = IntegerField(default=1) # 图片被引用的次数
timestamp = FloatField() # 时间戳
@@ -215,7 +208,7 @@ class Images(BaseModel):
vlm_processed = BooleanField(default=False) # 是否已经过VLM处理
class Meta:
table_name = table_prefix + "images"
table_name = "images"
class ImageDescriptions(BaseModel):
@@ -224,13 +217,13 @@ class ImageDescriptions(BaseModel):
"""
type = TextField() # 类型,例如 "emoji"
image_description_hash = CharField(max_length=64, index=True) # 图像的哈希值
image_description_hash = TextField(index=True) # 图像的哈希值
description = TextField() # 图像的描述
timestamp = FloatField() # 时间戳
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "image_descriptions"
table_name = "image_descriptions"
class OnlineTime(BaseModel):
@@ -239,14 +232,14 @@ class OnlineTime(BaseModel):
"""
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
timestamp = CharField(max_length=64, default=datetime.datetime.now) # 时间戳
timestamp = TextField(default=datetime.datetime.now) # 时间戳
duration = IntegerField() # 时长,单位分钟
start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "online_time"
table_name = "online_time"
class PersonInfo(BaseModel):
@@ -254,11 +247,11 @@ class PersonInfo(BaseModel):
用于存储个人信息数据的模型。
"""
person_id = CharField(max_length=64, unique=True) # 个人唯一ID
person_id = TextField(unique=True, index=True) # 个人唯一ID
person_name = TextField(null=True) # 个人名称 (允许为空)
name_reason = TextField(null=True) # 名称设定的原因
platform = TextField() # 平台
user_id = CharField(max_length=64, index=True) # 用户ID
user_id = TextField(index=True) # 用户ID
nickname = TextField() # 用户昵称
impression = TextField(null=True) # 个人印象
short_impression = TextField(null=True) # 个人印象的简短描述
@@ -273,11 +266,11 @@ class PersonInfo(BaseModel):
class Meta:
# database = db # 继承自 BaseModel
table_name = table_prefix + "person_info"
table_name = "person_info"
class Memory(BaseModel):
memory_id = CharField(max_length=128, index=True)
memory_id = TextField(index=True)
chat_id = TextField(null=True)
memory_text = TextField(null=True)
keywords = TextField(null=True)
@@ -285,7 +278,7 @@ class Memory(BaseModel):
last_view_time = FloatField(null=True)
class Meta:
table_name = table_prefix + "memory"
table_name = "memory"
class Expression(BaseModel):
@@ -297,16 +290,16 @@ class Expression(BaseModel):
style = TextField()
count = FloatField()
last_active_time = FloatField()
chat_id = CharField(max_length=128, index=True)
chat_id = TextField(index=True)
type = TextField()
create_date = FloatField(null=True) # 创建日期,允许为空以兼容老数据
class Meta:
table_name = table_prefix + "expression"
table_name = "expression"
class ThinkingLog(BaseModel):
chat_id = CharField(max_length=128, index=True)
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
@@ -326,7 +319,7 @@ class ThinkingLog(BaseModel):
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = table_prefix + "thinking_logs"
table_name = "thinking_logs"
class GraphNodes(BaseModel):
@@ -334,14 +327,14 @@ class GraphNodes(BaseModel):
用于存储记忆图节点的模型
"""
concept = CharField(max_length=128, unique=True) # 节点概念
concept = TextField(unique=True, index=True) # 节点概念
memory_items = TextField() # JSON格式存储的记忆列表
hash = TextField() # 节点哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = table_prefix + "graph_nodes"
table_name = "graph_nodes"
class GraphEdges(BaseModel):
@@ -349,15 +342,15 @@ class GraphEdges(BaseModel):
用于存储记忆图边的模型
"""
source = CharField(max_length=128, index=True) # 源节点
target = CharField(max_length=128, index=True) # 目标节点
source = TextField(index=True) # 源节点
target = TextField(index=True) # 目标节点
strength = IntegerField() # 连接强度
hash = TextField() # 边哈希值
created_time = FloatField() # 创建时间戳
last_modified = FloatField() # 最后修改时间戳
class Meta:
table_name = table_prefix + "graph_edges"
table_name = "graph_edges"
def create_tables():
@@ -405,137 +398,73 @@ def initialize_database():
ThinkingLog,
GraphNodes,
GraphEdges,
ActionRecords,
ActionRecords, # 添加 ActionRecords 到初始化列表
]
# 保持 del_extra 为 False以避免在生产环境中意外删除数据。
# 如果需要删除多余字段,请谨慎设置为 True。
del_extra = False
# 辅助函数:根据字段对象和数据库类型获取对应的 SQL 类型字符串
def get_sql_type(field_obj, db_type):
field_type_name = field_obj.__class__.__name__
if db_type == "sqlite":
return {
"TextField": "TEXT",
"IntegerField": "INTEGER",
"FloatField": "FLOAT",
"DoubleField": "DOUBLE",
"BooleanField": "INTEGER",
"DateTimeField": "DATETIME",
}.get(field_type_name, "TEXT")
elif db_type == "mysql":
# CharField 的 max_length 将在主循环中单独处理
return {
"TextField": "LONGTEXT", # MySQL TEXT 类型长度有限LONGTEXT 更安全
"IntegerField": "INT",
"FloatField": "FLOAT",
"DoubleField": "DOUBLE",
"BooleanField": "TINYINT(1)", # MySQL 布尔值存储为 TINYINT(1)
"DateTimeField": "DATETIME",
}.get(field_type_name, "TEXT")
logger.error(f"不支持的数据库类型: {db_type}")
return "TEXT" # 默认回退类型
# 辅助函数:将 Peewee 字段的默认值转换为 SQL 语句中的 DEFAULT 子句
def get_sql_default_value(field_obj):
if field_obj.default is None:
return "" # 没有定义默认值
# 可调用默认值(如 datetime.datetime.now无法直接转换为 SQL DDL 的 DEFAULT 子句
# 因此,对于这类情况,我们不生成 DEFAULT 子句,并依赖 Peewee 在应用层处理
# 如果字段为 NOT NULL 且无法提供字面默认值,则需要在 ADD COLUMN 时临时设为 NULLABLE
if callable(field_obj.default):
return ""
default_value = field_obj.default
if isinstance(default_value, str):
# 字符串默认值需要用单引号括起来,并对内部的单引号进行转义
escaped_value = str(default_value).replace("'", "''")
return f" DEFAULT '{escaped_value}'"
elif isinstance(default_value, bool):
return f" DEFAULT {int(default_value)}" # 布尔值转换为 0 或 1
elif isinstance(default_value, (int, float)):
return f" DEFAULT {default_value}"
return "" # 其他无法直接转换为 SQL 字面值的类型
try:
with db:
with db: # 管理 table_exists 检查的连接
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
logger.warning(f"'{table_name}' 未找到,正在创建...")
db.create_tables([model])
logger.info(f"'{table_name}' 创建成功")
# 表刚创建,无需检查字段
continue
# 获取现有列
db_type = global_config.data_base.db_type
if db_type == "sqlite":
# 检查字段
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
existing_columns = {row[1] for row in cursor.fetchall()}
elif db_type == "mysql":
cursor = db.execute_sql(f"SHOW COLUMNS FROM {table_name}")
existing_columns = {row[0] for row in cursor.fetchall()}
else:
logger.error(f"不支持的数据库类型 '{db_type}',跳过表 '{table_name}' 的字段检查。")
continue
model_fields = set(model._meta.fields.keys())
# 识别并添加缺失字段
missing_fields = model_fields - existing_columns
if missing_fields:
if missing_fields := model_fields - existing_columns:
logger.warning(f"'{table_name}' 缺失字段: {missing_fields}")
for field_name, field_obj in model._meta.fields.items():
if field_name not in existing_columns:
logger.info(f"'{table_name}' 缺失字段 '{field_name}',正在尝试添加...")
sql_type = get_sql_type(field_obj, db_type)
# 特殊处理 MySQL 的 CharField需要 max_length
if isinstance(field_obj, CharField) and db_type == "mysql":
sql_type = f"VARCHAR({field_obj.max_length})"
null_clause = " NULL" if field_obj.null else " NOT NULL"
default_clause = get_sql_default_value(field_obj)
# 如果字段定义为 NOT NULL 且无法在 SQL DDL 中提供字面默认值 (如可调用默认值)
# 为了避免在有数据的表中添加列时失败,暂时将其添加为 NULLABLE。
# 这是一种务实的兼容性处理,后续可能需要手动回填数据并修改为 NOT NULL。
if not field_obj.null and not default_clause:
logger.warning(
f"'{table_name}' 的字段 '{field_name}' 为 NOT NULL 但无法生成SQL默认值"
f"将暂时添加为 NULLABLE 以避免现有数据行错误。"
)
null_clause = " NULL" # 强制设为 NULLABLE
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}{null_clause}{default_clause}"
logger.info(f"'{table_name}' 缺失字段 '{field_name}',正在添加...")
field_type = field_obj.__class__.__name__
sql_type = {
"TextField": "TEXT",
"IntegerField": "INTEGER",
"FloatField": "FLOAT",
"DoubleField": "DOUBLE",
"BooleanField": "INTEGER",
"DateTimeField": "DATETIME",
}.get(field_type, "TEXT")
alter_sql = f"ALTER TABLE {table_name} ADD COLUMN {field_name} {sql_type}"
alter_sql += " NULL" if field_obj.null else " NOT NULL"
if hasattr(field_obj, "default") and field_obj.default is not None:
# 正确处理不同类型的默认值跳过lambda函数
default_value = field_obj.default
if callable(default_value):
# 跳过lambda函数或其他可调用对象这些无法在SQL中表示
pass
elif isinstance(default_value, str):
alter_sql += f" DEFAULT '{default_value}'"
elif isinstance(default_value, bool):
alter_sql += f" DEFAULT {int(default_value)}"
else:
alter_sql += f" DEFAULT {default_value}"
try:
db.execute_sql(alter_sql)
logger.info(f"字段 '{field_name}' 添加成功")
except Exception as e:
logger.error(f"添加字段 '{field_name}' 失败: {e}. SQL 语句: {alter_sql}")
logger.error(f"添加字段 '{field_name}' 失败: {e}")
# 检查并删除多余字段(根据 del_extra 旗标决定
if del_extra:
# 检查并删除多余字段(新增逻辑
extra_fields = existing_columns - model_fields
if extra_fields:
logger.warning(f"'{table_name}' 存在模型中未定义的字段: {extra_fields}")
logger.warning(f"'{table_name}' 存在多余字段: {extra_fields}")
for field_name in extra_fields:
try:
logger.warning(f"'{table_name}' 正在尝试删除多余字段 '{field_name}'...")
logger.warning(f"'{table_name}' 存在多余字段 '{field_name}',正在尝试删除...")
db.execute_sql(f"ALTER TABLE {table_name} DROP COLUMN {field_name}")
logger.info(f"字段 '{field_name}' 删除成功")
except Exception as e:
logger.error(f"删除字段 '{field_name}' 失败: {e}")
except Exception as e:
logger.exception(f"数据库初始化过程中发生异常: {e}")
# 如果初始化失败(例如数据库不可用),则退出
logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出
return
logger.info("数据库初始化完成")

View File

@@ -14,7 +14,6 @@ from src.common.logger import get_logger
from src.config.config_base import ConfigBase
from src.config.official_configs import (
BotConfig,
DataBaseConfig,
PersonalityConfig,
ExpressionConfig,
ChatConfig,
@@ -349,7 +348,6 @@ class Config(ConfigBase):
debug: DebugConfig
custom_prompt: CustomPromptConfig
voice: VoiceConfig
data_base: DataBaseConfig
@dataclass

View File

@@ -1,4 +1,5 @@
import re
from dataclasses import dataclass, field
from typing import Literal, Optional
@@ -271,7 +272,6 @@ class NormalChatConfig(ConfigBase):
willing_mode: str = "classical"
"""意愿模式"""
@dataclass
class ExpressionConfig(ConfigBase):
"""表达配置类"""
@@ -302,7 +302,6 @@ class ToolConfig(ConfigBase):
enable_tool: bool = False
"""是否在聊天中启用工具"""
@dataclass
class VoiceConfig(ConfigBase):
"""语音识别配置类"""
@@ -450,7 +449,6 @@ class KeywordReactionConfig(ConfigBase):
if not isinstance(rule, KeywordRuleConfig):
raise ValueError(f"规则必须是KeywordRuleConfig类型而不是{type(rule).__name__}")
@dataclass
class CustomPromptConfig(ConfigBase):
"""自定义提示词配置类"""
@@ -600,27 +598,3 @@ class LPMMKnowledgeConfig(ConfigBase):
embedding_dimension: int = 1024
"""嵌入向量维度,应该与模型的输出维度一致"""
class DataBaseConfig(ConfigBase):
"""数据库配置类"""
db_type: Literal["sqlite", "mysql"] = "sqlite"
"""数据库类型支持sqlite、mysql"""
host: str = "127.0.0.1"
"""数据库主机地址"""
port: int = 3306
"""数据库端口号"""
username: str = ""
"""数据库用户名"""
password: str = ""
"""数据库密码"""
database: str = "MaiBot"
"""数据库名称"""
table_prefix: str = ""
"""数据库表前缀"""

View File

@@ -1,5 +1,5 @@
[inner]
version = "6.1.0"
version = "6.0.0"
#----以下是给开发人员阅读的,如果你只是部署了麦麦,不需要阅读----
#如果你想要修改配置文件请在修改后将version的值进行变更
@@ -232,13 +232,3 @@ enable = true
[experimental] #实验性功能
enable_friend_chat = false # 是否启用好友聊天
[data_base] #数据库配置
# 数据库类型可选sqlite, mysql
db_type = "sqlite" # 数据库类型
host = "" # 数据库主机地址,如果是sqlite则不需要填写
port = 3306 # 数据库端口,如果是sqlite则不需要填写
username = "" # 数据库用户名,如果是sqlite则不需要填写
password = "" # 数据库密码,如果是sqlite则不需要填写
database = "MaiBot" # 数据库名称,如果是sqlite则不需要填写
table_prefix = "" # 数据库表前缀,用于支持多实例部署