Merge branch 'new-storage' into plugin

This commit is contained in:
SengokuCola
2025-05-16 21:14:16 +08:00
63 changed files with 2397 additions and 2008 deletions

View File

@@ -1,5 +1,6 @@
import os
from pymongo import MongoClient
from peewee import SqliteDatabase
from pymongo.database import Database
from rich.traceback import install
@@ -57,4 +58,15 @@ class DBWrapper:
# 全局数据库访问点
db: Database = DBWrapper()
memory_db: Database = DBWrapper()
# 定义数据库文件路径
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
_DB_DIR = os.path.join(ROOT_PATH, "data")
_DB_FILE = os.path.join(_DB_DIR, "MaiBot.db")
# 确保数据库目录存在
os.makedirs(_DB_DIR, exist_ok=True)
# 全局 Peewee SQLite 数据库访问点
db = SqliteDatabase(_DB_FILE)

View File

@@ -0,0 +1,358 @@
from peewee import Model, DoubleField, IntegerField, BooleanField, TextField, FloatField, DateTimeField
from .database import db
import datetime
from ..logger_manager import get_logger
logger = get_logger("database_model")
# 请在此处定义您的数据库实例。
# 您需要取消注释并配置适合您的数据库的部分。
# 例如,对于 SQLite:
# db = SqliteDatabase('MaiBot.db')
#
# 对于 PostgreSQL:
# db = PostgresqlDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=5432)
#
# 对于 MySQL:
# db = MySQLDatabase('your_db_name', user='your_user', password='your_password',
# host='localhost', port=3306)
# 定义一个基础模型是一个好习惯,所有其他模型都应继承自它。
# 这允许您在一个地方为所有模型指定数据库。
class BaseModel(Model):
class Meta:
# 将下面的 'db' 替换为您实际的数据库实例变量名。
database = db # 例如: database = my_actual_db_instance
pass # 在用户定义数据库实例之前,此处为占位符
class ChatStreams(BaseModel):
"""
用于存储流式记录数据的模型,类似于提供的 MongoDB 结构。
"""
# stream_id: "a544edeb1a9b73e3e1d77dff36e41264"
# 假设 stream_id 是唯一的,并为其创建索引以提高查询性能。
stream_id = TextField(unique=True, index=True)
# create_time: 1746096761.4490178 (时间戳精确到小数点后7位)
# DoubleField 用于存储浮点数,适合此类时间戳。
create_time = DoubleField()
# group_info 字段:
# platform: "qq"
# group_id: "941657197"
# group_name: "测试"
group_platform = TextField()
group_id = TextField()
group_name = TextField()
# last_active_time: 1746623771.4825106 (时间戳精确到小数点后7位)
last_active_time = DoubleField()
# platform: "qq" (顶层平台字段)
platform = TextField()
# user_info 字段:
# platform: "qq"
# user_id: "1787882683"
# user_nickname: "墨梓柒(IceSakurary)"
# user_cardname: ""
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
# user_cardname 可能为空字符串或不存在,设置 null=True 更具灵活性。
user_cardname = TextField(null=True)
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# 如果不使用带有数据库实例的 BaseModel或者想覆盖它
# 请取消注释并在下面设置数据库实例:
# database = db
table_name = "chat_streams" # 可选:明确指定数据库中的表名
class LLMUsage(BaseModel):
"""
用于存储 API 使用日志数据的模型。
"""
model_name = TextField(index=True) # 添加索引
user_id = TextField(index=True) # 添加索引
request_type = TextField(index=True) # 添加索引
endpoint = TextField()
prompt_tokens = IntegerField()
completion_tokens = IntegerField()
total_tokens = IntegerField()
cost = DoubleField()
status = TextField()
timestamp = DateTimeField(index=True) # 更改为 DateTimeField 并添加索引
class Meta:
# 如果 BaseModel.Meta.database 已设置,则此模型将继承该数据库配置。
# database = db
table_name = "llm_usage"
class Emoji(BaseModel):
"""表情包"""
full_path = TextField(unique=True, index=True) # 文件的完整路径 (包括文件名)
format = TextField() # 图片格式
emoji_hash = TextField(index=True) # 表情包的哈希值
description = TextField() # 表情包的描述
query_count = IntegerField(default=0) # 查询次数(用于统计表情包被查询描述的次数)
is_registered = BooleanField(default=False) # 是否已注册
is_banned = BooleanField(default=False) # 是否被禁止注册
# emotion: list[str] # 表情包的情感标签 - 存储为文本,应用层处理序列化/反序列化
emotion = TextField(null=True)
record_time = FloatField() # 记录时间(被创建的时间)
register_time = FloatField(null=True) # 注册时间(被注册为可用表情包的时间)
usage_count = IntegerField(default=0) # 使用次数(被使用的次数)
last_used_time = FloatField(null=True) # 上次使用时间
class Meta:
# database = db # 继承自 BaseModel
table_name = "emoji"
class Messages(BaseModel):
"""
用于存储消息数据的模型。
"""
message_id = TextField(index=True) # 消息 ID (更改自 IntegerField)
time = DoubleField() # 消息时间戳
chat_id = TextField(index=True) # 对应的 ChatStreams stream_id
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = TextField()
chat_info_platform = TextField()
chat_info_user_platform = TextField()
chat_info_user_id = TextField()
chat_info_user_nickname = TextField()
chat_info_user_cardname = TextField(null=True)
chat_info_group_platform = TextField(null=True) # 群聊信息可能不存在
chat_info_group_id = TextField(null=True)
chat_info_group_name = TextField(null=True)
chat_info_create_time = DoubleField()
chat_info_last_active_time = DoubleField()
# 从顶层 user_info 扁平化而来的字段 (消息发送者信息)
user_platform = TextField()
user_id = TextField()
user_nickname = TextField()
user_cardname = TextField(null=True)
processed_plain_text = TextField(null=True) # 处理后的纯文本消息
detailed_plain_text = TextField(null=True) # 详细的纯文本消息
memorized_times = IntegerField(default=0) # 被记忆的次数
class Meta:
# database = db # 继承自 BaseModel
table_name = "messages"
class Images(BaseModel):
"""
用于存储图像信息的模型。
"""
emoji_hash = TextField(index=True) # 图像的哈希值
description = TextField(null=True) # 图像的描述
path = TextField(unique=True) # 图像文件的路径
timestamp = FloatField() # 时间戳
type = TextField() # 图像类型,例如 "emoji"
class Meta:
# database = db # 继承自 BaseModel
table_name = "images"
class ImageDescriptions(BaseModel):
"""
用于存储图像描述信息的模型。
"""
type = TextField() # 类型,例如 "emoji"
image_description_hash = TextField(index=True) # 图像的哈希值
description = TextField() # 图像的描述
timestamp = FloatField() # 时间戳
class Meta:
# database = db # 继承自 BaseModel
table_name = "image_descriptions"
class OnlineTime(BaseModel):
"""
用于存储在线时长记录的模型。
"""
# timestamp: "$date": "2025-05-01T18:52:18.191Z" (存储为字符串)
timestamp = TextField(default=datetime.datetime.now) # 时间戳
duration = IntegerField() # 时长,单位分钟
start_timestamp = DateTimeField(default=datetime.datetime.now)
end_timestamp = DateTimeField(index=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = "online_time"
class PersonInfo(BaseModel):
"""
用于存储个人信息数据的模型。
"""
person_id = TextField(unique=True, index=True) # 个人唯一ID
person_name = TextField(null=True) # 个人名称 (允许为空)
name_reason = TextField(null=True) # 名称设定的原因
platform = TextField() # 平台
user_id = TextField(index=True) # 用户ID
nickname = TextField() # 用户昵称
relationship_value = IntegerField(default=0) # 关系值
know_time = FloatField() # 认识时间 (时间戳)
msg_interval = IntegerField() # 消息间隔
# msg_interval_list: 存储为 JSON 字符串的列表
msg_interval_list = TextField(null=True)
class Meta:
# database = db # 继承自 BaseModel
table_name = "person_info"
class Knowledges(BaseModel):
"""
用于存储知识库条目的模型。
"""
content = TextField() # 知识内容的文本
embedding = TextField() # 知识内容的嵌入向量,存储为 JSON 字符串的浮点数列表
# 可以添加其他元数据字段,如 source, create_time 等
class Meta:
# database = db # 继承自 BaseModel
table_name = "knowledges"
class ThinkingLog(BaseModel):
chat_id = TextField(index=True)
trigger_text = TextField(null=True)
response_text = TextField(null=True)
# Store complex dicts/lists as JSON strings
trigger_info_json = TextField(null=True)
response_info_json = TextField(null=True)
timing_results_json = TextField(null=True)
chat_history_json = TextField(null=True)
chat_history_in_thinking_json = TextField(null=True)
chat_history_after_response_json = TextField(null=True)
heartflow_data_json = TextField(null=True)
reasoning_data_json = TextField(null=True)
# Add a timestamp for the log entry itself
# Ensure you have: from peewee import DateTimeField
# And: import datetime
created_at = DateTimeField(default=datetime.datetime.now)
class Meta:
table_name = "thinking_logs"
class RecalledMessages(BaseModel):
"""
用于存储撤回消息记录的模型。
"""
message_id = TextField(index=True) # 被撤回的消息 ID
time = DoubleField() # 撤回操作发生的时间戳
stream_id = TextField() # 对应的 ChatStreams stream_id
class Meta:
table_name = "recalled_messages"
def create_tables():
"""
创建所有在模型中定义的数据库表。
"""
with db:
db.create_tables(
[
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
RecalledMessages, # 添加新模型
]
)
def initialize_database():
"""
检查所有定义的表是否存在,如果不存在则创建它们。
检查所有表的所有字段是否存在,如果缺失则警告用户并退出程序。
"""
import sys
models = [
ChatStreams,
LLMUsage,
Emoji,
Messages,
Images,
ImageDescriptions,
OnlineTime,
PersonInfo,
Knowledges,
ThinkingLog,
RecalledMessages, # 添加新模型
]
needs_creation = False
try:
with db: # 管理 table_exists 检查的连接
for model in models:
table_name = model._meta.table_name
if not db.table_exists(model):
logger.warning(f"'{table_name}' 未找到。")
needs_creation = True
break # 一个表丢失,无需进一步检查。
if not needs_creation:
# 检查字段
for model in models:
table_name = model._meta.table_name
cursor = db.execute_sql(f"PRAGMA table_info('{table_name}')")
existing_columns = {row[1] for row in cursor.fetchall()}
model_fields = model._meta.fields
for field_name in model_fields:
if field_name not in existing_columns:
logger.error(f"'{table_name}' 缺失字段 '{field_name}',请手动迁移数据库结构后重启程序。")
sys.exit(1)
except Exception as e:
logger.exception(f"检查表或字段是否存在时出错: {e}")
# 如果检查失败(例如数据库不可用),则退出
return
if needs_creation:
logger.info("正在初始化数据库:一个或多个表丢失。正在尝试创建所有定义的表...")
try:
create_tables() # 此函数有其自己的 'with db:' 上下文管理。
logger.info("数据库表创建过程完成。")
except Exception as e:
logger.exception(f"创建表期间出错: {e}")
else:
logger.info("所有数据库表及字段均已存在。")
# 模块加载时调用初始化函数
initialize_database()

View File

@@ -1,11 +1,19 @@
from src.common.database import db
from src.common.database.database_model import Messages # 更改导入
from src.common.logger import get_module_logger
import traceback
from typing import List, Any, Optional
from peewee import Model # 添加 Peewee Model 导入
logger = get_module_logger(__name__)
def _model_to_dict(model_instance: Model) -> dict[str, Any]:
"""
将 Peewee 模型实例转换为字典。
"""
return model_instance.__data__
def find_messages(
message_filter: dict[str, Any],
sort: Optional[List[tuple[str, int]]] = None,
@@ -16,39 +24,84 @@ def find_messages(
根据提供的过滤器、排序和限制条件查找消息。
Args:
message_filter: MongoDB 查询过滤器。
sort: MongoDB 排序条件列表,例如 [('time', 1)]。仅在 limit 为 0 时生效。
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
sort: 排序条件列表,例如 [('time', 1)] (1 for asc, -1 for desc)。仅在 limit 为 0 时生效。
limit: 返回的最大文档数0表示不限制。
limit_mode: 当 limit > 0 时生效。 'earliest' 表示获取最早的记录, 'latest' 表示获取最新的记录(结果仍按时间正序排列)。默认为 'latest'
Returns:
消息文档列表,如果出错则返回空列表。
消息字典列表,如果出错则返回空列表。
"""
try:
query = db.messages.find(message_filter)
query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(f"过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。")
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
if limit > 0:
if limit_mode == "earliest":
# 获取时间最早的 limit 条记录,已经是正序
query = query.sort([("time", 1)]).limit(limit)
results = list(query)
query = query.order_by(Messages.time.asc()).limit(limit)
peewee_results = list(query)
else: # 默认为 'latest'
# 获取时间最晚的 limit 条记录
query = query.sort([("time", -1)]).limit(limit)
latest_results = list(query)
query = query.order_by(Messages.time.desc()).limit(limit)
latest_results_peewee = list(query)
# 将结果按时间正序排列
# 假设消息文档中总是有 'time' 字段且可排序
results = sorted(latest_results, key=lambda msg: msg.get("time"))
peewee_results = sorted(latest_results_peewee, key=lambda msg: msg.time)
else:
# limit 为 0 时,应用传入的 sort 参数
if sort:
query = query.sort(sort)
results = list(query)
peewee_sort_terms = []
for field_name, direction in sort:
if hasattr(Messages, field_name):
field = getattr(Messages, field_name)
if direction == 1: # ASC
peewee_sort_terms.append(field.asc())
elif direction == -1: # DESC
peewee_sort_terms.append(field.desc())
else:
logger.warning(f"字段 '{field_name}' 的排序方向 '{direction}' 无效。将跳过此排序条件。")
else:
logger.warning(f"排序字段 '{field_name}' 在 Messages 模型中未找到。将跳过此排序条件。")
if peewee_sort_terms:
query = query.order_by(*peewee_sort_terms)
peewee_results = list(query)
return results
return [_model_to_dict(msg) for msg in peewee_results]
except Exception as e:
log_message = (
f"查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
f"使用 Peewee 查找消息失败 (filter={message_filter}, sort={sort}, limit={limit}, limit_mode={limit_mode}): {e}\n"
+ traceback.format_exc()
)
logger.error(log_message)
@@ -60,18 +113,57 @@ def count_messages(message_filter: dict[str, Any]) -> int:
根据提供的过滤器计算消息数量。
Args:
message_filter: MongoDB 查询过滤器。
message_filter: 查询过滤器字典,键为模型字段名,值为期望值或包含操作符的字典 (例如 {'$gt': value}).
Returns:
符合条件的消息数量,如果出错则返回 0。
"""
try:
count = db.messages.count_documents(message_filter)
query = Messages.select()
# 应用过滤器
if message_filter:
conditions = []
for key, value in message_filter.items():
if hasattr(Messages, key):
field = getattr(Messages, key)
if isinstance(value, dict):
# 处理 MongoDB 风格的操作符
for op, op_value in value.items():
if op == "$gt":
conditions.append(field > op_value)
elif op == "$lt":
conditions.append(field < op_value)
elif op == "$gte":
conditions.append(field >= op_value)
elif op == "$lte":
conditions.append(field <= op_value)
elif op == "$ne":
conditions.append(field != op_value)
elif op == "$in":
conditions.append(field.in_(op_value))
elif op == "$nin":
conditions.append(field.not_in(op_value))
else:
logger.warning(
f"计数时,过滤器中遇到未知操作符 '{op}' (字段: '{key}')。将跳过此操作符。"
)
else:
# 直接相等比较
conditions.append(field == value)
else:
logger.warning(f"计数时,过滤器键 '{key}' 在 Messages 模型中未找到。将跳过此条件。")
if conditions:
query = query.where(*conditions)
count = query.count()
return count
except Exception as e:
log_message = f"计数消息失败 (message_filter={message_filter}): {e}\n" + traceback.format_exc()
log_message = f"使用 Peewee 计数消息失败 (message_filter={message_filter}): {e}\n{traceback.format_exc()}"
logger.error(log_message)
return 0
# 你可以在这里添加更多与 messages 集合相关的数据库操作函数,例如 find_one_message, insert_message 等。
# 注意:对于 Peewee插入操作通常是 Messages.create(...) 或 instance.save()。
# 查找单个消息可以是 Messages.get_or_none(...) 或 query.first()。

View File

@@ -35,7 +35,7 @@ class TelemetryHeartBeatTask(AsyncTask):
info_dict = {
"os_type": "Unknown",
"py_version": platform.python_version(),
"mmc_version": global_config.MAI_VERSION,
"mmc_version": global_config.MMC_VERSION,
}
match platform.system():
@@ -133,10 +133,9 @@ class TelemetryHeartBeatTask(AsyncTask):
async def run(self):
# 发送心跳
if global_config.remote_enable:
if self.client_uuid is None:
if not await self._req_uuid():
logger.error("获取UUID失败跳过此次心跳")
return
if global_config.telemetry.enable:
if self.client_uuid is None and not await self._req_uuid():
logger.error("获取UUID失败跳过此次心跳")
return
await self._send_heartbeat()