refactor: 清理项目结构并修复类型注解问题

修复 SQLAlchemy 模型的类型注解,使用 Mapped 类型避免类型检查器错误
- 修正异步数据库操作中缺少 await 的问题
- 优化反注入统计系统的数值字段处理逻辑
- 添加缺失的导入语句修复模块依赖问题
This commit is contained in:
雅诺狐
2025-10-07 11:35:12 +08:00
parent 167e4d2520
commit 875ee4813c
19 changed files with 1466 additions and 3997 deletions

2
bot.py
View File

@@ -543,7 +543,7 @@ class MaiBotMain:
"""设置时区"""
try:
if platform.system().lower() != "windows":
time.tzset()
time.tzset() # type: ignore
logger.info("时区设置完成")
else:
logger.info("Windows系统跳过时区设置")

View File

@@ -1,10 +0,0 @@
from src.plugin_system.base.plugin_metadata import PluginMetadata
__plugin_meta__ = PluginMetadata(
name="Echo Example Plugin",
description="An example plugin that echoes messages.",
usage="!echo [message]",
version="1.0.0",
author="Your Name",
license="MIT",
)

View File

@@ -1,53 +0,0 @@
{
"manifest_version": 1,
"format_version": "1.0.0",
"name": "Echo 示例插件",
"description": "展示增强命令系统的Echo命令示例插件",
"version": "1.0.0",
"author": {
"name": "MoFox"
},
"license": "MIT",
"keywords": ["echo", "example", "command"],
"categories": ["utility", "example"],
"host_application": {
"name": "MaiBot",
"min_version": "0.10.0"
},
"entry_points": {
"main": "plugin.py"
},
"plugin_info": {
"is_built_in": false,
"plugin_type": "example",
"components": [
{
"type": "command",
"name": "echo",
"description": "回显命令,支持别名 say, repeat"
},
{
"type": "command",
"name": "hello",
"description": "问候命令,支持别名 hi, greet"
},
{
"type": "command",
"name": "info",
"description": "显示插件信息,支持别名 about"
},
{
"type": "command",
"name": "test",
"description": "测试命令,展示参数解析功能"
}
],
"features": [
"增强命令系统示例",
"无需正则表达式的命令定义",
"命令别名支持",
"参数解析功能",
"聊天类型限制"
]
}
}

View File

@@ -1,204 +0,0 @@
"""
Echo 示例插件
展示增强命令系统的使用方法
"""
from typing import Union
from src.plugin_system import (
BasePlugin,
ChatType,
CommandArgs,
ConfigField,
PlusCommand,
PlusCommandInfo,
register_plugin,
)
from src.plugin_system.base.component_types import PythonDependency
class EchoCommand(PlusCommand):
"""Echo命令示例"""
command_name = "echo"
command_description = "回显命令"
command_aliases = ["say", "repeat"]
priority = 5
chat_type_allow = ChatType.ALL
intercept_message = True
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行echo命令"""
if args.is_empty():
await self.send_text("❓ 请提供要回显的内容\n用法: /echo <内容>")
return True, "参数不足", True
content = args.get_raw()
# 检查内容长度限制
max_length = self.get_config("commands.max_content_length", 500)
if len(content) > max_length:
await self.send_text(f"❌ 内容过长,最大允许 {max_length} 字符")
return True, "内容过长", True
await self.send_text(f"🔊 {content}")
return True, "Echo命令执行成功", True
class HelloCommand(PlusCommand):
"""Hello命令示例"""
command_name = "hello"
command_description = "问候命令"
command_aliases = ["hi", "greet"]
priority = 3
chat_type_allow = ChatType.ALL
intercept_message = True
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行hello命令"""
if args.is_empty():
await self.send_text("👋 Hello! 很高兴见到你!")
else:
name = args.get_first()
await self.send_text(f"👋 Hello, {name}! 很高兴见到你!")
return True, "Hello命令执行成功", True
class InfoCommand(PlusCommand):
"""信息命令示例"""
command_name = "info"
command_description = "显示插件信息"
command_aliases = ["about"]
priority = 1
chat_type_allow = ChatType.ALL
intercept_message = True
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行info命令"""
info_text = (
"📋 Echo 示例插件信息\n"
"版本: 1.0.0\n"
"作者: MaiBot Team\n"
"描述: 展示增强命令系统的使用方法\n\n"
"🎯 可用命令:\n"
"• /echo|/say|/repeat <内容> - 回显内容\n"
"• /hello|/hi|/greet [名字] - 问候\n"
"• /info|/about - 显示此信息\n"
"• /test <子命令> [参数] - 测试各种功能"
)
await self.send_text(info_text)
return True, "Info命令执行成功", True
class TestCommand(PlusCommand):
"""测试命令示例,展示参数解析功能"""
command_name = "test"
command_description = "测试命令,展示参数解析功能"
command_aliases = ["t"]
priority = 2
chat_type_allow = ChatType.ALL
intercept_message = True
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
"""执行test命令"""
if args.is_empty():
help_text = (
"🧪 测试命令帮助\n"
"用法: /test <子命令> [参数]\n\n"
"可用子命令:\n"
"• args - 显示参数解析结果\n"
"• flags - 测试标志参数\n"
"• count - 计算参数数量\n"
"• join - 连接所有参数"
)
await self.send_text(help_text)
return True, "显示帮助", True
subcommand = args.get_first().lower()
if subcommand == "args":
result = (
f"🔍 参数解析结果:\n"
f"原始字符串: '{args.get_raw()}'\n"
f"解析后参数: {args.get_args()}\n"
f"参数数量: {args.count()}\n"
f"第一个参数: '{args.get_first()}'\n"
f"剩余参数: '{args.get_remaining()}'"
)
await self.send_text(result)
elif subcommand == "flags":
result = (
f"🏴 标志测试结果:\n"
f"包含 --verbose: {args.has_flag('--verbose')}\n"
f"包含 -v: {args.has_flag('-v')}\n"
f"--output 的值: '{args.get_flag_value('--output', '未设置')}'\n"
f"--name 的值: '{args.get_flag_value('--name', '未设置')}'"
)
await self.send_text(result)
elif subcommand == "count":
count = args.count() - 1 # 减去子命令本身
await self.send_text(f"📊 除子命令外的参数数量: {count}")
elif subcommand == "join":
remaining = args.get_remaining()
if remaining:
await self.send_text(f"🔗 连接结果: {remaining}")
else:
await self.send_text("❌ 没有可连接的参数")
else:
await self.send_text(f"❓ 未知的子命令: {subcommand}")
return True, "Test命令执行成功", True
@register_plugin
class EchoExamplePlugin(BasePlugin):
"""Echo 示例插件"""
plugin_name: str = "echo_example_plugin"
enable_plugin: bool = True
dependencies: list[str] = []
python_dependencies: list[Union[str, "PythonDependency"]] = []
config_file_name: str = "config.toml"
config_schema = {
"plugin": {
"enabled": ConfigField(bool, default=True, description="是否启用插件"),
"config_version": ConfigField(str, default="1.0.0", description="配置文件版本"),
},
"commands": {
"echo_enabled": ConfigField(bool, default=True, description="是否启用 Echo 命令"),
"cooldown": ConfigField(int, default=0, description="命令冷却时间(秒)"),
"max_content_length": ConfigField(int, default=500, description="最大回显内容长度"),
},
}
config_section_descriptions = {
"plugin": "插件基本配置",
"commands": "命令相关配置",
}
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type]]:
"""获取插件组件"""
components = []
if self.get_config("plugin.enabled", True):
# 添加所有命令直接使用PlusCommand类
if self.get_config("commands.echo_enabled", True):
components.append((EchoCommand.get_plus_command_info(), EchoCommand))
components.append((HelloCommand.get_plus_command_info(), HelloCommand))
components.append((InfoCommand.get_plus_command_info(), InfoCommand))
components.append((TestCommand.get_plus_command_info(), TestCommand))
return components

View File

@@ -107,7 +107,7 @@ class HelloWorldPlugin(BasePlugin):
components.append((GetSystemInfoTool.get_tool_info(), GetSystemInfoTool))
if self.get_config("components.hello_command_enabled", True):
components.append((HelloCommand.get_command_info(), HelloCommand))
components.append((HelloCommand.get_plus_command_info(), HelloCommand))
if self.get_config("components.random_emoji_action_enabled", True):
components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction))

31
pyrightconfig.json Normal file
View File

@@ -0,0 +1,31 @@
{
"$schema": "https://raw.githubusercontent.com/microsoft/pyright/main/packages/vscode-pyright/schemas/pyrightconfig.schema.json",
"include": [
"src",
"bot.py",
"__main__.py"
],
"exclude": [
"**/__pycache__",
"data",
"logs",
"tests",
"target",
"*.egg-info"
],
"typeCheckingMode": "standard",
"reportMissingImports": false,
"reportMissingTypeStubs": false,
"reportMissingModuleSource": false,
"diagnosticSeverityOverrides": {
"reportMissingImports": "none",
"reportMissingTypeStubs": "none",
"reportMissingModuleSource": "none"
},
"pythonVersion": "3.12",
"venvPath": ".",
"venv": ".venv",
"executionEnvironments": [
{"root": "src"}
]
}

View File

@@ -0,0 +1,220 @@
"""批量将经典 SQLAlchemy 模型字段写法
field = Column(Integer, nullable=False, default=0)
转换为 2.0 推荐的带类型注解写法:
field: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
脚本特点:
1. 仅处理指定文件(默认: src/common/database/sqlalchemy_models.py)。
2. 自动识别多行 Column(...) 定义 (括号未闭合会继续合并)。
3. 已经是 Mapped 写法的行会跳过。
4. 根据类型名 (Integer / Float / Boolean / Text / String / DateTime / get_string_field) 推断 Python 类型。
5. nullable=True 时自动添加 "| None"
6. 保留 Column(...) 内的原始参数顺序与内容。
7. 生成 .bak 备份文件,确保可回滚。
8. 支持 --dry-run 查看差异,不写回文件。
局限/注意:
- 简单基于正则/括号计数,不解析完整 AST非常规写法(例如变量中构造 Column 再赋值)不会处理。
- 复杂工厂/自定义类型未在映射表中的,统一映射为 Any。
- 不自动添加 from __future__ import annotations如需 Python 3.10 以下更先进类型表达式,请自行处理。
使用方式: (在项目根目录执行)
python scripts/convert_sqlalchemy_models.py \
--file src/common/database/sqlalchemy_models.py --dry-run
确认无误后去掉 --dry-run 真实写入。
"""
from __future__ import annotations
import argparse
import re
import shutil
from pathlib import Path
from typing import Any
TYPE_MAP = {
"Integer": "int",
"Float": "float",
"Boolean": "bool",
"Text": "str",
"String": "str",
"DateTime": "datetime.datetime",
# 自定义帮助函数 get_string_field(...) 也返回字符串类型
"get_string_field": "str",
}
COLUMN_ASSIGN_RE = re.compile(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*Column\(")
ALREADY_MAPPED_RE = re.compile(r"^[ \t]*[A-Za-z_][A-Za-z0-9_]*\s*:\s*Mapped\[")
def detect_column_block(lines: list[str], start_index: int) -> tuple[int, int] | None:
"""检测从 start_index 开始的 Column(...) 语句跨越的行范围 (包含结束行)。
使用括号计数法处理多行。
返回 (start, end) 行号 (包含 end)。"""
line = lines[start_index]
if "Column(" not in line:
return None
open_parens = line.count("(") - line.count(")")
i = start_index
while open_parens > 0 and i + 1 < len(lines):
i += 1
l2 = lines[i]
open_parens += l2.count("(") - l2.count(")")
return (start_index, i)
def extract_column_body(block_lines: list[str]) -> str:
"""提取 Column(...) 内部参数文本 (去掉首尾 Column( 和 最后一个 ) )。"""
joined = "\n".join(block_lines)
# 找到第一次出现 Column(
start_pos = joined.find("Column(")
if start_pos == -1:
return ""
inner = joined[start_pos + len("Column(") :]
# 去掉最后一个 ) —— 简单方式: 找到最后一个 ) 并截断
last_paren = inner.rfind(")")
if last_paren != -1:
inner = inner[:last_paren]
return inner.strip()
def guess_python_type(column_body: str) -> str:
# 简单取第一个类型标识符 (去掉前导装饰/空格)
# 可能形式: Integer, Text, get_string_field(50), DateTime, Boolean
# 利用正则抓取第一个标识符
m = re.search(r"([A-Za-z_][A-Za-z0-9_]*)", column_body)
if not m:
return "Any"
type_token = m.group(1)
py_type = TYPE_MAP.get(type_token, "Any")
# nullable 检测
if "nullable=True" in column_body or "nullable = True" in column_body:
# 避免重复 Optional
if py_type != "Any" and not py_type.endswith(" | None"):
py_type = f"{py_type} | None"
elif py_type == "Any":
py_type = "Any | None"
return py_type
def convert_block(block_lines: list[str]) -> list[str]:
first_line = block_lines[0]
m_name = re.match(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=", first_line)
if not m_name:
return block_lines
indent = m_name.group("indent")
name = m_name.group("name")
body = extract_column_body(block_lines)
py_type = guess_python_type(body)
# 构造新的多行 mapped_column 写法
# 保留内部参数的换行缩进: 重新缩进为 indent + 4 空格 (延续原风格: 在 indent 基础上再加 4 空格)
inner_lines = body.split("\n")
if len(inner_lines) == 1:
new_line = f"{indent}{name}: Mapped[{py_type}] = mapped_column({inner_lines[0].strip()})\n"
return [new_line]
else:
# 多行情况
ind2 = indent + " "
rebuilt = [f"{indent}{name}: Mapped[{py_type}] = mapped_column(",]
for il in inner_lines:
if il.strip():
rebuilt.append(f"{ind2}{il.rstrip()}")
rebuilt.append(f"{indent})\n")
return [l + ("\n" if not l.endswith("\n") else "") for l in rebuilt]
def ensure_imports(content: str) -> str:
if "Mapped," in content or "Mapped[" in content:
# 已经可能存在导入
if "from sqlalchemy.orm import Mapped, mapped_column" not in content:
# 简单插到第一个 import sqlalchemy 之后
lines = content.splitlines()
for i, line in enumerate(lines):
if "sqlalchemy" in line and line.startswith("from sqlalchemy"):
lines.insert(i + 1, "from sqlalchemy.orm import Mapped, mapped_column")
return "\n".join(lines)
return content
def process_file(path: Path) -> tuple[str, str]:
original = path.read_text(encoding="utf-8")
lines = original.splitlines(keepends=True)
i = 0
out: list[str] = []
changed = 0
while i < len(lines):
line = lines[i]
# 跳过已是 Mapped 风格
if ALREADY_MAPPED_RE.match(line):
out.append(line)
i += 1
continue
if "= Column(" in line and re.match(r"^\s+[A-Za-z_][A-Za-z0-9_]*\s*=", line):
start, end = detect_column_block(lines, i) or (i, i)
block = lines[start : end + 1]
converted = convert_block(block)
out.extend(converted)
i = end + 1
# 如果转换结果与原始不同,计数
if "".join(converted) != "".join(block):
changed += 1
else:
out.append(line)
i += 1
new_content = "".join(out)
new_content = ensure_imports(new_content)
# 在文件末尾或头部预留统计信息打印(不写入文件,只返回)
return original, new_content if changed else original
def main():
parser = argparse.ArgumentParser(description="批量转换 SQLAlchemy 模型字段为 2.0 Mapped 写法")
parser.add_argument("--file", default="src/common/database/sqlalchemy_models.py", help="目标模型文件")
parser.add_argument("--dry-run", action="store_true", help="仅显示差异,不写回")
parser.add_argument("--write", action="store_true", help="执行写回 (与 --dry-run 互斥)")
args = parser.parse_args()
target = Path(args.file)
if not target.exists():
raise SystemExit(f"文件不存在: {target}")
original, new_content = process_file(target)
if original == new_content:
print("[INFO] 没有需要转换的内容或转换后无差异。")
return
# 简单差异输出 (行对比)
if args.dry_run or not args.write:
print("[DRY-RUN] 以下为转换后预览 (仅显示不同段落):")
import difflib
diff = difflib.unified_diff(
original.splitlines(), new_content.splitlines(), fromfile="original", tofile="converted", lineterm=""
)
count = 0
for d in diff:
print(d)
count += 1
if count == 0:
print("[INFO] 差异为空 (可能未匹配到 Column 定义)。")
if not args.write:
print("\n未写回。若确认无误,添加 --write 执行替换。")
return
backup = target.with_suffix(target.suffix + ".bak")
shutil.copyfile(target, backup)
target.write_text(new_content, encoding="utf-8")
print(f"[DONE] 已写回: {target},备份文件: {backup.name}")
if __name__ == "__main__": # pragma: no cover
main()

View File

@@ -1,284 +0,0 @@
import os
import sys
import time
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams # noqa
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def format_timestamp(timestamp: float) -> str:
"""Format timestamp to readable date string"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def calculate_interest_value_distribution(messages) -> dict[str, int]:
"""Calculate distribution of interest_value"""
distribution = {
"0.000-0.010": 0,
"0.010-0.050": 0,
"0.050-0.100": 0,
"0.100-0.500": 0,
"0.500-1.000": 0,
"1.000-2.000": 0,
"2.000-5.000": 0,
"5.000-10.000": 0,
"10.000+": 0,
}
for msg in messages:
if msg.interest_value is None or msg.interest_value == 0.0:
continue
value = float(msg.interest_value)
if value < 0.010:
distribution["0.000-0.010"] += 1
elif value < 0.050:
distribution["0.010-0.050"] += 1
elif value < 0.100:
distribution["0.050-0.100"] += 1
elif value < 0.500:
distribution["0.100-0.500"] += 1
elif value < 1.000:
distribution["0.500-1.000"] += 1
elif value < 2.000:
distribution["1.000-2.000"] += 1
elif value < 5.000:
distribution["2.000-5.000"] += 1
elif value < 10.000:
distribution["5.000-10.000"] += 1
else:
distribution["10.000+"] += 1
return distribution
def get_interest_value_stats(messages) -> dict[str, float]:
"""Calculate basic statistics for interest_value"""
values = [
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
]
if not values:
return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0}
values.sort()
count = len(values)
return {
"count": count,
"min": min(values),
"max": max(values),
"avg": sum(values) / count,
"median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2,
}
def get_available_chats() -> list[tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = (
Messages.select()
.where(
(Messages.chat_id == chat_id)
& (Messages.interest_value.is_null(False))
& (Messages.interest_value != 0.0)
)
.count()
)
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
except Exception as e:
print(f"获取聊天列表失败: {e}")
return []
def get_time_range_input() -> tuple[float | None, float | None]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24 * 3600, now
elif choice == "2":
return now - 3 * 24 * 3600, now
elif choice == "3":
return now - 7 * 24 * 3600, now
elif choice == "4":
return now - 30 * 24 * 3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
return start_time, end_time
except ValueError:
print("时间格式错误,将不限制时间范围")
return None, None
else:
return None, None
def analyze_interest_values(
chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None
) -> None:
"""Analyze interest values with optional filters"""
# 构建查询条件
query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0))
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_interest_value_distribution(messages)
stats = get_interest_value_stats(messages)
# 显示结果
print("\n=== Interest Value 分析结果 ===")
if chat_id:
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
print(f"时间范围: {format_timestamp(start_time)} 之后")
elif end_time:
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"有效消息数量: {stats['count']} (排除null和0值)")
print(f"最小值: {stats['min']:.3f}")
print(f"最大值: {stats['max']:.3f}")
print(f"平均值: {stats['avg']:.3f}")
print(f"中位数: {stats['median']:.3f}")
print("\nInterest Value 分布:")
total = stats["count"]
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
print(f"{range_name}: {count} ({percentage:.2f}%)")
def interactive_menu() -> None:
"""Interactive menu for interest value analysis"""
while True:
print("\n" + "=" * 50)
print("Interest Value 分析工具")
print("=" * 50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == "q":
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到有interest_value数据的聊天")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条有效消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
chat_id = chats[chat_choice - 1][0]
else:
print("无效选择")
continue
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_interest_values(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

File diff suppressed because it is too large Load Diff

View File

@@ -1,239 +0,0 @@
"""
插件Manifest管理命令行工具
提供插件manifest文件的创建、验证和管理功能
"""
import argparse
import os
import sys
from pathlib import Path
import orjson
from src.common.logger import get_logger
from src.plugin_system.utils.manifest_utils import (
ManifestValidator,
)
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))
logger = get_logger("manifest_tool")
def create_minimal_manifest(plugin_dir: str, plugin_name: str, description: str = "", author: str = "") -> bool:
"""创建最小化的manifest文件
Args:
plugin_dir: 插件目录
plugin_name: 插件名称
description: 插件描述
author: 插件作者
Returns:
bool: 是否创建成功
"""
manifest_path = os.path.join(plugin_dir, "_manifest.json")
if os.path.exists(manifest_path):
print(f"❌ Manifest文件已存在: {manifest_path}")
return False
# 创建最小化manifest
minimal_manifest = {
"manifest_version": 1,
"name": plugin_name,
"version": "1.0.0",
"description": description or f"{plugin_name}插件",
"author": {"name": author or "Unknown"},
}
try:
with open(manifest_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(minimal_manifest, option=orjson.OPT_INDENT_2).decode("utf-8"))
print(f"✅ 已创建最小化manifest文件: {manifest_path}")
return True
except Exception as e:
print(f"❌ 创建manifest文件失败: {e}")
return False
def create_complete_manifest(plugin_dir: str, plugin_name: str) -> bool:
"""创建完整的manifest模板文件
Args:
plugin_dir: 插件目录
plugin_name: 插件名称
Returns:
bool: 是否创建成功
"""
manifest_path = os.path.join(plugin_dir, "_manifest.json")
if os.path.exists(manifest_path):
print(f"❌ Manifest文件已存在: {manifest_path}")
return False
# 创建完整模板
complete_manifest = {
"manifest_version": 1,
"name": plugin_name,
"version": "1.0.0",
"description": f"{plugin_name}插件描述",
"author": {"name": "插件作者", "url": "https://github.com/your-username"},
"license": "MIT",
"host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
"homepage_url": "https://github.com/your-repo",
"repository_url": "https://github.com/your-repo",
"keywords": ["keyword1", "keyword2"],
"categories": ["Category1"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": False,
"plugin_type": "general",
"components": [{"type": "action", "name": "sample_action", "description": "示例动作组件"}],
},
}
try:
with open(manifest_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(complete_manifest, option=orjson.OPT_INDENT_2).decode("utf-8"))
print(f"✅ 已创建完整manifest模板: {manifest_path}")
print("💡 请根据实际情况修改manifest文件中的内容")
return True
except Exception as e:
print(f"❌ 创建manifest文件失败: {e}")
return False
def validate_manifest_file(plugin_dir: str) -> bool:
"""验证manifest文件
Args:
plugin_dir: 插件目录
Returns:
bool: 是否验证通过
"""
manifest_path = os.path.join(plugin_dir, "_manifest.json")
if not os.path.exists(manifest_path):
print(f"❌ 未找到manifest文件: {manifest_path}")
return False
try:
with open(manifest_path, encoding="utf-8") as f:
manifest_data = orjson.loads(f.read())
validator = ManifestValidator()
is_valid = validator.validate_manifest(manifest_data)
# 显示验证结果
print("📋 Manifest验证结果:")
print(validator.get_validation_report())
if is_valid:
print("✅ Manifest文件验证通过")
else:
print("❌ Manifest文件验证失败")
return is_valid
except orjson.JSONDecodeError as e:
print(f"❌ Manifest文件格式错误: {e}")
return False
except Exception as e:
print(f"❌ 验证过程中发生错误: {e}")
return False
def scan_plugins_without_manifest(root_dir: str) -> None:
"""扫描缺少manifest文件的插件
Args:
root_dir: 扫描的根目录
"""
print(f"🔍 扫描目录: {root_dir}")
plugins_without_manifest = []
for root, dirs, files in os.walk(root_dir):
# 跳过隐藏目录和__pycache__
dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"]
# 检查是否包含plugin.py文件标识为插件目录
if "plugin.py" in files:
manifest_path = os.path.join(root, "_manifest.json")
if not os.path.exists(manifest_path):
plugins_without_manifest.append(root)
if plugins_without_manifest:
print(f"❌ 发现 {len(plugins_without_manifest)} 个插件缺少manifest文件:")
for plugin_dir in plugins_without_manifest:
plugin_name = os.path.basename(plugin_dir)
print(f" - {plugin_name}: {plugin_dir}")
print("💡 使用 'python manifest_tool.py create-minimal <插件目录>' 创建manifest文件")
else:
print("✅ 所有插件都有manifest文件")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="插件Manifest管理工具")
subparsers = parser.add_subparsers(dest="command", help="可用命令")
# 创建最小化manifest命令
create_minimal_parser = subparsers.add_parser("create-minimal", help="创建最小化manifest文件")
create_minimal_parser.add_argument("plugin_dir", help="插件目录路径")
create_minimal_parser.add_argument("--name", help="插件名称")
create_minimal_parser.add_argument("--description", help="插件描述")
create_minimal_parser.add_argument("--author", help="插件作者")
# 创建完整manifest命令
create_complete_parser = subparsers.add_parser("create-complete", help="创建完整manifest模板")
create_complete_parser.add_argument("plugin_dir", help="插件目录路径")
create_complete_parser.add_argument("--name", help="插件名称")
# 验证manifest命令
validate_parser = subparsers.add_parser("validate", help="验证manifest文件")
validate_parser.add_argument("plugin_dir", help="插件目录路径")
# 扫描插件命令
scan_parser = subparsers.add_parser("scan", help="扫描缺少manifest的插件")
scan_parser.add_argument("root_dir", help="扫描的根目录路径")
args = parser.parse_args()
if not args.command:
parser.print_help()
return
try:
if args.command == "create-minimal":
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
success = create_minimal_manifest(args.plugin_dir, plugin_name, args.description or "", args.author or "")
sys.exit(0 if success else 1)
elif args.command == "create-complete":
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
success = create_complete_manifest(args.plugin_dir, plugin_name)
sys.exit(0 if success else 1)
elif args.command == "validate":
success = validate_manifest_file(args.plugin_dir)
sys.exit(0 if success else 1)
elif args.command == "scan":
scan_plugins_without_manifest(args.root_dir)
except Exception as e:
print(f"❌ 执行命令时发生错误: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,922 +0,0 @@
import os
# import time
import pickle
import sys # 新增系统模块导入
from pathlib import Path
import orjson
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from peewee import Field, IntegrityError, Model
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
# Rich 进度条和显示组件
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
# from rich.text import Text
from src.common.database.database import db
from src.common.database.sqlalchemy_models import (
ChatStreams,
Emoji,
GraphEdges,
GraphNodes,
ImageDescriptions,
Images,
Knowledges,
Messages,
PersonInfo,
ThinkingLog,
)
from src.common.logger import get_logger
logger = get_logger("mongodb_to_sqlite")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@dataclass
class MigrationConfig:
"""迁移配置类"""
mongo_collection: str
target_model: type[Model]
field_mapping: dict[str, str]
batch_size: int = 500
enable_validation: bool = True
skip_duplicates: bool = True
unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段
# 数据验证相关类已移除 - 用户要求不要数据验证
@dataclass
class MigrationCheckpoint:
"""迁移断点数据"""
collection_name: str
processed_count: int
last_processed_id: Any
timestamp: datetime
batch_errors: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class MigrationStats:
"""迁移统计信息"""
total_documents: int = 0
processed_count: int = 0
success_count: int = 0
error_count: int = 0
skipped_count: int = 0
duplicate_count: int = 0
validation_errors: int = 0
batch_insert_count: int = 0
errors: list[dict[str, Any]] = field(default_factory=list)
start_time: datetime | None = None
end_time: datetime | None = None
def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None):
"""添加错误记录"""
self.errors.append(
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
)
self.error_count += 1
def add_validation_error(self, doc_id: Any, field: str, error: str):
"""添加验证错误"""
self.add_error(doc_id, f"验证失败 - {field}: {error}")
self.validation_errors += 1
class MongoToSQLiteMigrator:
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
def __init__(self, mongo_uri: str | None = None, database_name: str | None = None):
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
self.mongo_uri = mongo_uri or self._build_mongo_uri()
self.mongo_client: MongoClient | None = None
self.mongo_db = None
# 迁移配置
self.migration_configs = self._initialize_migration_configs()
# 进度条控制台
self.console = Console()
# 检查点目录
self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints"))
self.checkpoint_dir.mkdir(exist_ok=True)
# 验证规则已禁用
self.validation_rules = self._initialize_validation_rules()
def _build_mongo_uri(self) -> str:
"""构建MongoDB连接URI"""
if mongo_uri := os.getenv("MONGODB_URI"):
return mongo_uri
user = os.getenv("MONGODB_USER")
password = os.getenv("MONGODB_PASS")
host = os.getenv("MONGODB_HOST", "localhost")
port = os.getenv("MONGODB_PORT", "27017")
auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin")
if user and password:
return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}"
else:
return f"mongodb://{host}:{port}/{self.database_name}"
def _initialize_migration_configs(self) -> list[MigrationConfig]:
"""初始化迁移配置"""
return [ # 表情包迁移配置
MigrationConfig(
mongo_collection="emoji",
target_model=Emoji,
field_mapping={
"full_path": "full_path",
"format": "format",
"hash": "emoji_hash",
"description": "description",
"emotion": "emotion",
"usage_count": "usage_count",
"last_used_time": "last_used_time",
# record_time字段将在转换时自动设置为当前时间
},
enable_validation=False, # 禁用数据验证
unique_fields=["full_path", "emoji_hash"],
),
# 聊天流迁移配置
MigrationConfig(
mongo_collection="chat_streams",
target_model=ChatStreams,
field_mapping={
"stream_id": "stream_id",
"create_time": "create_time",
"group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null而新的数据库不允许为null所以私聊聊天流是没法迁移的等更新吧。
"group_info.group_id": "group_id", # 同上
"group_info.group_name": "group_name", # 同上
"last_active_time": "last_active_time",
"platform": "platform",
"user_info.platform": "user_platform",
"user_info.user_id": "user_id",
"user_info.user_nickname": "user_nickname",
"user_info.user_cardname": "user_cardname",
},
enable_validation=False, # 禁用数据验证
unique_fields=["stream_id"],
),
# 消息迁移配置
MigrationConfig(
mongo_collection="messages",
target_model=Messages,
field_mapping={
"message_id": "message_id",
"time": "time",
"chat_id": "chat_id",
"chat_info.stream_id": "chat_info_stream_id",
"chat_info.platform": "chat_info_platform",
"chat_info.user_info.platform": "chat_info_user_platform",
"chat_info.user_info.user_id": "chat_info_user_id",
"chat_info.user_info.user_nickname": "chat_info_user_nickname",
"chat_info.user_info.user_cardname": "chat_info_user_cardname",
"chat_info.group_info.platform": "chat_info_group_platform",
"chat_info.group_info.group_id": "chat_info_group_id",
"chat_info.group_info.group_name": "chat_info_group_name",
"chat_info.create_time": "chat_info_create_time",
"chat_info.last_active_time": "chat_info_last_active_time",
"user_info.platform": "user_platform",
"user_info.user_id": "user_id",
"user_info.user_nickname": "user_nickname",
"user_info.user_cardname": "user_cardname",
"processed_plain_text": "processed_plain_text",
"memorized_times": "memorized_times",
},
enable_validation=False, # 禁用数据验证
unique_fields=["message_id"],
),
# 图片迁移配置
MigrationConfig(
mongo_collection="images",
target_model=Images,
field_mapping={
"hash": "emoji_hash",
"description": "description",
"path": "path",
"timestamp": "timestamp",
"type": "type",
},
unique_fields=["path"],
),
# 图片描述迁移配置
MigrationConfig(
mongo_collection="image_descriptions",
target_model=ImageDescriptions,
field_mapping={
"type": "type",
"hash": "image_description_hash",
"description": "description",
"timestamp": "timestamp",
},
unique_fields=["image_description_hash", "type"],
),
# 个人信息迁移配置
MigrationConfig(
mongo_collection="person_info",
target_model=PersonInfo,
field_mapping={
"person_id": "person_id",
"person_name": "person_name",
"name_reason": "name_reason",
"platform": "platform",
"user_id": "user_id",
"nickname": "nickname",
"relationship_value": "relationship_value",
"konw_time": "know_time",
},
unique_fields=["person_id"],
),
# 知识库迁移配置
MigrationConfig(
mongo_collection="knowledges",
target_model=Knowledges,
field_mapping={"content": "content", "embedding": "embedding"},
unique_fields=["content"], # 假设内容唯一
),
# 思考日志迁移配置
MigrationConfig(
mongo_collection="thinking_log",
target_model=ThinkingLog,
field_mapping={
"chat_id": "chat_id",
"trigger_text": "trigger_text",
"response_text": "response_text",
"trigger_info": "trigger_info_json",
"response_info": "response_info_json",
"timing_results": "timing_results_json",
"chat_history": "chat_history_json",
"chat_history_in_thinking": "chat_history_in_thinking_json",
"chat_history_after_response": "chat_history_after_response_json",
"heartflow_data": "heartflow_data_json",
"reasoning_data": "reasoning_data_json",
},
unique_fields=["chat_id", "trigger_text"],
),
# 图节点迁移配置
MigrationConfig(
mongo_collection="graph_data.nodes",
target_model=GraphNodes,
field_mapping={
"concept": "concept",
"memory_items": "memory_items",
"hash": "hash",
"created_time": "created_time",
"last_modified": "last_modified",
},
unique_fields=["concept"],
),
# 图边迁移配置
MigrationConfig(
mongo_collection="graph_data.edges",
target_model=GraphEdges,
field_mapping={
"source": "source",
"target": "target",
"strength": "strength",
"hash": "hash",
"created_time": "created_time",
"last_modified": "last_modified",
},
unique_fields=["source", "target"], # 组合唯一性
),
]
def _initialize_validation_rules(self) -> dict[str, Any]:
"""数据验证已禁用 - 返回空字典"""
return {}
def connect_mongodb(self) -> bool:
"""连接到MongoDB"""
try:
self.mongo_client = MongoClient(
self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10
)
# 测试连接
self.mongo_client.admin.command("ping")
self.mongo_db = self.mongo_client[self.database_name]
logger.info(f"成功连接到MongoDB: {self.database_name}")
return True
except ConnectionFailure as e:
logger.error(f"MongoDB连接失败: {e}")
return False
except Exception as e:
logger.error(f"MongoDB连接异常: {e}")
return False
def disconnect_mongodb(self):
"""断开MongoDB连接"""
if self.mongo_client:
self.mongo_client.close()
logger.info("MongoDB连接已关闭")
def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any:
"""获取嵌套字段的值"""
if "." not in field_path:
return document.get(field_path)
parts = field_path.split(".")
value = document
for part in parts:
if isinstance(value, dict):
value = value.get(part)
else:
return None
if value is None:
break
return value
def _convert_field_value(self, value: Any, target_field: Field) -> Any:
"""根据目标字段类型转换值"""
if value is None:
return None
field_type = target_field.__class__.__name__
try:
if target_field.name == "record_time" and field_type == "DateTimeField":
return datetime.now()
if field_type in ["CharField", "TextField"]:
if isinstance(value, list | dict):
return orjson.dumps(value, ensure_ascii=False)
return str(value) if value is not None else ""
elif field_type == "IntegerField":
if isinstance(value, str):
# 处理字符串数字
clean_value = value.strip()
if clean_value.replace(".", "").replace("-", "").isdigit():
return int(float(clean_value))
return 0
return int(value) if value is not None else 0
elif field_type in ["FloatField", "DoubleField"]:
return float(value) if value is not None else 0.0
elif field_type == "BooleanField":
if isinstance(value, str):
return value.lower() in ("true", "1", "yes", "on")
return bool(value)
elif field_type == "DateTimeField":
if isinstance(value, int | float):
return datetime.fromtimestamp(value)
elif isinstance(value, str):
try:
# 尝试解析ISO格式日期
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
try:
# 尝试解析时间戳字符串
return datetime.fromtimestamp(float(value))
except ValueError:
return datetime.now()
return datetime.now()
return value
except (ValueError, TypeError) as e:
logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}")
return self._get_default_value_for_field(target_field)
def _get_default_value_for_field(self, field: Field) -> Any:
"""获取字段的默认值"""
field_type = field.__class__.__name__
if hasattr(field, "default") and field.default is not None:
return field.default
if field.null:
return None
# 根据字段类型返回默认值
if field_type in ["CharField", "TextField"]:
return ""
elif field_type == "IntegerField":
return 0
elif field_type in ["FloatField", "DoubleField"]:
return 0.0
elif field_type == "BooleanField":
return False
elif field_type == "DateTimeField":
return datetime.now()
return None
def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
"""数据验证已禁用 - 始终返回True"""
return True
def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any):
"""保存迁移断点"""
checkpoint = MigrationCheckpoint(
collection_name=collection_name,
processed_count=processed_count,
last_processed_id=last_id,
timestamp=datetime.now(),
)
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
try:
with open(checkpoint_file, "wb") as f:
pickle.dump(checkpoint, f)
except Exception as e:
logger.warning(f"保存断点失败: {e}")
def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None:
"""加载迁移断点"""
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
if not checkpoint_file.exists():
return None
try:
with open(checkpoint_file, "rb") as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"加载断点失败: {e}")
return None
def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int:
"""批量插入数据"""
if not data_list:
return 0
success_count = 0
try:
with db.atomic():
# 分批插入避免SQL语句过长
batch_size = 100
for i in range(0, len(data_list), batch_size):
batch = data_list[i : i + batch_size]
model.insert_many(batch).execute()
success_count += len(batch)
except Exception as e:
logger.error(f"批量插入失败: {e}")
# 如果批量插入失败,尝试逐个插入
for data in data_list:
try:
model.create(**data)
success_count += 1
except Exception:
pass # 忽略单个插入失败
return success_count
def _check_duplicate_by_unique_fields(
self, model: type[Model], data: dict[str, Any], unique_fields: list[str]
) -> bool:
"""根据唯一字段检查重复"""
if not unique_fields:
return False
try:
query = model.select()
for field_name in unique_fields:
if field_name in data and data[field_name] is not None:
field_obj = getattr(model, field_name)
query = query.where(field_obj == data[field_name])
return query.exists()
except Exception as e:
logger.debug(f"重复检查失败: {e}")
return False
def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None:
"""使用ORM创建模型实例"""
try:
# 过滤掉不存在的字段
valid_data = {}
for field_name, value in data.items():
if hasattr(model, field_name):
valid_data[field_name] = value
else:
logger.debug(f"跳过未知字段: {field_name}")
# 创建实例
instance = model.create(**valid_data)
return instance
except IntegrityError as e:
# 处理唯一约束冲突等完整性错误
logger.debug(f"完整性约束冲突: {e}")
return None
except Exception as e:
logger.error(f"创建模型实例失败: {e}")
return None
def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
"""迁移单个集合 - 使用优化的批量插入和进度条"""
stats = MigrationStats()
stats.start_time = datetime.now()
# 检查是否有断点
checkpoint = self._load_checkpoint(config.mongo_collection)
start_from_id = checkpoint.last_processed_id if checkpoint else None
if checkpoint:
stats.processed_count = checkpoint.processed_count
logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录")
logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}")
try:
# 获取MongoDB集合
mongo_collection = self.mongo_db[config.mongo_collection]
# 构建查询条件(用于断点恢复)
query = {}
if start_from_id:
query = {"_id": {"$gt": start_from_id}}
stats.total_documents = mongo_collection.count_documents(query)
if stats.total_documents == 0:
logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移")
return stats
logger.info(f"待迁移文档数量: {stats.total_documents}")
# 创建Rich进度条
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents)
# 批量处理数据
batch_data = []
batch_count = 0
last_processed_id = None
for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size):
try:
doc_id = mongo_doc.get("_id", "unknown")
last_processed_id = doc_id
# 构建目标数据
target_data = {}
for mongo_field, sqlite_field in config.field_mapping.items():
value = self._get_nested_value(mongo_doc, mongo_field)
# 获取目标字段对象并转换类型
if hasattr(config.target_model, sqlite_field):
field_obj = getattr(config.target_model, sqlite_field)
converted_value = self._convert_field_value(value, field_obj)
target_data[sqlite_field] = converted_value
# 数据验证已禁用
# if config.enable_validation:
# if not self._validate_data(config.mongo_collection, target_data, doc_id, stats):
# stats.skipped_count += 1
# continue
# 重复检查
if config.skip_duplicates and self._check_duplicate_by_unique_fields(
config.target_model, target_data, config.unique_fields
):
stats.duplicate_count += 1
stats.skipped_count += 1
logger.debug(f"跳过重复记录: {doc_id}")
continue
# 添加到批量数据
batch_data.append(target_data)
stats.processed_count += 1
# 执行批量插入
if len(batch_data) >= config.batch_size:
success_count = self._batch_insert(config.target_model, batch_data)
stats.success_count += success_count
stats.batch_insert_count += 1
# 保存断点
self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id)
batch_data.clear()
batch_count += 1
# 更新进度条
progress.update(task, advance=config.batch_size)
except Exception as e:
doc_id = mongo_doc.get("_id", "unknown")
stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
# 处理剩余的批量数据
if batch_data:
success_count = self._batch_insert(config.target_model, batch_data)
stats.success_count += success_count
stats.batch_insert_count += 1
progress.update(task, advance=len(batch_data))
# 完成进度条
progress.update(task, completed=stats.total_documents)
stats.end_time = datetime.now()
duration = stats.end_time - stats.start_time
logger.info(
f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n"
f"总计: {stats.total_documents}, 成功: {stats.success_count}, "
f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n"
f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}"
)
# 清理断点文件
checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl"
if checkpoint_file.exists():
checkpoint_file.unlink()
except Exception as e:
logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}")
stats.add_error("collection_error", str(e))
return stats
def migrate_all(self) -> dict[str, MigrationStats]:
"""执行所有迁移任务"""
logger.info("开始执行数据库迁移...")
if not self.connect_mongodb():
logger.error("无法连接到MongoDB迁移终止")
return {}
all_stats = {}
try:
# 创建总体进度表格
total_collections = len(self.migration_configs)
self.console.print(
Panel(
f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
f"[yellow]总集合数: {total_collections}[/yellow]",
title="迁移开始",
expand=False,
)
)
for idx, config in enumerate(self.migration_configs, 1):
self.console.print(
f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]"
)
stats = self.migrate_collection(config)
all_stats[config.mongo_collection] = stats
# 显示单个集合的快速统计
if stats.processed_count > 0:
success_rate = stats.success_count / stats.processed_count * 100
if success_rate >= 95:
status_emoji = ""
status_color = "bright_green"
elif success_rate >= 80:
status_emoji = "⚠️"
status_color = "yellow"
else:
status_emoji = ""
status_color = "red"
self.console.print(
f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} "
f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]"
)
# 错误率检查
if stats.processed_count > 0:
error_rate = stats.error_count / stats.processed_count
if error_rate > 0.1: # 错误率超过10%
self.console.print(
f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} "
f"({stats.error_count}/{stats.processed_count})[/red]"
)
finally:
self.disconnect_mongodb()
self._print_migration_summary(all_stats)
return all_stats
def _print_migration_summary(self, all_stats: dict[str, MigrationStats]):
"""使用Rich打印美观的迁移汇总信息"""
# 计算总体统计
total_processed = sum(stats.processed_count for stats in all_stats.values())
total_success = sum(stats.success_count for stats in all_stats.values())
total_errors = sum(stats.error_count for stats in all_stats.values())
total_skipped = sum(stats.skipped_count for stats in all_stats.values())
total_duplicates = sum(stats.duplicate_count for stats in all_stats.values())
total_validation_errors = sum(stats.validation_errors for stats in all_stats.values())
total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values())
# 计算总耗时
total_duration_seconds = 0
for stats in all_stats.values():
if stats.start_time and stats.end_time:
duration = stats.end_time - stats.start_time
total_duration_seconds += duration.total_seconds()
# 创建详细统计表格
table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta")
table.add_column("集合名称", style="cyan", width=20)
table.add_column("文档总数", justify="right", style="blue")
table.add_column("处理数量", justify="right", style="green")
table.add_column("成功数量", justify="right", style="green")
table.add_column("错误数量", justify="right", style="red")
table.add_column("跳过数量", justify="right", style="yellow")
table.add_column("重复数量", justify="right", style="bright_yellow")
table.add_column("验证错误", justify="right", style="red")
table.add_column("批次数", justify="right", style="purple")
table.add_column("成功率", justify="right", style="bright_green")
table.add_column("耗时(秒)", justify="right", style="blue")
for collection_name, stats in all_stats.items():
success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0
duration = 0
if stats.start_time and stats.end_time:
duration = (stats.end_time - stats.start_time).total_seconds()
# 根据成功率设置颜色
if success_rate >= 95:
success_rate_style = "[bright_green]"
elif success_rate >= 80:
success_rate_style = "[yellow]"
else:
success_rate_style = "[red]"
table.add_row(
collection_name,
str(stats.total_documents),
str(stats.processed_count),
str(stats.success_count),
f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0",
f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0",
f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0",
f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0",
str(stats.batch_insert_count),
f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}",
f"{duration:.2f}",
)
# 添加总计行
total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0
if total_success_rate >= 95:
total_rate_style = "[bright_green]"
elif total_success_rate >= 80:
total_rate_style = "[yellow]"
else:
total_rate_style = "[red]"
table.add_section()
table.add_row(
"[bold]总计[/bold]",
f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]",
f"[bold]{total_processed}[/bold]",
f"[bold]{total_success}[/bold]",
f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]",
f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]",
f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]"
if total_duplicates > 0
else "[bold]0[/bold]",
f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]",
f"[bold]{total_batch_inserts}[/bold]",
f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]",
f"[bold]{total_duration_seconds:.2f}[/bold]",
)
self.console.print(table)
# 创建状态面板
status_items = []
if total_errors > 0:
status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]")
if total_validation_errors > 0:
status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]")
if total_duplicates > 0:
status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]")
if total_success_rate >= 95:
status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]")
elif total_success_rate >= 80:
status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]")
else:
status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]")
if status_items:
status_panel = Panel(
"\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow"
)
self.console.print(status_panel)
# 性能统计面板
avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0
performance_info = (
f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f}\n"
f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n"
f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作"
)
performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green")
self.console.print(performance_panel)
def add_migration_config(self, config: MigrationConfig):
"""添加新的迁移配置"""
self.migration_configs.append(config)
def migrate_single_collection(self, collection_name: str) -> MigrationStats | None:
"""迁移单个指定的集合"""
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
if not config:
logger.error(f"未找到集合 {collection_name} 的迁移配置")
return None
if not self.connect_mongodb():
logger.error("无法连接到MongoDB")
return None
try:
stats = self.migrate_collection(config)
self._print_migration_summary({collection_name: stats})
return stats
finally:
self.disconnect_mongodb()
def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str):
"""导出错误报告"""
error_report = {
"timestamp": datetime.now().isoformat(),
"summary": {
collection: {
"total": stats.total_documents,
"processed": stats.processed_count,
"success": stats.success_count,
"errors": stats.error_count,
"skipped": stats.skipped_count,
"duplicates": stats.duplicate_count,
}
for collection, stats in all_stats.items()
},
"errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors},
}
try:
with open(filepath, "w", encoding="utf-8") as f:
orjson.dumps(error_report, f, ensure_ascii=False, indent=2)
logger.info(f"错误报告已导出到: {filepath}")
except Exception as e:
logger.error(f"导出错误报告失败: {e}")
def main():
"""主程序入口"""
migrator = MongoToSQLiteMigrator()
# 执行迁移
migration_results = migrator.migrate_all()
# 导出错误报告(如果有错误)
if any(stats.error_count > 0 for stats in migration_results.values()):
error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
migrator.export_error_report(migration_results, error_report_path)
logger.info("数据迁移完成!")
if __name__ == "__main__":
main()

View File

@@ -1,556 +0,0 @@
#!/bin/bash
# MaiCore & NapCat Adapter一键安装脚本 by Cookie_987
# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9
# 请小心使用任何一键脚本!
INSTALLER_VERSION="0.0.5-refactor"
LANG=C.UTF-8
# 如无法访问GitHub请修改此处镜像地址
GITHUB_REPO="https://ghfast.top/https://github.com"
# 颜色输出
GREEN="\e[32m"
RED="\e[31m"
RESET="\e[0m"
# 需要的基本软件包
declare -A REQUIRED_PACKAGES=(
["common"]="git sudo python3 curl gnupg"
["debian"]="python3-venv python3-pip build-essential"
["ubuntu"]="python3-venv python3-pip build-essential"
["centos"]="epel-release python3-pip python3-devel gcc gcc-c++ make"
["arch"]="python-virtualenv python-pip base-devel"
)
# 默认项目目录
DEFAULT_INSTALL_DIR="/opt/maicore"
# 服务名称
SERVICE_NAME="maicore"
SERVICE_NAME_WEB="maicore-web"
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
IS_INSTALL_NAPCAT=false
IS_INSTALL_DEPENDENCIES=false
# 检查是否已安装
check_installed() {
[[ -f /etc/systemd/system/${SERVICE_NAME}.service ]]
}
# 加载安装信息
load_install_info() {
if [[ -f /etc/maicore_install.conf ]]; then
source /etc/maicore_install.conf
else
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
BRANCH="refactor"
fi
}
# 显示管理菜单
show_menu() {
while true; do
choice=$(whiptail --title "MaiCore管理菜单" --menu "请选择要执行的操作:" 15 60 7 \
"1" "启动MaiCore" \
"2" "停止MaiCore" \
"3" "重启MaiCore" \
"4" "启动NapCat Adapter" \
"5" "停止NapCat Adapter" \
"6" "重启NapCat Adapter" \
"7" "拉取最新MaiCore仓库" \
"8" "切换分支" \
"9" "退出" 3>&1 1>&2 2>&3)
[[ $? -ne 0 ]] && exit 0
case "$choice" in
1)
systemctl start ${SERVICE_NAME}
whiptail --msgbox "✅MaiCore已启动" 10 60
;;
2)
systemctl stop ${SERVICE_NAME}
whiptail --msgbox "🛑MaiCore已停止" 10 60
;;
3)
systemctl restart ${SERVICE_NAME}
whiptail --msgbox "🔄MaiCore已重启" 10 60
;;
4)
systemctl start ${SERVICE_NAME_NBADAPTER}
whiptail --msgbox "✅NapCat Adapter已启动" 10 60
;;
5)
systemctl stop ${SERVICE_NAME_NBADAPTER}
whiptail --msgbox "🛑NapCat Adapter已停止" 10 60
;;
6)
systemctl restart ${SERVICE_NAME_NBADAPTER}
whiptail --msgbox "🔄NapCat Adapter已重启" 10 60
;;
7)
update_dependencies
;;
8)
switch_branch
;;
9)
exit 0
;;
*)
whiptail --msgbox "无效选项!" 10 60
;;
esac
done
}
# 更新依赖
update_dependencies() {
whiptail --title "⚠" --msgbox "更新后请阅读教程" 10 60
systemctl stop ${SERVICE_NAME}
cd "${INSTALL_DIR}/MaiBot" || {
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
return 1
}
if ! git pull origin "${BRANCH}"; then
whiptail --msgbox "🚫 代码更新失败!" 10 60
return 1
fi
source "${INSTALL_DIR}/venv/bin/activate"
if ! pip install -r requirements.txt; then
whiptail --msgbox "🚫 依赖安装失败!" 10 60
deactivate
return 1
fi
deactivate
whiptail --msgbox "✅ 已停止服务并拉取最新仓库提交" 10 60
}
# 切换分支
switch_branch() {
new_branch=$(whiptail --inputbox "请输入要切换的分支名称:" 10 60 "${BRANCH}" 3>&1 1>&2 2>&3)
[[ -z "$new_branch" ]] && {
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
return 1
}
cd "${INSTALL_DIR}/MaiBot" || {
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
return 1
}
if ! git ls-remote --exit-code --heads origin "${new_branch}" >/dev/null 2>&1; then
whiptail --msgbox "🚫 分支 ${new_branch} 不存在!" 10 60
return 1
fi
if ! git checkout "${new_branch}"; then
whiptail --msgbox "🚫 分支切换失败!" 10 60
return 1
fi
if ! git pull origin "${new_branch}"; then
whiptail --msgbox "🚫 代码拉取失败!" 10 60
return 1
fi
systemctl stop ${SERVICE_NAME}
source "${INSTALL_DIR}/venv/bin/activate"
pip install -r requirements.txt
deactivate
sed -i "s/^BRANCH=.*/BRANCH=${new_branch}/" /etc/maicore_install.conf
BRANCH="${new_branch}"
check_eula
whiptail --msgbox "✅ 已停止服务并切换到分支 ${new_branch} " 10 60
}
check_eula() {
# 首先计算当前EULA的MD5值
current_md5=$(md5sum "${INSTALL_DIR}/MaiBot/EULA.md" | awk '{print $1}')
# 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(md5sum "${INSTALL_DIR}/MaiBot/PRIVACY.md" | awk '{print $1}')
# 如果当前的md5值为空则直接返回
if [[ -z $current_md5 || -z $current_md5_privacy ]]; then
whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60
fi
# 检查eula.confirmed文件是否存在
if [[ -f ${INSTALL_DIR}/MaiBot/eula.confirmed ]]; then
# 如果存在则检查其中包含的md5与current_md5是否一致
confirmed_md5=$(cat ${INSTALL_DIR}/MaiBot/eula.confirmed)
else
confirmed_md5=""
fi
# 检查privacy.confirmed文件是否存在
if [[ -f ${INSTALL_DIR}/MaiBot/privacy.confirmed ]]; then
# 如果存在则检查其中包含的md5与current_md5是否一致
confirmed_md5_privacy=$(cat ${INSTALL_DIR}/MaiBot/privacy.confirmed)
else
confirmed_md5_privacy=""
fi
# 如果EULA或隐私条款有更新提示用户重新确认
if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then
whiptail --title "📜 使用协议更新" --yesno "检测到MaiCore EULA或隐私条款已更新。\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议 \n\n " 12 70
if [[ $? -eq 0 ]]; then
echo -n $current_md5 > ${INSTALL_DIR}/MaiBot/eula.confirmed
echo -n $current_md5_privacy > ${INSTALL_DIR}/MaiBot/privacy.confirmed
else
exit 1
fi
fi
}
# ----------- 主安装流程 -----------
run_installation() {
# 1/6: 检测是否安装 whiptail
if ! command -v whiptail &>/dev/null; then
echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}"
if command -v apt-get &>/dev/null; then
apt-get update && apt-get install -y whiptail
elif command -v pacman &>/dev/null; then
pacman -Syu --noconfirm whiptail
elif command -v yum &>/dev/null; then
yum install -y whiptail
else
echo -e "${RED}[Error] 无受支持的包管理器,无法安装 whiptail!${RESET}"
exit 1
fi
fi
whiptail --title " 提示" --msgbox "如果您没有特殊需求请优先使用docker方式部署。" 10 60
# 协议确认
if ! (whiptail --title " [1/6] 使用协议" --yes-button "我同意" --no-button "我拒绝" --yesno "使用MaiCore及此脚本前请先阅读EULA协议及隐私协议\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议" 12 70); then
exit 1
fi
# 欢迎信息
whiptail --title "[2/6] 欢迎使用MaiCore一键安装脚本 by Cookie987" --msgbox "检测到您未安装MaiCore将自动进入安装流程安装完成后再次运行此脚本即可进入管理菜单。\n\n项目处于活跃开发阶段代码可能随时更改\n文档未完善有问题可以提交 Issue 或者 Discussion\nQQ机器人存在被限制风险请自行了解谨慎使用\n由于持续迭代可能存在一些已知或未知的bug\n由于开发中可能消耗较多token\n\n本脚本可能更新不及时如遇到bug请优先尝试手动部署以确定是否为脚本问题" 17 60
# 系统检查
check_system() {
if [[ "$(id -u)" -ne 0 ]]; then
whiptail --title "🚫 权限不足" --msgbox "请使用 root 用户运行此脚本!\n执行方式: sudo bash $0" 10 60
exit 1
fi
if [[ -f /etc/os-release ]]; then
source /etc/os-release
if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then
return
elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then
return
elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then
return
elif [[ "$ID" == "arch" ]]; then
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
return
else
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
exit 1
fi
else
whiptail --title "⚠️ 无法检测系统" --msgbox "无法识别系统版本,安装已终止。" 10 60
exit 1
fi
}
check_system
# 设置包管理器
case "$ID" in
debian|ubuntu)
PKG_MANAGER="apt"
;;
centos)
PKG_MANAGER="yum"
;;
arch)
# 添加arch包管理器
PKG_MANAGER="pacman"
;;
esac
# 检查NapCat
check_napcat() {
if command -v napcat &>/dev/null; then
NAPCAT_INSTALLED=true
else
NAPCAT_INSTALLED=false
fi
}
check_napcat
# 安装必要软件包
install_packages() {
missing_packages=()
# 检查 common 及当前系统专属依赖
for package in ${REQUIRED_PACKAGES["common"]} ${REQUIRED_PACKAGES["$ID"]}; do
case "$PKG_MANAGER" in
apt)
dpkg -s "$package" &>/dev/null || missing_packages+=("$package")
;;
yum)
rpm -q "$package" &>/dev/null || missing_packages+=("$package")
;;
pacman)
pacman -Qi "$package" &>/dev/null || missing_packages+=("$package")
;;
esac
done
if [[ ${#missing_packages[@]} -gt 0 ]]; then
whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装" 10 60
if [[ $? -eq 0 ]]; then
IS_INSTALL_DEPENDENCIES=true
else
whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续" 10 60 || exit 1
fi
fi
}
install_packages
# 安装NapCat
install_napcat() {
[[ $NAPCAT_INSTALLED == true ]] && return
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat是否安装\n如果您想使用远程NapCat请跳过此步。" 10 60 && {
IS_INSTALL_NAPCAT=true
}
}
# 仅在非Arch系统上安装NapCat
[[ "$ID" != "arch" ]] && install_napcat
# Python版本检查
check_python() {
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
if ! python3 -c "import sys; exit(0) if sys.version_info >= (3,10) else exit(1)"; then
whiptail --title "⚠️ [4/6] Python 版本过低" --msgbox "检测到 Python 版本为 $PYTHON_VERSION,需要 3.10 或以上!\n请升级 Python 后重新运行本脚本。" 10 60
exit 1
fi
}
# 如果没安装python则不检查python版本
if command -v python3 &>/dev/null; then
check_python
fi
# 选择分支
choose_branch() {
BRANCH=$(whiptail --title "🔀 选择分支" --radiolist "请选择要安装的分支:" 15 60 4 \
"main" "稳定版本(推荐)" ON \
"dev" "开发版(不知道什么意思就别选)" OFF \
"classical" "经典版0.6.0以前的版本)" OFF \
"custom" "自定义分支" OFF 3>&1 1>&2 2>&3)
RETVAL=$?
if [ $RETVAL -ne 0 ]; then
whiptail --msgbox "🚫 操作取消!" 10 60
exit 1
fi
if [[ "$BRANCH" == "custom" ]]; then
BRANCH=$(whiptail --title "🔀 自定义分支" --inputbox "请输入自定义分支名称:" 10 60 "refactor" 3>&1 1>&2 2>&3)
RETVAL=$?
if [ $RETVAL -ne 0 ]; then
whiptail --msgbox "🚫 输入取消!" 10 60
exit 1
fi
if [[ -z "$BRANCH" ]]; then
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
exit 1
fi
fi
}
choose_branch
# 选择安装路径
choose_install_dir() {
INSTALL_DIR=$(whiptail --title "📂 [6/6] 选择安装路径" --inputbox "请输入MaiCore的安装目录" 10 60 "$DEFAULT_INSTALL_DIR" 3>&1 1>&2 2>&3)
[[ -z "$INSTALL_DIR" ]] && {
whiptail --title "⚠️ 取消输入" --yesno "未输入安装路径,是否退出安装?" 10 60 && exit 1
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
}
}
choose_install_dir
# 确认安装
confirm_install() {
local confirm_msg="请确认以下更改:\n\n"
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
confirm_msg+="🔀 分支: $BRANCH\n"
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
confirm_msg+="\n注意本脚本默认使用ghfast.top为GitHub进行加速如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1
}
confirm_install
# 开始安装
echo -e "${GREEN}安装${missing_packages[@]}...${RESET}"
if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then
case "$PKG_MANAGER" in
apt)
apt update && apt install -y "${missing_packages[@]}"
;;
yum)
yum install -y "${missing_packages[@]}" --nobest
;;
pacman)
pacman -S --noconfirm "${missing_packages[@]}"
;;
esac
fi
if [[ $IS_INSTALL_NAPCAT == true ]]; then
echo -e "${GREEN}安装 NapCat...${RESET}"
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
fi
echo -e "${GREEN}创建安装目录...${RESET}"
mkdir -p "$INSTALL_DIR"
cd "$INSTALL_DIR" || exit 1
echo -e "${GREEN}设置Python虚拟环境...${RESET}"
python3 -m venv venv
source venv/bin/activate
echo -e "${GREEN}克隆MaiCore仓库...${RESET}"
git clone -b "$BRANCH" "$GITHUB_REPO/MaiM-with-u/MaiBot" MaiBot || {
echo -e "${RED}克隆MaiCore仓库失败${RESET}"
exit 1
}
echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}"
git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || {
echo -e "${RED}克隆 maim_message 包仓库失败!${RESET}"
exit 1
}
echo -e "${GREEN}克隆 nonebot-plugin-maibot-adapters 仓库...${RESET}"
git clone $GITHUB_REPO/MaiM-with-u/MaiBot-Napcat-Adapter.git || {
echo -e "${RED}克隆 MaiBot-Napcat-Adapter.git 仓库失败!${RESET}"
exit 1
}
echo -e "${GREEN}安装Python依赖...${RESET}"
pip install -r MaiBot/requirements.txt
cd MaiBot
pip install uv
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
cd ..
echo -e "${GREEN}安装maim_message依赖...${RESET}"
cd maim_message
uv pip install -i https://mirrors.aliyun.com/pypi/simple -e .
cd ..
echo -e "${GREEN}部署MaiBot Napcat Adapter...${RESET}"
cd MaiBot-Napcat-Adapter
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
cd ..
echo -e "${GREEN}同意协议...${RESET}"
# 首先计算当前EULA的MD5值
current_md5=$(md5sum "MaiBot/EULA.md" | awk '{print $1}')
# 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(md5sum "MaiBot/PRIVACY.md" | awk '{print $1}')
echo -n $current_md5 > MaiBot/eula.confirmed
echo -n $current_md5_privacy > MaiBot/privacy.confirmed
echo -e "${GREEN}创建系统服务...${RESET}"
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
[Unit]
Description=MaiCore
After=network.target ${SERVICE_NAME_NBADAPTER}.service
[Service]
Type=simple
WorkingDirectory=${INSTALL_DIR}/MaiBot
ExecStart=$INSTALL_DIR/venv/bin/python3 bot.py
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
EOF
# cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
# [Unit]
# Description=MaiCore WebUI
# After=network.target ${SERVICE_NAME}.service
# [Service]
# Type=simple
# WorkingDirectory=${INSTALL_DIR}/MaiBot
# ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
# Restart=always
# RestartSec=10s
# [Install]
# WantedBy=multi-user.target
# EOF
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
[Unit]
Description=MaiBot Napcat Adapter
After=network.target mongod.service ${SERVICE_NAME}.service
[Service]
Type=simple
WorkingDirectory=${INSTALL_DIR}/MaiBot-Napcat-Adapter
ExecStart=$INSTALL_DIR/venv/bin/python3 main.py
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
EOF
systemctl daemon-reload
# 保存安装信息
echo "INSTALLER_VERSION=${INSTALLER_VERSION}" > /etc/maicore_install.conf
echo "INSTALL_DIR=${INSTALL_DIR}" >> /etc/maicore_install.conf
echo "BRANCH=${BRANCH}" >> /etc/maicore_install.conf
whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成\n已创建系统服务${SERVICE_NAME}${SERVICE_NAME_WEB}${SERVICE_NAME_NBADAPTER}\n\n使用以下命令管理服务\n启动服务systemctl start ${SERVICE_NAME}\n查看状态systemctl status ${SERVICE_NAME}" 14 60
}
# ----------- 主执行流程 -----------
# 检查root权限
[[ $(id -u) -ne 0 ]] && {
echo -e "${RED}请使用root用户运行此脚本${RESET}"
exit 1
}
# 如果已安装显示菜单,并检查协议是否更新
if check_installed; then
load_install_info
check_eula
show_menu
else
run_installation
# 安装完成后询问是否启动
if whiptail --title "安装完成" --yesno "是否立即启动MaiCore服务" 10 60; then
systemctl start ${SERVICE_NAME}
whiptail --msgbox "✅ 服务已启动!\n使用 systemctl status ${SERVICE_NAME} 查看状态" 10 60
fi
fi

View File

@@ -265,7 +265,8 @@ class AntiPromptInjector:
async with get_db_session() as session:
# 删除对应的消息记录
stmt = delete(Messages).where(Messages.message_id == message_id)
result = session.execute(stmt)
# 注意: 异步会话需要 await 执行,否则 result 是 coroutine无法获取 rowcount
result = await session.execute(stmt)
await session.commit()
if result.rowcount > 0:
@@ -296,7 +297,7 @@ class AntiPromptInjector:
.where(Messages.message_id == message_id)
.values(processed_plain_text=new_content, display_message=new_content)
)
result = session.execute(stmt)
result = await session.execute(stmt)
await session.commit()
if result.rowcount > 0:

View File

@@ -5,9 +5,9 @@
"""
import datetime
from typing import Any
from typing import Any, Optional, TypedDict, Literal, Union, Callable, TypeVar, cast
from sqlalchemy import select
from sqlalchemy import select, delete
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
from src.common.logger import get_logger
@@ -16,8 +16,30 @@ from src.config.config import global_config
logger = get_logger("anti_injector.statistics")
TNum = TypeVar("TNum", int, float)
def _add_optional(a: Optional[TNum], b: TNum) -> TNum:
"""安全相加:左值可能为 None。
Args:
a: 可能为 None 的当前值
b: 要累加的增量(非 None
Returns:
新的累加结果(与 b 同类型)
"""
if a is None:
return b
return cast(TNum, a + b) # a 不为 None此处显式 cast 便于类型检查
class AntiInjectionStatistics:
"""反注入系统统计管理类"""
"""反注入系统统计管理类
主要改进:
- 对 "可能为 None" 的数值字段做集中安全处理,减少在业务逻辑里反复判空。
- 补充类型注解便于静态检查器Pylance/Pyright识别。
"""
def __init__(self):
"""初始化统计管理器"""
@@ -25,8 +47,12 @@ class AntiInjectionStatistics:
"""当前会话开始时间"""
@staticmethod
async def get_or_create_stats():
"""获取或创建统计记录"""
async def get_or_create_stats() -> Optional[AntiInjectionStats]: # type: ignore[name-defined]
"""获取或创建统计记录
Returns:
AntiInjectionStats | None: 成功返回模型实例,否则 None
"""
try:
async with get_db_session() as session:
# 获取最新的统计记录,如果没有则创建
@@ -46,8 +72,15 @@ class AntiInjectionStatistics:
return None
@staticmethod
async def update_stats(**kwargs):
"""更新统计数据"""
async def update_stats(**kwargs: Any) -> None:
"""更新统计数据(批量可选字段)
支持字段:
- processing_time_delta: float 累加到 processing_time_total
- last_processing_time: float 设置 last_process_time
- total_messages / detected_injections / blocked_messages / shielded_messages / error_count: 累加
- 其他任意字段:直接赋值(若模型存在该属性)
"""
try:
async with get_db_session() as session:
stats = (
@@ -63,13 +96,12 @@ class AntiInjectionStatistics:
for key, value in kwargs.items():
if key == "processing_time_delta":
# 处理时间累加 - 确保不为 None
if stats.processing_time_total is None:
stats.processing_time_total = 0.0
stats.processing_time_total += value
delta = float(value)
stats.processing_time_total = _add_optional(stats.processing_time_total, delta) # type: ignore[attr-defined]
continue
elif key == "last_processing_time":
# 直接设置最后处理时间
stats.last_process_time = value
stats.last_process_time = float(value)
continue
elif hasattr(stats, key):
if key in [
@@ -79,12 +111,10 @@ class AntiInjectionStatistics:
"shielded_messages",
"error_count",
]:
# 累加类型的字段 - 确保不为None
current_value = getattr(stats, key)
if current_value is None:
setattr(stats, key, value)
else:
setattr(stats, key, current_value + value)
# 累加类型的字段 - 统一用辅助函数
current_value = cast(Optional[int], getattr(stats, key))
increment = int(value)
setattr(stats, key, _add_optional(current_value, increment))
else:
# 直接设置的字段
setattr(stats, key, value)
@@ -114,10 +144,11 @@ class AntiInjectionStatistics:
stats = await self.get_or_create_stats()
# 计算派生统计信息 - 处理 None 值
total_messages = stats.total_messages or 0
detected_injections = stats.detected_injections or 0
processing_time_total = stats.processing_time_total or 0.0
total_messages = stats.total_messages or 0 # type: ignore[attr-defined]
detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined]
processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined]
detection_rate = (detected_injections / total_messages * 100) if total_messages > 0 else 0
avg_processing_time = (processing_time_total / total_messages) if total_messages > 0 else 0
@@ -127,17 +158,22 @@ class AntiInjectionStatistics:
current_time = datetime.datetime.now()
uptime = current_time - self.session_start_time
last_proc = stats.last_process_time # type: ignore[attr-defined]
blocked_messages = stats.blocked_messages or 0 # type: ignore[attr-defined]
shielded_messages = stats.shielded_messages or 0 # type: ignore[attr-defined]
error_count = stats.error_count or 0 # type: ignore[attr-defined]
return {
"status": "enabled",
"uptime": str(uptime),
"total_messages": total_messages,
"detected_injections": detected_injections,
"blocked_messages": stats.blocked_messages or 0,
"shielded_messages": stats.shielded_messages or 0,
"blocked_messages": blocked_messages,
"shielded_messages": shielded_messages,
"detection_rate": f"{detection_rate:.2f}%",
"average_processing_time": f"{avg_processing_time:.3f}s",
"last_processing_time": f"{stats.last_process_time:.3f}s" if stats.last_process_time else "0.000s",
"error_count": stats.error_count or 0,
"last_processing_time": f"{last_proc:.3f}s" if last_proc else "0.000s",
"error_count": error_count,
}
except Exception as e:
logger.error(f"获取统计信息失败: {e}")
@@ -149,7 +185,7 @@ class AntiInjectionStatistics:
try:
async with get_db_session() as session:
# 删除现有统计记录
await session.execute(select(AntiInjectionStats).delete())
await session.execute(delete(AntiInjectionStats))
await session.commit()
logger.info("统计信息已重置")
except Exception as e:

View File

@@ -51,7 +51,7 @@ class UserBanManager:
remaining_time = ban_duration - (datetime.datetime.now() - ban_record.created_at)
return False, None, f"用户被封禁中,剩余时间: {remaining_time}"
else:
# 封禁已过期,重置违规次数
# 封禁已过期,重置违规次数与时间(模型已使用 Mapped 类型,可直接赋值)
ban_record.violation_num = 0
ban_record.created_at = datetime.datetime.now()
await session.commit()
@@ -92,7 +92,6 @@ class UserBanManager:
await session.commit()
# 检查是否需要自动封禁
if ban_record.violation_num >= self.config.auto_ban_violation_threshold:
logger.warning(f"用户 {platform}:{user_id} 违规次数达到 {ban_record.violation_num},触发自动封禁")
# 只有在首次达到阈值时才更新封禁开始时间

View File

@@ -377,11 +377,12 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], r
class EmojiManager:
_instance = None
_initialized: bool = False # 显式声明,避免属性未定义错误
def __new__(cls) -> "EmojiManager":
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
# 类属性已声明,无需再次赋值
return cls._instance
def __init__(self) -> None:
@@ -399,7 +400,8 @@ class EmojiManager:
self.emoji_num_max = global_config.emoji.max_reg_num
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表使用类型注解明确列表元素类型
logger.info("启动表情包管理器")
self._initialized = True
logger.info("启动表情包管理器")
def shutdown(self) -> None:
@@ -752,8 +754,8 @@ class EmojiManager:
try:
emoji_record = await self.get_emoji_from_db(emoji_hash)
if emoji_record and emoji_record[0].emotion:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
return emoji_record.emotion
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") # type: ignore # type: ignore
return emoji_record.emotion # type: ignore
except Exception as e:
logger.error(f"从数据库查询表情包描述时出错: {e}")

View File

@@ -7,6 +7,7 @@ import asyncio
import time
from typing import Any
from chat.message_manager.adaptive_stream_manager import StreamPriority
from src.chat.chatter_manager import ChatterManager
from src.chat.energy_system import energy_manager
from src.common.data_models.message_manager_data_model import StreamContext

View File

@@ -1,6 +1,14 @@
"""SQLAlchemy数据库模型定义
替换Peewee ORM使用SQLAlchemy提供更好的连接池管理和错误恢复能力
说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到
SQLAlchemy 2.0 推荐的带类型注解的声明式风格:
field_name: Mapped[PyType] = mapped_column(Type, ...)
这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。
当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。
"""
import datetime
@@ -103,31 +111,31 @@ class ChatStreams(Base):
__tablename__ = "chat_streams"
id = Column(Integer, primary_key=True, autoincrement=True)
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
create_time = Column(Float, nullable=False)
group_platform = Column(Text, nullable=True)
group_id = Column(get_string_field(100), nullable=True, index=True)
group_name = Column(Text, nullable=True)
last_active_time = Column(Float, nullable=False)
platform = Column(Text, nullable=False)
user_platform = Column(Text, nullable=False)
user_id = Column(get_string_field(100), nullable=False, index=True)
user_nickname = Column(Text, nullable=False)
user_cardname = Column(Text, nullable=True)
energy_value = Column(Float, nullable=True, default=5.0)
sleep_pressure = Column(Float, nullable=True, default=0.0)
focus_energy = Column(Float, nullable=True, default=0.5)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
stream_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, unique=True, index=True)
create_time: Mapped[float] = mapped_column(Float, nullable=False)
group_platform: Mapped[str | None] = mapped_column(Text, nullable=True)
group_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True)
group_name: Mapped[str | None] = mapped_column(Text, nullable=True)
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
platform: Mapped[str] = mapped_column(Text, nullable=False)
user_platform: Mapped[str] = mapped_column(Text, nullable=False)
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
user_nickname: Mapped[str] = mapped_column(Text, nullable=False)
user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True)
energy_value: Mapped[float | None] = mapped_column(Float, nullable=True, default=5.0)
sleep_pressure: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0)
focus_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5)
# 动态兴趣度系统字段
base_interest_energy = Column(Float, nullable=True, default=0.5)
message_interest_total = Column(Float, nullable=True, default=0.0)
message_count = Column(Integer, nullable=True, default=0)
action_count = Column(Integer, nullable=True, default=0)
reply_count = Column(Integer, nullable=True, default=0)
last_interaction_time = Column(Float, nullable=True, default=None)
consecutive_no_reply = Column(Integer, nullable=True, default=0)
base_interest_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5)
message_interest_total: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0)
message_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
action_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
reply_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
last_interaction_time: Mapped[float | None] = mapped_column(Float, nullable=True, default=None)
consecutive_no_reply: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
# 消息打断系统字段
interruption_count = Column(Integer, nullable=True, default=0)
interruption_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
__table_args__ = (
Index("idx_chatstreams_stream_id", "stream_id"),
@@ -141,20 +149,20 @@ class LLMUsage(Base):
__tablename__ = "llm_usage"
id = Column(Integer, primary_key=True, autoincrement=True)
model_name = Column(get_string_field(100), nullable=False, index=True)
model_assign_name = Column(get_string_field(100), index=True) # 添加索引
model_api_provider = Column(get_string_field(100), index=True) # 添加索引
user_id = Column(get_string_field(50), nullable=False, index=True)
request_type = Column(get_string_field(50), nullable=False, index=True)
endpoint = Column(Text, nullable=False)
prompt_tokens = Column(Integer, nullable=False)
completion_tokens = Column(Integer, nullable=False)
time_cost = Column(Float, nullable=True)
total_tokens = Column(Integer, nullable=False)
cost = Column(Float, nullable=False)
status = Column(Text, nullable=False)
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
model_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
model_assign_name: Mapped[str] = mapped_column(get_string_field(100), index=True)
model_api_provider: Mapped[str] = mapped_column(get_string_field(100), index=True)
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
request_type: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
endpoint: Mapped[str] = mapped_column(Text, nullable=False)
prompt_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
completion_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
time_cost: Mapped[float | None] = mapped_column(Float, nullable=True)
total_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
cost: Mapped[float] = mapped_column(Float, nullable=False)
status: Mapped[str] = mapped_column(Text, nullable=False)
timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
__table_args__ = (
Index("idx_llmusage_model_name", "model_name"),
@@ -172,19 +180,19 @@ class Emoji(Base):
__tablename__ = "emoji"
id = Column(Integer, primary_key=True, autoincrement=True)
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
format = Column(Text, nullable=False)
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
query_count = Column(Integer, nullable=False, default=0)
is_registered = Column(Boolean, nullable=False, default=False)
is_banned = Column(Boolean, nullable=False, default=False)
emotion = Column(Text, nullable=True)
record_time = Column(Float, nullable=False)
register_time = Column(Float, nullable=True)
usage_count = Column(Integer, nullable=False, default=0)
last_used_time = Column(Float, nullable=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
full_path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True)
format: Mapped[str] = mapped_column(Text, nullable=False)
emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
description: Mapped[str] = mapped_column(Text, nullable=False)
query_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
is_registered: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_banned: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
emotion: Mapped[str | None] = mapped_column(Text, nullable=True)
record_time: Mapped[float] = mapped_column(Float, nullable=False)
register_time: Mapped[float | None] = mapped_column(Float, nullable=True)
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
last_used_time: Mapped[float | None] = mapped_column(Float, nullable=True)
__table_args__ = (
Index("idx_emoji_full_path", "full_path"),
@@ -197,50 +205,50 @@ class Messages(Base):
__tablename__ = "messages"
id = Column(Integer, primary_key=True, autoincrement=True)
message_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
reply_to = Column(Text, nullable=True)
interest_value = Column(Float, nullable=True)
key_words = Column(Text, nullable=True)
key_words_lite = Column(Text, nullable=True)
is_mentioned = Column(Boolean, nullable=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
message_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
time: Mapped[float] = mapped_column(Float, nullable=False)
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
reply_to: Mapped[str | None] = mapped_column(Text, nullable=True)
interest_value: Mapped[float | None] = mapped_column(Float, nullable=True)
key_words: Mapped[str | None] = mapped_column(Text, nullable=True)
key_words_lite: Mapped[str | None] = mapped_column(Text, nullable=True)
is_mentioned: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
chat_info_user_platform = Column(Text, nullable=False)
chat_info_user_id = Column(Text, nullable=False)
chat_info_user_nickname = Column(Text, nullable=False)
chat_info_user_cardname = Column(Text, nullable=True)
chat_info_group_platform = Column(Text, nullable=True)
chat_info_group_id = Column(Text, nullable=True)
chat_info_group_name = Column(Text, nullable=True)
chat_info_create_time = Column(Float, nullable=False)
chat_info_last_active_time = Column(Float, nullable=False)
chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False)
chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False)
chat_info_user_platform: Mapped[str] = mapped_column(Text, nullable=False)
chat_info_user_id: Mapped[str] = mapped_column(Text, nullable=False)
chat_info_user_nickname: Mapped[str] = mapped_column(Text, nullable=False)
chat_info_user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True)
chat_info_group_platform: Mapped[str | None] = mapped_column(Text, nullable=True)
chat_info_group_id: Mapped[str | None] = mapped_column(Text, nullable=True)
chat_info_group_name: Mapped[str | None] = mapped_column(Text, nullable=True)
chat_info_create_time: Mapped[float] = mapped_column(Float, nullable=False)
chat_info_last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
# 从顶层 user_info 扁平化而来的字段
user_platform = Column(Text, nullable=True)
user_id = Column(get_string_field(100), nullable=True, index=True)
user_nickname = Column(Text, nullable=True)
user_cardname = Column(Text, nullable=True)
user_platform: Mapped[str | None] = mapped_column(Text, nullable=True)
user_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True)
user_nickname: Mapped[str | None] = mapped_column(Text, nullable=True)
user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True)
processed_plain_text = Column(Text, nullable=True)
display_message = Column(Text, nullable=True)
memorized_times = Column(Integer, nullable=False, default=0)
priority_mode = Column(Text, nullable=True)
priority_info = Column(Text, nullable=True)
additional_config = Column(Text, nullable=True)
is_emoji = Column(Boolean, nullable=False, default=False)
is_picid = Column(Boolean, nullable=False, default=False)
is_command = Column(Boolean, nullable=False, default=False)
is_notify = Column(Boolean, nullable=False, default=False)
processed_plain_text: Mapped[str | None] = mapped_column(Text, nullable=True)
display_message: Mapped[str | None] = mapped_column(Text, nullable=True)
memorized_times: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
priority_mode: Mapped[str | None] = mapped_column(Text, nullable=True)
priority_info: Mapped[str | None] = mapped_column(Text, nullable=True)
additional_config: Mapped[str | None] = mapped_column(Text, nullable=True)
is_emoji: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_picid: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_command: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_notify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# 兴趣度系统字段
actions = Column(Text, nullable=True) # JSON格式存储动作列表
should_reply = Column(Boolean, nullable=True, default=False)
should_act = Column(Boolean, nullable=True, default=False)
actions: Mapped[str | None] = mapped_column(Text, nullable=True)
should_reply: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False)
should_act: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False)
__table_args__ = (
Index("idx_messages_message_id", "message_id"),
@@ -257,17 +265,17 @@ class ActionRecords(Base):
__tablename__ = "action_records"
id = Column(Integer, primary_key=True, autoincrement=True)
action_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
action_name = Column(Text, nullable=False)
action_data = Column(Text, nullable=False)
action_done = Column(Boolean, nullable=False, default=False)
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
action_prompt_display = Column(Text, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
action_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
time: Mapped[float] = mapped_column(Float, nullable=False)
action_name: Mapped[str] = mapped_column(Text, nullable=False)
action_data: Mapped[str] = mapped_column(Text, nullable=False)
action_done: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
action_build_into_prompt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
action_prompt_display: Mapped[str] = mapped_column(Text, nullable=False)
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False)
chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False)
__table_args__ = (
Index("idx_actionrecords_action_id", "action_id"),
@@ -281,15 +289,15 @@ class Images(Base):
__tablename__ = "images"
id = Column(Integer, primary_key=True, autoincrement=True)
image_id = Column(Text, nullable=False, default="")
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=True)
path = Column(get_string_field(500), nullable=False, unique=True)
count = Column(Integer, nullable=False, default=1)
timestamp = Column(Float, nullable=False)
type = Column(Text, nullable=False)
vlm_processed = Column(Boolean, nullable=False, default=False)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
image_id: Mapped[str] = mapped_column(Text, nullable=False, default="")
emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True)
count: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
timestamp: Mapped[float] = mapped_column(Float, nullable=False)
type: Mapped[str] = mapped_column(Text, nullable=False)
vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
__table_args__ = (
Index("idx_images_emoji_hash", "emoji_hash"),
@@ -302,11 +310,11 @@ class ImageDescriptions(Base):
__tablename__ = "image_descriptions"
id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(Text, nullable=False)
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
timestamp = Column(Float, nullable=False)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
type: Mapped[str] = mapped_column(Text, nullable=False)
image_description_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
description: Mapped[str] = mapped_column(Text, nullable=False)
timestamp: Mapped[float] = mapped_column(Float, nullable=False)
__table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),)
@@ -316,20 +324,20 @@ class Videos(Base):
__tablename__ = "videos"
id = Column(Integer, primary_key=True, autoincrement=True)
video_id = Column(Text, nullable=False, default="")
video_hash = Column(get_string_field(64), nullable=False, index=True, unique=True)
description = Column(Text, nullable=True)
count = Column(Integer, nullable=False, default=1)
timestamp = Column(Float, nullable=False)
vlm_processed = Column(Boolean, nullable=False, default=False)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
video_id: Mapped[str] = mapped_column(Text, nullable=False, default="")
video_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True, unique=True)
description: Mapped[str | None] = mapped_column(Text, nullable=True)
count: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
timestamp: Mapped[float] = mapped_column(Float, nullable=False)
vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# 视频特有属性
duration = Column(Float, nullable=True) # 视频时长(秒)
frame_count = Column(Integer, nullable=True) # 总帧数
fps = Column(Float, nullable=True) # 帧率
resolution = Column(Text, nullable=True) # 分辨率
file_size = Column(Integer, nullable=True) # 文件大小(字节)
duration: Mapped[float | None] = mapped_column(Float, nullable=True)
frame_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
fps: Mapped[float | None] = mapped_column(Float, nullable=True)
resolution: Mapped[str | None] = mapped_column(Text, nullable=True)
file_size: Mapped[int | None] = mapped_column(Integer, nullable=True)
__table_args__ = (
Index("idx_videos_video_hash", "video_hash"),
@@ -342,11 +350,11 @@ class OnlineTime(Base):
__tablename__ = "online_time"
id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
duration = Column(Integer, nullable=False)
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
end_timestamp = Column(DateTime, nullable=False, index=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
timestamp: Mapped[str] = mapped_column(Text, nullable=False, default=str(datetime.datetime.now))
duration: Mapped[int] = mapped_column(Integer, nullable=False)
start_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
end_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True)
__table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),)
@@ -356,22 +364,22 @@ class PersonInfo(Base):
__tablename__ = "person_info"
id = Column(Integer, primary_key=True, autoincrement=True)
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
person_name = Column(Text, nullable=True)
name_reason = Column(Text, nullable=True)
platform = Column(Text, nullable=False)
user_id = Column(get_string_field(50), nullable=False, index=True)
nickname = Column(Text, nullable=True)
impression = Column(Text, nullable=True)
short_impression = Column(Text, nullable=True)
points = Column(Text, nullable=True)
forgotten_points = Column(Text, nullable=True)
info_list = Column(Text, nullable=True)
know_times = Column(Float, nullable=True)
know_since = Column(Float, nullable=True)
last_know = Column(Float, nullable=True)
attitude = Column(Integer, nullable=True, default=50)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
person_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True)
person_name: Mapped[str | None] = mapped_column(Text, nullable=True)
name_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
platform: Mapped[str] = mapped_column(Text, nullable=False)
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
nickname: Mapped[str | None] = mapped_column(Text, nullable=True)
impression: Mapped[str | None] = mapped_column(Text, nullable=True)
short_impression: Mapped[str | None] = mapped_column(Text, nullable=True)
points: Mapped[str | None] = mapped_column(Text, nullable=True)
forgotten_points: Mapped[str | None] = mapped_column(Text, nullable=True)
info_list: Mapped[str | None] = mapped_column(Text, nullable=True)
know_times: Mapped[float | None] = mapped_column(Float, nullable=True)
know_since: Mapped[float | None] = mapped_column(Float, nullable=True)
last_know: Mapped[float | None] = mapped_column(Float, nullable=True)
attitude: Mapped[int | None] = mapped_column(Integer, nullable=True, default=50)
__table_args__ = (
Index("idx_personinfo_person_id", "person_id"),
@@ -384,13 +392,13 @@ class BotPersonalityInterests(Base):
__tablename__ = "bot_personality_interests"
id = Column(Integer, primary_key=True, autoincrement=True)
personality_id = Column(get_string_field(100), nullable=False, index=True)
personality_description = Column(Text, nullable=False)
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
version = Column(Integer, nullable=False, default=1)
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
personality_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
personality_description: Mapped[str] = mapped_column(Text, nullable=False)
interest_tags: Mapped[str] = mapped_column(Text, nullable=False)
embedding_model: Mapped[str] = mapped_column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
last_updated: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
__table_args__ = (
Index("idx_botpersonality_personality_id", "personality_id"),
@@ -404,13 +412,13 @@ class Memory(Base):
__tablename__ = "memory"
id = Column(Integer, primary_key=True, autoincrement=True)
memory_id = Column(get_string_field(64), nullable=False, index=True)
chat_id = Column(Text, nullable=True)
memory_text = Column(Text, nullable=True)
keywords = Column(Text, nullable=True)
create_time = Column(Float, nullable=True)
last_view_time = Column(Float, nullable=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
memory_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
chat_id: Mapped[str | None] = mapped_column(Text, nullable=True)
memory_text: Mapped[str | None] = mapped_column(Text, nullable=True)
keywords: Mapped[str | None] = mapped_column(Text, nullable=True)
create_time: Mapped[float | None] = mapped_column(Float, nullable=True)
last_view_time: Mapped[float | None] = mapped_column(Float, nullable=True)
__table_args__ = (Index("idx_memory_memory_id", "memory_id"),)
@@ -437,19 +445,19 @@ class ThinkingLog(Base):
__tablename__ = "thinking_logs"
id = Column(Integer, primary_key=True, autoincrement=True)
chat_id = Column(get_string_field(64), nullable=False, index=True)
trigger_text = Column(Text, nullable=True)
response_text = Column(Text, nullable=True)
trigger_info_json = Column(Text, nullable=True)
response_info_json = Column(Text, nullable=True)
timing_results_json = Column(Text, nullable=True)
chat_history_json = Column(Text, nullable=True)
chat_history_in_thinking_json = Column(Text, nullable=True)
chat_history_after_response_json = Column(Text, nullable=True)
heartflow_data_json = Column(Text, nullable=True)
reasoning_data_json = Column(Text, nullable=True)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
trigger_text: Mapped[str | None] = mapped_column(Text, nullable=True)
response_text: Mapped[str | None] = mapped_column(Text, nullable=True)
trigger_info_json: Mapped[str | None] = mapped_column(Text, nullable=True)
response_info_json: Mapped[str | None] = mapped_column(Text, nullable=True)
timing_results_json: Mapped[str | None] = mapped_column(Text, nullable=True)
chat_history_json: Mapped[str | None] = mapped_column(Text, nullable=True)
chat_history_in_thinking_json: Mapped[str | None] = mapped_column(Text, nullable=True)
chat_history_after_response_json: Mapped[str | None] = mapped_column(Text, nullable=True)
heartflow_data_json: Mapped[str | None] = mapped_column(Text, nullable=True)
reasoning_data_json: Mapped[str | None] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),)
@@ -459,13 +467,13 @@ class GraphNodes(Base):
__tablename__ = "graph_nodes"
id = Column(Integer, primary_key=True, autoincrement=True)
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
memory_items = Column(Text, nullable=False)
hash = Column(Text, nullable=False)
weight = Column(Float, nullable=False, default=1.0)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
concept: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True)
memory_items: Mapped[str] = mapped_column(Text, nullable=False)
hash: Mapped[str] = mapped_column(Text, nullable=False)
weight: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
created_time: Mapped[float] = mapped_column(Float, nullable=False)
last_modified: Mapped[float] = mapped_column(Float, nullable=False)
__table_args__ = (Index("idx_graphnodes_concept", "concept"),)
@@ -475,13 +483,13 @@ class GraphEdges(Base):
__tablename__ = "graph_edges"
id = Column(Integer, primary_key=True, autoincrement=True)
source = Column(get_string_field(255), nullable=False, index=True)
target = Column(get_string_field(255), nullable=False, index=True)
strength = Column(Integer, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
source: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True)
target: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True)
strength: Mapped[int] = mapped_column(Integer, nullable=False)
hash: Mapped[str] = mapped_column(Text, nullable=False)
created_time: Mapped[float] = mapped_column(Float, nullable=False)
last_modified: Mapped[float] = mapped_column(Float, nullable=False)
__table_args__ = (
Index("idx_graphedges_source", "source"),
@@ -494,11 +502,11 @@ class Schedule(Base):
__tablename__ = "schedule"
id = Column(Integer, primary_key=True, autoincrement=True)
date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式
schedule_data = Column(Text, nullable=False) # JSON格式的日程数据
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
date: Mapped[str] = mapped_column(get_string_field(10), nullable=False, unique=True, index=True)
schedule_data: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
__table_args__ = (Index("idx_schedule_date", "date"),)
@@ -508,17 +516,15 @@ class MaiZoneScheduleStatus(Base):
__tablename__ = "maizone_schedule_status"
id = Column(Integer, primary_key=True, autoincrement=True)
datetime_hour = Column(
get_string_field(13), nullable=False, unique=True, index=True
) # YYYY-MM-DD HH格式精确到小时
activity = Column(Text, nullable=False) # 该小时的活动内容
is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理
processed_at = Column(DateTime, nullable=True) # 处理时间
story_content = Column(Text, nullable=True) # 生成的说说内容
send_success = Column(Boolean, nullable=False, default=False) # 是否发送成功
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
datetime_hour: Mapped[str] = mapped_column(get_string_field(13), nullable=False, unique=True, index=True)
activity: Mapped[str] = mapped_column(Text, nullable=False)
is_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
processed_at: Mapped[datetime.datetime | None] = mapped_column(DateTime, nullable=True)
story_content: Mapped[str | None] = mapped_column(Text, nullable=True)
send_success: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
__table_args__ = (
Index("idx_maizone_datetime_hour", "datetime_hour"),
@@ -527,16 +533,20 @@ class MaiZoneScheduleStatus(Base):
class BanUser(Base):
"""被禁用用户模型"""
"""被禁用用户模型
使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型,
避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。
"""
__tablename__ = "ban_users"
id = Column(Integer, primary_key=True, autoincrement=True)
platform = Column(Text, nullable=False)
user_id = Column(get_string_field(50), nullable=False, index=True)
violation_num = Column(Integer, nullable=False, default=0)
reason = Column(Text, nullable=False)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
platform: Mapped[str] = mapped_column(Text, nullable=False)
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
reason: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (
Index("idx_violation_num", "violation_num"),
@@ -551,38 +561,38 @@ class AntiInjectionStats(Base):
__tablename__ = "anti_injection_stats"
id = Column(Integer, primary_key=True, autoincrement=True)
total_messages = Column(Integer, nullable=False, default=0)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
total_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
"""总处理消息数"""
detected_injections = Column(Integer, nullable=False, default=0)
detected_injections: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
"""检测到的注入攻击数"""
blocked_messages = Column(Integer, nullable=False, default=0)
blocked_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
"""被阻止的消息数"""
shielded_messages = Column(Integer, nullable=False, default=0)
shielded_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
"""被加盾的消息数"""
processing_time_total = Column(Float, nullable=False, default=0.0)
processing_time_total: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
"""总处理时间"""
total_process_time = Column(Float, nullable=False, default=0.0)
total_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
"""累计总处理时间"""
last_process_time = Column(Float, nullable=False, default=0.0)
last_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
"""最近一次处理时间"""
error_count = Column(Integer, nullable=False, default=0)
error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
"""错误计数"""
start_time = Column(DateTime, nullable=False, default=datetime.datetime.now)
start_time: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
"""统计开始时间"""
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
"""记录创建时间"""
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
"""记录更新时间"""
__table_args__ = (
@@ -596,26 +606,26 @@ class CacheEntries(Base):
__tablename__ = "cache_entries"
id = Column(Integer, primary_key=True, autoincrement=True)
cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
cache_key: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True)
"""缓存键,包含工具名、参数和代码哈希"""
cache_value = Column(Text, nullable=False)
cache_value: Mapped[str] = mapped_column(Text, nullable=False)
"""缓存的数据JSON格式"""
expires_at = Column(Float, nullable=False, index=True)
expires_at: Mapped[float] = mapped_column(Float, nullable=False, index=True)
"""过期时间戳"""
tool_name = Column(get_string_field(100), nullable=False, index=True)
tool_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
"""工具名称"""
created_at = Column(Float, nullable=False, default=lambda: time.time())
created_at: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time())
"""创建时间戳"""
last_accessed = Column(Float, nullable=False, default=lambda: time.time())
last_accessed: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time())
"""最后访问时间戳"""
access_count = Column(Integer, nullable=False, default=0)
access_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
"""访问次数"""
__table_args__ = (
@@ -631,18 +641,16 @@ class MonthlyPlan(Base):
__tablename__ = "monthly_plans"
id = Column(Integer, primary_key=True, autoincrement=True)
plan_text = Column(Text, nullable=False)
target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM"
status = Column(
get_string_field(20), nullable=False, default="active", index=True
) # 'active', 'completed', 'archived'
usage_count = Column(Integer, nullable=False, default=0)
last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
plan_text: Mapped[str] = mapped_column(Text, nullable=False)
target_month: Mapped[str] = mapped_column(String(7), nullable=False, index=True)
status: Mapped[str] = mapped_column(get_string_field(20), nullable=False, default="active", index=True)
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
is_deleted = Column(Boolean, nullable=False, default=False)
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
__table_args__ = (
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
@@ -807,12 +815,12 @@ class PermissionNodes(Base):
__tablename__ = "permission_nodes"
id = Column(Integer, primary_key=True, autoincrement=True)
node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称
description = Column(Text, nullable=False) # 权限描述
plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件
default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
node_name: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True)
description: Mapped[str] = mapped_column(Text, nullable=False)
plugin_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
default_granted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
__table_args__ = (
Index("idx_permission_plugin", "plugin_name"),
@@ -825,13 +833,13 @@ class UserPermissions(Base):
__tablename__ = "user_permissions"
id = Column(Integer, primary_key=True, autoincrement=True)
platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型
user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID
permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称
granted = Column(Boolean, default=True, nullable=False) # 是否授权
granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间
granted_by = Column(get_string_field(100), nullable=True) # 授权者信息
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
platform: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
permission_node: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True)
granted: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
granted_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
granted_by: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True)
__table_args__ = (
Index("idx_user_platform_id", "platform", "user_id"),
@@ -845,13 +853,13 @@ class UserRelationships(Base):
__tablename__ = "user_relationships"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
user_name = Column(get_string_field(100), nullable=True) # 用户名
relationship_text = Column(Text, nullable=True) # 关系印象描述
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True)
user_name: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True)
relationship_text: Mapped[str | None] = mapped_column(Text, nullable=True)
relationship_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.3) # 关系分数(0-1)
last_updated: Mapped[float] = mapped_column(Float, nullable=False, default=time.time)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
__table_args__ = (
Index("idx_user_relationship_id", "user_id"),

View File

@@ -0,0 +1,872 @@
"""SQLAlchemy数据库模型定义
替换Peewee ORM使用SQLAlchemy提供更好的连接池管理和错误恢复能力
说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到
SQLAlchemy 2.0 推荐的带类型注解的声明式风格:
field_name: Mapped[PyType] = mapped_column(Type, ...)
这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。
当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。
"""
import datetime
import os
import time
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from typing import Any
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
from src.common.database.connection_pool_manager import get_connection_pool_manager
from src.common.logger import get_logger
logger = get_logger("sqlalchemy_models")
# 创建基类
Base = declarative_base()
async def enable_sqlite_wal_mode(engine):
"""为 SQLite 启用 WAL 模式以提高并发性能"""
try:
async with engine.begin() as conn:
# 启用 WAL 模式
await conn.execute(text("PRAGMA journal_mode = WAL"))
# 设置适中的同步级别,平衡性能和安全性
await conn.execute(text("PRAGMA synchronous = NORMAL"))
# 启用外键约束
await conn.execute(text("PRAGMA foreign_keys = ON"))
# 设置 busy_timeout避免锁定错误
await conn.execute(text("PRAGMA busy_timeout = 60000")) # 60秒
logger.info("[SQLite] WAL 模式已启用,并发性能已优化")
except Exception as e:
logger.warning(f"[SQLite] 启用 WAL 模式失败: {e},将使用默认配置")
async def maintain_sqlite_database():
"""定期维护 SQLite 数据库性能"""
try:
engine, SessionLocal = await initialize_database()
if not engine:
return
async with engine.begin() as conn:
# 检查并确保 WAL 模式仍然启用
result = await conn.execute(text("PRAGMA journal_mode"))
journal_mode = result.scalar()
if journal_mode != "wal":
await conn.execute(text("PRAGMA journal_mode = WAL"))
logger.info("[SQLite] WAL 模式已重新启用")
# 优化数据库性能
await conn.execute(text("PRAGMA synchronous = NORMAL"))
await conn.execute(text("PRAGMA busy_timeout = 60000"))
await conn.execute(text("PRAGMA foreign_keys = ON"))
# 定期清理(可选,根据需要启用)
# await conn.execute(text("PRAGMA optimize"))
logger.info("[SQLite] 数据库维护完成")
except Exception as e:
logger.warning(f"[SQLite] 数据库维护失败: {e}")
def get_sqlite_performance_config():
"""获取 SQLite 性能优化配置"""
return {
"journal_mode": "WAL", # 提高并发性能
"synchronous": "NORMAL", # 平衡性能和安全性
"busy_timeout": 60000, # 60秒超时
"foreign_keys": "ON", # 启用外键约束
"cache_size": -10000, # 10MB 缓存
"temp_store": "MEMORY", # 临时存储使用内存
"mmap_size": 268435456, # 256MB 内存映射
}
# MySQL兼容的字段类型辅助函数
def get_string_field(max_length=255, **kwargs):
"""
根据数据库类型返回合适的字符串字段
MySQL需要指定长度的VARCHAR用于索引SQLite可以使用Text
"""
from src.config.config import global_config
if global_config.database.database_type == "mysql":
return String(max_length, **kwargs)
else:
return Text(**kwargs)
class ChatStreams(Base):
"""聊天流模型"""
__tablename__ = "chat_streams"
id = Column(Integer, primary_key=True, autoincrement=True)
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
create_time = Column(Float, nullable=False)
group_platform = Column(Text, nullable=True)
group_id = Column(get_string_field(100), nullable=True, index=True)
group_name = Column(Text, nullable=True)
last_active_time = Column(Float, nullable=False)
platform = Column(Text, nullable=False)
user_platform = Column(Text, nullable=False)
user_id = Column(get_string_field(100), nullable=False, index=True)
user_nickname = Column(Text, nullable=False)
user_cardname = Column(Text, nullable=True)
energy_value = Column(Float, nullable=True, default=5.0)
sleep_pressure = Column(Float, nullable=True, default=0.0)
focus_energy = Column(Float, nullable=True, default=0.5)
# 动态兴趣度系统字段
base_interest_energy = Column(Float, nullable=True, default=0.5)
message_interest_total = Column(Float, nullable=True, default=0.0)
message_count = Column(Integer, nullable=True, default=0)
action_count = Column(Integer, nullable=True, default=0)
reply_count = Column(Integer, nullable=True, default=0)
last_interaction_time = Column(Float, nullable=True, default=None)
consecutive_no_reply = Column(Integer, nullable=True, default=0)
# 消息打断系统字段
interruption_count = Column(Integer, nullable=True, default=0)
__table_args__ = (
Index("idx_chatstreams_stream_id", "stream_id"),
Index("idx_chatstreams_user_id", "user_id"),
Index("idx_chatstreams_group_id", "group_id"),
)
class LLMUsage(Base):
"""LLM使用记录模型"""
__tablename__ = "llm_usage"
id = Column(Integer, primary_key=True, autoincrement=True)
model_name = Column(get_string_field(100), nullable=False, index=True)
model_assign_name = Column(get_string_field(100), index=True) # 添加索引
model_api_provider = Column(get_string_field(100), index=True) # 添加索引
user_id = Column(get_string_field(50), nullable=False, index=True)
request_type = Column(get_string_field(50), nullable=False, index=True)
endpoint = Column(Text, nullable=False)
prompt_tokens = Column(Integer, nullable=False)
completion_tokens = Column(Integer, nullable=False)
time_cost = Column(Float, nullable=True)
total_tokens = Column(Integer, nullable=False)
cost = Column(Float, nullable=False)
status = Column(Text, nullable=False)
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
__table_args__ = (
Index("idx_llmusage_model_name", "model_name"),
Index("idx_llmusage_model_assign_name", "model_assign_name"),
Index("idx_llmusage_model_api_provider", "model_api_provider"),
Index("idx_llmusage_time_cost", "time_cost"),
Index("idx_llmusage_user_id", "user_id"),
Index("idx_llmusage_request_type", "request_type"),
Index("idx_llmusage_timestamp", "timestamp"),
)
class Emoji(Base):
"""表情包模型"""
__tablename__ = "emoji"
id = Column(Integer, primary_key=True, autoincrement=True)
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
format = Column(Text, nullable=False)
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
query_count = Column(Integer, nullable=False, default=0)
is_registered = Column(Boolean, nullable=False, default=False)
is_banned = Column(Boolean, nullable=False, default=False)
emotion = Column(Text, nullable=True)
record_time = Column(Float, nullable=False)
register_time = Column(Float, nullable=True)
usage_count = Column(Integer, nullable=False, default=0)
last_used_time = Column(Float, nullable=True)
__table_args__ = (
Index("idx_emoji_full_path", "full_path"),
Index("idx_emoji_hash", "emoji_hash"),
)
class Messages(Base):
"""消息模型"""
__tablename__ = "messages"
id = Column(Integer, primary_key=True, autoincrement=True)
message_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
reply_to = Column(Text, nullable=True)
interest_value = Column(Float, nullable=True)
key_words = Column(Text, nullable=True)
key_words_lite = Column(Text, nullable=True)
is_mentioned = Column(Boolean, nullable=True)
# 从 chat_info 扁平化而来的字段
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
chat_info_user_platform = Column(Text, nullable=False)
chat_info_user_id = Column(Text, nullable=False)
chat_info_user_nickname = Column(Text, nullable=False)
chat_info_user_cardname = Column(Text, nullable=True)
chat_info_group_platform = Column(Text, nullable=True)
chat_info_group_id = Column(Text, nullable=True)
chat_info_group_name = Column(Text, nullable=True)
chat_info_create_time = Column(Float, nullable=False)
chat_info_last_active_time = Column(Float, nullable=False)
# 从顶层 user_info 扁平化而来的字段
user_platform = Column(Text, nullable=True)
user_id = Column(get_string_field(100), nullable=True, index=True)
user_nickname = Column(Text, nullable=True)
user_cardname = Column(Text, nullable=True)
processed_plain_text = Column(Text, nullable=True)
display_message = Column(Text, nullable=True)
memorized_times = Column(Integer, nullable=False, default=0)
priority_mode = Column(Text, nullable=True)
priority_info = Column(Text, nullable=True)
additional_config = Column(Text, nullable=True)
is_emoji = Column(Boolean, nullable=False, default=False)
is_picid = Column(Boolean, nullable=False, default=False)
is_command = Column(Boolean, nullable=False, default=False)
is_notify = Column(Boolean, nullable=False, default=False)
# 兴趣度系统字段
actions = Column(Text, nullable=True) # JSON格式存储动作列表
should_reply = Column(Boolean, nullable=True, default=False)
should_act = Column(Boolean, nullable=True, default=False)
__table_args__ = (
Index("idx_messages_message_id", "message_id"),
Index("idx_messages_chat_id", "chat_id"),
Index("idx_messages_time", "time"),
Index("idx_messages_user_id", "user_id"),
Index("idx_messages_should_reply", "should_reply"),
Index("idx_messages_should_act", "should_act"),
)
class ActionRecords(Base):
"""动作记录模型"""
__tablename__ = "action_records"
id = Column(Integer, primary_key=True, autoincrement=True)
action_id = Column(get_string_field(100), nullable=False, index=True)
time = Column(Float, nullable=False)
action_name = Column(Text, nullable=False)
action_data = Column(Text, nullable=False)
action_done = Column(Boolean, nullable=False, default=False)
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
action_prompt_display = Column(Text, nullable=False)
chat_id = Column(get_string_field(64), nullable=False, index=True)
chat_info_stream_id = Column(Text, nullable=False)
chat_info_platform = Column(Text, nullable=False)
__table_args__ = (
Index("idx_actionrecords_action_id", "action_id"),
Index("idx_actionrecords_chat_id", "chat_id"),
Index("idx_actionrecords_time", "time"),
)
class Images(Base):
"""图像信息模型"""
__tablename__ = "images"
id = Column(Integer, primary_key=True, autoincrement=True)
image_id = Column(Text, nullable=False, default="")
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=True)
path = Column(get_string_field(500), nullable=False, unique=True)
count = Column(Integer, nullable=False, default=1)
timestamp = Column(Float, nullable=False)
type = Column(Text, nullable=False)
vlm_processed = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index("idx_images_emoji_hash", "emoji_hash"),
Index("idx_images_path", "path"),
)
class ImageDescriptions(Base):
"""图像描述信息模型"""
__tablename__ = "image_descriptions"
id = Column(Integer, primary_key=True, autoincrement=True)
type = Column(Text, nullable=False)
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
description = Column(Text, nullable=False)
timestamp = Column(Float, nullable=False)
__table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),)
class Videos(Base):
"""视频信息模型"""
__tablename__ = "videos"
id = Column(Integer, primary_key=True, autoincrement=True)
video_id = Column(Text, nullable=False, default="")
video_hash = Column(get_string_field(64), nullable=False, index=True, unique=True)
description = Column(Text, nullable=True)
count = Column(Integer, nullable=False, default=1)
timestamp = Column(Float, nullable=False)
vlm_processed = Column(Boolean, nullable=False, default=False)
# 视频特有属性
duration = Column(Float, nullable=True) # 视频时长(秒)
frame_count = Column(Integer, nullable=True) # 总帧数
fps = Column(Float, nullable=True) # 帧率
resolution = Column(Text, nullable=True) # 分辨率
file_size = Column(Integer, nullable=True) # 文件大小(字节)
__table_args__ = (
Index("idx_videos_video_hash", "video_hash"),
Index("idx_videos_timestamp", "timestamp"),
)
class OnlineTime(Base):
"""在线时长记录模型"""
__tablename__ = "online_time"
id = Column(Integer, primary_key=True, autoincrement=True)
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
duration = Column(Integer, nullable=False)
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
end_timestamp = Column(DateTime, nullable=False, index=True)
__table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),)
class PersonInfo(Base):
"""人物信息模型"""
__tablename__ = "person_info"
id = Column(Integer, primary_key=True, autoincrement=True)
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
person_name = Column(Text, nullable=True)
name_reason = Column(Text, nullable=True)
platform = Column(Text, nullable=False)
user_id = Column(get_string_field(50), nullable=False, index=True)
nickname = Column(Text, nullable=True)
impression = Column(Text, nullable=True)
short_impression = Column(Text, nullable=True)
points = Column(Text, nullable=True)
forgotten_points = Column(Text, nullable=True)
info_list = Column(Text, nullable=True)
know_times = Column(Float, nullable=True)
know_since = Column(Float, nullable=True)
last_know = Column(Float, nullable=True)
attitude = Column(Integer, nullable=True, default=50)
__table_args__ = (
Index("idx_personinfo_person_id", "person_id"),
Index("idx_personinfo_user_id", "user_id"),
)
class BotPersonalityInterests(Base):
"""机器人人格兴趣标签模型"""
__tablename__ = "bot_personality_interests"
id = Column(Integer, primary_key=True, autoincrement=True)
personality_id = Column(get_string_field(100), nullable=False, index=True)
personality_description = Column(Text, nullable=False)
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
version = Column(Integer, nullable=False, default=1)
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
__table_args__ = (
Index("idx_botpersonality_personality_id", "personality_id"),
Index("idx_botpersonality_version", "version"),
Index("idx_botpersonality_last_updated", "last_updated"),
)
class Memory(Base):
"""记忆模型"""
__tablename__ = "memory"
id = Column(Integer, primary_key=True, autoincrement=True)
memory_id = Column(get_string_field(64), nullable=False, index=True)
chat_id = Column(Text, nullable=True)
memory_text = Column(Text, nullable=True)
keywords = Column(Text, nullable=True)
create_time = Column(Float, nullable=True)
last_view_time = Column(Float, nullable=True)
__table_args__ = (Index("idx_memory_memory_id", "memory_id"),)
class Expression(Base):
"""表达风格模型"""
__tablename__ = "expression"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
situation: Mapped[str] = mapped_column(Text, nullable=False)
style: Mapped[str] = mapped_column(Text, nullable=False)
count: Mapped[float] = mapped_column(Float, nullable=False)
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
type: Mapped[str] = mapped_column(Text, nullable=False)
create_date: Mapped[float | None] = mapped_column(Float, nullable=True)
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
class ThinkingLog(Base):
"""思考日志模型"""
__tablename__ = "thinking_logs"
id = Column(Integer, primary_key=True, autoincrement=True)
chat_id = Column(get_string_field(64), nullable=False, index=True)
trigger_text = Column(Text, nullable=True)
response_text = Column(Text, nullable=True)
trigger_info_json = Column(Text, nullable=True)
response_info_json = Column(Text, nullable=True)
timing_results_json = Column(Text, nullable=True)
chat_history_json = Column(Text, nullable=True)
chat_history_in_thinking_json = Column(Text, nullable=True)
chat_history_after_response_json = Column(Text, nullable=True)
heartflow_data_json = Column(Text, nullable=True)
reasoning_data_json = Column(Text, nullable=True)
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),)
class GraphNodes(Base):
"""记忆图节点模型"""
__tablename__ = "graph_nodes"
id = Column(Integer, primary_key=True, autoincrement=True)
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
memory_items = Column(Text, nullable=False)
hash = Column(Text, nullable=False)
weight = Column(Float, nullable=False, default=1.0)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (Index("idx_graphnodes_concept", "concept"),)
class GraphEdges(Base):
"""记忆图边模型"""
__tablename__ = "graph_edges"
id = Column(Integer, primary_key=True, autoincrement=True)
source = Column(get_string_field(255), nullable=False, index=True)
target = Column(get_string_field(255), nullable=False, index=True)
strength = Column(Integer, nullable=False)
hash = Column(Text, nullable=False)
created_time = Column(Float, nullable=False)
last_modified = Column(Float, nullable=False)
__table_args__ = (
Index("idx_graphedges_source", "source"),
Index("idx_graphedges_target", "target"),
)
class Schedule(Base):
"""日程模型"""
__tablename__ = "schedule"
id = Column(Integer, primary_key=True, autoincrement=True)
date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式
schedule_data = Column(Text, nullable=False) # JSON格式的日程数据
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
__table_args__ = (Index("idx_schedule_date", "date"),)
class MaiZoneScheduleStatus(Base):
"""麦麦空间日程处理状态模型"""
__tablename__ = "maizone_schedule_status"
id = Column(Integer, primary_key=True, autoincrement=True)
datetime_hour = Column(
get_string_field(13), nullable=False, unique=True, index=True
) # YYYY-MM-DD HH格式精确到小时
activity = Column(Text, nullable=False) # 该小时的活动内容
is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理
processed_at = Column(DateTime, nullable=True) # 处理时间
story_content = Column(Text, nullable=True) # 生成的说说内容
send_success = Column(Boolean, nullable=False, default=False) # 是否发送成功
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
__table_args__ = (
Index("idx_maizone_datetime_hour", "datetime_hour"),
Index("idx_maizone_is_processed", "is_processed"),
)
class BanUser(Base):
"""被禁用用户模型
使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型,
避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。
"""
__tablename__ = "ban_users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
platform: Mapped[str] = mapped_column(Text, nullable=False)
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
reason: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
__table_args__ = (
Index("idx_violation_num", "violation_num"),
Index("idx_banuser_user_id", "user_id"),
Index("idx_banuser_platform", "platform"),
Index("idx_banuser_platform_user_id", "platform", "user_id"),
)
class AntiInjectionStats(Base):
"""反注入系统统计模型"""
__tablename__ = "anti_injection_stats"
id = Column(Integer, primary_key=True, autoincrement=True)
total_messages = Column(Integer, nullable=False, default=0)
"""总处理消息数"""
detected_injections = Column(Integer, nullable=False, default=0)
"""检测到的注入攻击数"""
blocked_messages = Column(Integer, nullable=False, default=0)
"""被阻止的消息数"""
shielded_messages = Column(Integer, nullable=False, default=0)
"""被加盾的消息数"""
processing_time_total = Column(Float, nullable=False, default=0.0)
"""总处理时间"""
total_process_time = Column(Float, nullable=False, default=0.0)
"""累计总处理时间"""
last_process_time = Column(Float, nullable=False, default=0.0)
"""最近一次处理时间"""
error_count = Column(Integer, nullable=False, default=0)
"""错误计数"""
start_time = Column(DateTime, nullable=False, default=datetime.datetime.now)
"""统计开始时间"""
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
"""记录创建时间"""
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
"""记录更新时间"""
__table_args__ = (
Index("idx_anti_injection_stats_created_at", "created_at"),
Index("idx_anti_injection_stats_updated_at", "updated_at"),
)
class CacheEntries(Base):
"""工具缓存条目模型"""
__tablename__ = "cache_entries"
id = Column(Integer, primary_key=True, autoincrement=True)
cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True)
"""缓存键,包含工具名、参数和代码哈希"""
cache_value = Column(Text, nullable=False)
"""缓存的数据JSON格式"""
expires_at = Column(Float, nullable=False, index=True)
"""过期时间戳"""
tool_name = Column(get_string_field(100), nullable=False, index=True)
"""工具名称"""
created_at = Column(Float, nullable=False, default=lambda: time.time())
"""创建时间戳"""
last_accessed = Column(Float, nullable=False, default=lambda: time.time())
"""最后访问时间戳"""
access_count = Column(Integer, nullable=False, default=0)
"""访问次数"""
__table_args__ = (
Index("idx_cache_entries_key", "cache_key"),
Index("idx_cache_entries_expires_at", "expires_at"),
Index("idx_cache_entries_tool_name", "tool_name"),
Index("idx_cache_entries_created_at", "created_at"),
)
class MonthlyPlan(Base):
"""月度计划模型"""
__tablename__ = "monthly_plans"
id = Column(Integer, primary_key=True, autoincrement=True)
plan_text = Column(Text, nullable=False)
target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM"
status = Column(
get_string_field(20), nullable=False, default="active", index=True
) # 'active', 'completed', 'archived'
usage_count = Column(Integer, nullable=False, default=0)
last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
is_deleted = Column(Boolean, nullable=False, default=False)
__table_args__ = (
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
Index("idx_monthlyplan_last_used_date", "last_used_date"),
Index("idx_monthlyplan_usage_count", "usage_count"),
# 保留旧索引以兼容
Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"),
)
# 数据库引擎和会话管理
_engine = None
_SessionLocal = None
def get_database_url():
"""获取数据库连接URL"""
from src.config.config import global_config
config = global_config.database
if config.database_type == "mysql":
# 对用户名和密码进行URL编码处理特殊字符
from urllib.parse import quote_plus
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
# 检查是否配置了Unix socket连接
if config.mysql_unix_socket:
# 使用Unix socket连接
encoded_socket = quote_plus(config.mysql_unix_socket)
return (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@/{config.mysql_database}"
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
)
else:
# 使用标准TCP连接
return (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
else: # SQLite
# 如果是相对路径,则相对于项目根目录
if not os.path.isabs(config.sqlite_path):
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
return f"sqlite+aiosqlite:///{db_path}"
async def initialize_database():
"""初始化异步数据库引擎和会话"""
global _engine, _SessionLocal
if _engine is not None:
return _engine, _SessionLocal
database_url = get_database_url()
from src.config.config import global_config
config = global_config.database
# 配置引擎参数
engine_kwargs: dict[str, Any] = {
"echo": False, # 生产环境关闭SQL日志
"future": True,
}
if config.database_type == "mysql":
# MySQL连接池配置 - 异步引擎使用默认连接池
engine_kwargs.update(
{
"pool_size": config.connection_pool_size,
"max_overflow": config.connection_pool_size * 2,
"pool_timeout": config.connection_timeout,
"pool_recycle": 3600, # 1小时回收连接
"pool_pre_ping": True, # 连接前ping检查
"connect_args": {
"autocommit": config.mysql_autocommit,
"charset": config.mysql_charset,
"connect_timeout": config.connection_timeout,
},
}
)
else:
# SQLite配置 - aiosqlite不支持连接池参数
engine_kwargs.update(
{
"connect_args": {
"check_same_thread": False,
"timeout": 60, # 增加超时时间
},
}
)
_engine = create_async_engine(database_url, **engine_kwargs)
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
# 调用新的迁移函数,它会处理表的创建和列的添加
from src.common.database.db_migration import check_and_migrate_database
await check_and_migrate_database()
# 如果是 SQLite启用 WAL 模式以提高并发性能
if config.database_type == "sqlite":
await enable_sqlite_wal_mode(_engine)
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
return _engine, _SessionLocal
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession]:
"""
异步数据库会话上下文管理器。
在初始化失败时会yield None调用方需要检查会话是否为None。
现在使用透明的连接池管理器来复用现有连接,提高并发性能。
"""
SessionLocal = None
try:
_, SessionLocal = await initialize_database()
if not SessionLocal:
raise RuntimeError("数据库会话工厂 (_SessionLocal) 未初始化。")
except Exception as e:
logger.error(f"数据库初始化失败,无法创建会话: {e}")
raise
# 使用连接池管理器获取会话
pool_manager = get_connection_pool_manager()
async with pool_manager.get_session(SessionLocal) as session:
# 对于 SQLite在会话开始时设置 PRAGMA仅对新连接
from src.config.config import global_config
if global_config.database.database_type == "sqlite":
try:
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
except Exception as e:
logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}")
yield session
async def get_engine():
"""获取异步数据库引擎"""
engine, _ = await initialize_database()
return engine
class PermissionNodes(Base):
"""权限节点模型"""
__tablename__ = "permission_nodes"
id = Column(Integer, primary_key=True, autoincrement=True)
node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称
description = Column(Text, nullable=False) # 权限描述
plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件
default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
__table_args__ = (
Index("idx_permission_plugin", "plugin_name"),
Index("idx_permission_node", "node_name"),
)
class UserPermissions(Base):
"""用户权限模型"""
__tablename__ = "user_permissions"
id = Column(Integer, primary_key=True, autoincrement=True)
platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型
user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID
permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称
granted = Column(Boolean, default=True, nullable=False) # 是否授权
granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间
granted_by = Column(get_string_field(100), nullable=True) # 授权者信息
__table_args__ = (
Index("idx_user_platform_id", "platform", "user_id"),
Index("idx_user_permission", "platform", "user_id", "permission_node"),
Index("idx_permission_granted", "permission_node", "granted"),
)
class UserRelationships(Base):
"""用户关系模型 - 存储用户与bot的关系数据"""
__tablename__ = "user_relationships"
id = Column(Integer, primary_key=True, autoincrement=True)
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
user_name = Column(get_string_field(100), nullable=True) # 用户名
relationship_text = Column(Text, nullable=True) # 关系印象描述
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
__table_args__ = (
Index("idx_user_relationship_id", "user_id"),
Index("idx_relationship_score", "relationship_score"),
Index("idx_relationship_updated", "last_updated"),
)