refactor: 清理项目结构并修复类型注解问题

修复 SQLAlchemy 模型的类型注解,使用 Mapped 类型避免类型检查器错误
- 修正异步数据库操作中缺少 await 的问题
- 优化反注入统计系统的数值字段处理逻辑
- 添加缺失的导入语句修复模块依赖问题
This commit is contained in:
雅诺狐
2025-10-07 11:35:12 +08:00
parent 167e4d2520
commit 875ee4813c
19 changed files with 1466 additions and 3997 deletions

View File

@@ -0,0 +1,220 @@
"""批量将经典 SQLAlchemy 模型字段写法
field = Column(Integer, nullable=False, default=0)
转换为 2.0 推荐的带类型注解写法:
field: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
脚本特点:
1. 仅处理指定文件(默认: src/common/database/sqlalchemy_models.py)。
2. 自动识别多行 Column(...) 定义 (括号未闭合会继续合并)。
3. 已经是 Mapped 写法的行会跳过。
4. 根据类型名 (Integer / Float / Boolean / Text / String / DateTime / get_string_field) 推断 Python 类型。
5. nullable=True 时自动添加 "| None"
6. 保留 Column(...) 内的原始参数顺序与内容。
7. 生成 .bak 备份文件,确保可回滚。
8. 支持 --dry-run 查看差异,不写回文件。
局限/注意:
- 简单基于正则/括号计数,不解析完整 AST非常规写法(例如变量中构造 Column 再赋值)不会处理。
- 复杂工厂/自定义类型未在映射表中的,统一映射为 Any。
- 不自动添加 from __future__ import annotations如需 Python 3.10 以下更先进类型表达式,请自行处理。
使用方式: (在项目根目录执行)
python scripts/convert_sqlalchemy_models.py \
--file src/common/database/sqlalchemy_models.py --dry-run
确认无误后去掉 --dry-run 真实写入。
"""
from __future__ import annotations
import argparse
import re
import shutil
from pathlib import Path
from typing import Any
TYPE_MAP = {
"Integer": "int",
"Float": "float",
"Boolean": "bool",
"Text": "str",
"String": "str",
"DateTime": "datetime.datetime",
# 自定义帮助函数 get_string_field(...) 也返回字符串类型
"get_string_field": "str",
}
COLUMN_ASSIGN_RE = re.compile(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*Column\(")
ALREADY_MAPPED_RE = re.compile(r"^[ \t]*[A-Za-z_][A-Za-z0-9_]*\s*:\s*Mapped\[")
def detect_column_block(lines: list[str], start_index: int) -> tuple[int, int] | None:
"""检测从 start_index 开始的 Column(...) 语句跨越的行范围 (包含结束行)。
使用括号计数法处理多行。
返回 (start, end) 行号 (包含 end)。"""
line = lines[start_index]
if "Column(" not in line:
return None
open_parens = line.count("(") - line.count(")")
i = start_index
while open_parens > 0 and i + 1 < len(lines):
i += 1
l2 = lines[i]
open_parens += l2.count("(") - l2.count(")")
return (start_index, i)
def extract_column_body(block_lines: list[str]) -> str:
"""提取 Column(...) 内部参数文本 (去掉首尾 Column( 和 最后一个 ) )。"""
joined = "\n".join(block_lines)
# 找到第一次出现 Column(
start_pos = joined.find("Column(")
if start_pos == -1:
return ""
inner = joined[start_pos + len("Column(") :]
# 去掉最后一个 ) —— 简单方式: 找到最后一个 ) 并截断
last_paren = inner.rfind(")")
if last_paren != -1:
inner = inner[:last_paren]
return inner.strip()
def guess_python_type(column_body: str) -> str:
# 简单取第一个类型标识符 (去掉前导装饰/空格)
# 可能形式: Integer, Text, get_string_field(50), DateTime, Boolean
# 利用正则抓取第一个标识符
m = re.search(r"([A-Za-z_][A-Za-z0-9_]*)", column_body)
if not m:
return "Any"
type_token = m.group(1)
py_type = TYPE_MAP.get(type_token, "Any")
# nullable 检测
if "nullable=True" in column_body or "nullable = True" in column_body:
# 避免重复 Optional
if py_type != "Any" and not py_type.endswith(" | None"):
py_type = f"{py_type} | None"
elif py_type == "Any":
py_type = "Any | None"
return py_type
def convert_block(block_lines: list[str]) -> list[str]:
first_line = block_lines[0]
m_name = re.match(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=", first_line)
if not m_name:
return block_lines
indent = m_name.group("indent")
name = m_name.group("name")
body = extract_column_body(block_lines)
py_type = guess_python_type(body)
# 构造新的多行 mapped_column 写法
# 保留内部参数的换行缩进: 重新缩进为 indent + 4 空格 (延续原风格: 在 indent 基础上再加 4 空格)
inner_lines = body.split("\n")
if len(inner_lines) == 1:
new_line = f"{indent}{name}: Mapped[{py_type}] = mapped_column({inner_lines[0].strip()})\n"
return [new_line]
else:
# 多行情况
ind2 = indent + " "
rebuilt = [f"{indent}{name}: Mapped[{py_type}] = mapped_column(",]
for il in inner_lines:
if il.strip():
rebuilt.append(f"{ind2}{il.rstrip()}")
rebuilt.append(f"{indent})\n")
return [l + ("\n" if not l.endswith("\n") else "") for l in rebuilt]
def ensure_imports(content: str) -> str:
if "Mapped," in content or "Mapped[" in content:
# 已经可能存在导入
if "from sqlalchemy.orm import Mapped, mapped_column" not in content:
# 简单插到第一个 import sqlalchemy 之后
lines = content.splitlines()
for i, line in enumerate(lines):
if "sqlalchemy" in line and line.startswith("from sqlalchemy"):
lines.insert(i + 1, "from sqlalchemy.orm import Mapped, mapped_column")
return "\n".join(lines)
return content
def process_file(path: Path) -> tuple[str, str]:
original = path.read_text(encoding="utf-8")
lines = original.splitlines(keepends=True)
i = 0
out: list[str] = []
changed = 0
while i < len(lines):
line = lines[i]
# 跳过已是 Mapped 风格
if ALREADY_MAPPED_RE.match(line):
out.append(line)
i += 1
continue
if "= Column(" in line and re.match(r"^\s+[A-Za-z_][A-Za-z0-9_]*\s*=", line):
start, end = detect_column_block(lines, i) or (i, i)
block = lines[start : end + 1]
converted = convert_block(block)
out.extend(converted)
i = end + 1
# 如果转换结果与原始不同,计数
if "".join(converted) != "".join(block):
changed += 1
else:
out.append(line)
i += 1
new_content = "".join(out)
new_content = ensure_imports(new_content)
# 在文件末尾或头部预留统计信息打印(不写入文件,只返回)
return original, new_content if changed else original
def main():
parser = argparse.ArgumentParser(description="批量转换 SQLAlchemy 模型字段为 2.0 Mapped 写法")
parser.add_argument("--file", default="src/common/database/sqlalchemy_models.py", help="目标模型文件")
parser.add_argument("--dry-run", action="store_true", help="仅显示差异,不写回")
parser.add_argument("--write", action="store_true", help="执行写回 (与 --dry-run 互斥)")
args = parser.parse_args()
target = Path(args.file)
if not target.exists():
raise SystemExit(f"文件不存在: {target}")
original, new_content = process_file(target)
if original == new_content:
print("[INFO] 没有需要转换的内容或转换后无差异。")
return
# 简单差异输出 (行对比)
if args.dry_run or not args.write:
print("[DRY-RUN] 以下为转换后预览 (仅显示不同段落):")
import difflib
diff = difflib.unified_diff(
original.splitlines(), new_content.splitlines(), fromfile="original", tofile="converted", lineterm=""
)
count = 0
for d in diff:
print(d)
count += 1
if count == 0:
print("[INFO] 差异为空 (可能未匹配到 Column 定义)。")
if not args.write:
print("\n未写回。若确认无误,添加 --write 执行替换。")
return
backup = target.with_suffix(target.suffix + ".bak")
shutil.copyfile(target, backup)
target.write_text(new_content, encoding="utf-8")
print(f"[DONE] 已写回: {target},备份文件: {backup.name}")
if __name__ == "__main__": # pragma: no cover
main()

View File

@@ -1,284 +0,0 @@
import os
import sys
import time
from datetime import datetime
# Add project root to Python path
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
from src.common.database.database_model import Messages, ChatStreams # noqa
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def format_timestamp(timestamp: float) -> str:
"""Format timestamp to readable date string"""
try:
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
except (ValueError, OSError):
return "未知时间"
def calculate_interest_value_distribution(messages) -> dict[str, int]:
"""Calculate distribution of interest_value"""
distribution = {
"0.000-0.010": 0,
"0.010-0.050": 0,
"0.050-0.100": 0,
"0.100-0.500": 0,
"0.500-1.000": 0,
"1.000-2.000": 0,
"2.000-5.000": 0,
"5.000-10.000": 0,
"10.000+": 0,
}
for msg in messages:
if msg.interest_value is None or msg.interest_value == 0.0:
continue
value = float(msg.interest_value)
if value < 0.010:
distribution["0.000-0.010"] += 1
elif value < 0.050:
distribution["0.010-0.050"] += 1
elif value < 0.100:
distribution["0.050-0.100"] += 1
elif value < 0.500:
distribution["0.100-0.500"] += 1
elif value < 1.000:
distribution["0.500-1.000"] += 1
elif value < 2.000:
distribution["1.000-2.000"] += 1
elif value < 5.000:
distribution["2.000-5.000"] += 1
elif value < 10.000:
distribution["5.000-10.000"] += 1
else:
distribution["10.000+"] += 1
return distribution
def get_interest_value_stats(messages) -> dict[str, float]:
"""Calculate basic statistics for interest_value"""
values = [
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
]
if not values:
return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0}
values.sort()
count = len(values)
return {
"count": count,
"min": min(values),
"max": max(values),
"avg": sum(values) / count,
"median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2,
}
def get_available_chats() -> list[tuple[str, str, int]]:
"""Get all available chats with message counts"""
try:
# 获取所有有消息的chat_id
chat_counts = {}
for msg in Messages.select(Messages.chat_id).distinct():
chat_id = msg.chat_id
count = (
Messages.select()
.where(
(Messages.chat_id == chat_id)
& (Messages.interest_value.is_null(False))
& (Messages.interest_value != 0.0)
)
.count()
)
if count > 0:
chat_counts[chat_id] = count
# 获取聊天名称
result = []
for chat_id, count in chat_counts.items():
chat_name = get_chat_name(chat_id)
result.append((chat_id, chat_name, count))
# 按消息数量排序
result.sort(key=lambda x: x[2], reverse=True)
return result
except Exception as e:
print(f"获取聊天列表失败: {e}")
return []
def get_time_range_input() -> tuple[float | None, float | None]:
"""Get time range input from user"""
print("\n时间范围选择:")
print("1. 最近1天")
print("2. 最近3天")
print("3. 最近7天")
print("4. 最近30天")
print("5. 自定义时间范围")
print("6. 不限制时间")
choice = input("请选择时间范围 (1-6): ").strip()
now = time.time()
if choice == "1":
return now - 24 * 3600, now
elif choice == "2":
return now - 3 * 24 * 3600, now
elif choice == "3":
return now - 7 * 24 * 3600, now
elif choice == "4":
return now - 30 * 24 * 3600, now
elif choice == "5":
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
start_str = input().strip()
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
end_str = input().strip()
try:
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
return start_time, end_time
except ValueError:
print("时间格式错误,将不限制时间范围")
return None, None
else:
return None, None
def analyze_interest_values(
chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None
) -> None:
"""Analyze interest values with optional filters"""
# 构建查询条件
query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0))
if chat_id:
query = query.where(Messages.chat_id == chat_id)
if start_time:
query = query.where(Messages.time >= start_time)
if end_time:
query = query.where(Messages.time <= end_time)
messages = list(query)
if not messages:
print("没有找到符合条件的消息")
return
# 计算统计信息
distribution = calculate_interest_value_distribution(messages)
stats = get_interest_value_stats(messages)
# 显示结果
print("\n=== Interest Value 分析结果 ===")
if chat_id:
print(f"聊天: {get_chat_name(chat_id)}")
else:
print("聊天: 全部聊天")
if start_time and end_time:
print(f"时间范围: {format_timestamp(start_time)}{format_timestamp(end_time)}")
elif start_time:
print(f"时间范围: {format_timestamp(start_time)} 之后")
elif end_time:
print(f"时间范围: {format_timestamp(end_time)} 之前")
else:
print("时间范围: 不限制")
print("\n基本统计:")
print(f"有效消息数量: {stats['count']} (排除null和0值)")
print(f"最小值: {stats['min']:.3f}")
print(f"最大值: {stats['max']:.3f}")
print(f"平均值: {stats['avg']:.3f}")
print(f"中位数: {stats['median']:.3f}")
print("\nInterest Value 分布:")
total = stats["count"]
for range_name, count in distribution.items():
if count > 0:
percentage = count / total * 100
print(f"{range_name}: {count} ({percentage:.2f}%)")
def interactive_menu() -> None:
"""Interactive menu for interest value analysis"""
while True:
print("\n" + "=" * 50)
print("Interest Value 分析工具")
print("=" * 50)
print("1. 分析全部聊天")
print("2. 选择特定聊天分析")
print("q. 退出")
choice = input("\n请选择分析模式 (1-2, q): ").strip()
if choice.lower() == "q":
print("再见!")
break
chat_id = None
if choice == "2":
# 显示可用的聊天列表
chats = get_available_chats()
if not chats:
print("没有找到有interest_value数据的聊天")
continue
print(f"\n可用的聊天 (共{len(chats)}个):")
for i, (_cid, name, count) in enumerate(chats, 1):
print(f"{i}. {name} ({count}条有效消息)")
try:
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
if 1 <= chat_choice <= len(chats):
chat_id = chats[chat_choice - 1][0]
else:
print("无效选择")
continue
except ValueError:
print("请输入有效数字")
continue
elif choice != "1":
print("无效选择")
continue
# 获取时间范围
start_time, end_time = get_time_range_input()
# 执行分析
analyze_interest_values(chat_id, start_time, end_time)
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

File diff suppressed because it is too large Load Diff

View File

@@ -1,239 +0,0 @@
"""
插件Manifest管理命令行工具
提供插件manifest文件的创建、验证和管理功能
"""
import argparse
import os
import sys
from pathlib import Path
import orjson
from src.common.logger import get_logger
from src.plugin_system.utils.manifest_utils import (
ManifestValidator,
)
# 添加项目根目录到Python路径
project_root = Path(__file__).parent.parent.parent.parent
sys.path.insert(0, str(project_root))
logger = get_logger("manifest_tool")
def create_minimal_manifest(plugin_dir: str, plugin_name: str, description: str = "", author: str = "") -> bool:
"""创建最小化的manifest文件
Args:
plugin_dir: 插件目录
plugin_name: 插件名称
description: 插件描述
author: 插件作者
Returns:
bool: 是否创建成功
"""
manifest_path = os.path.join(plugin_dir, "_manifest.json")
if os.path.exists(manifest_path):
print(f"❌ Manifest文件已存在: {manifest_path}")
return False
# 创建最小化manifest
minimal_manifest = {
"manifest_version": 1,
"name": plugin_name,
"version": "1.0.0",
"description": description or f"{plugin_name}插件",
"author": {"name": author or "Unknown"},
}
try:
with open(manifest_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(minimal_manifest, option=orjson.OPT_INDENT_2).decode("utf-8"))
print(f"✅ 已创建最小化manifest文件: {manifest_path}")
return True
except Exception as e:
print(f"❌ 创建manifest文件失败: {e}")
return False
def create_complete_manifest(plugin_dir: str, plugin_name: str) -> bool:
"""创建完整的manifest模板文件
Args:
plugin_dir: 插件目录
plugin_name: 插件名称
Returns:
bool: 是否创建成功
"""
manifest_path = os.path.join(plugin_dir, "_manifest.json")
if os.path.exists(manifest_path):
print(f"❌ Manifest文件已存在: {manifest_path}")
return False
# 创建完整模板
complete_manifest = {
"manifest_version": 1,
"name": plugin_name,
"version": "1.0.0",
"description": f"{plugin_name}插件描述",
"author": {"name": "插件作者", "url": "https://github.com/your-username"},
"license": "MIT",
"host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
"homepage_url": "https://github.com/your-repo",
"repository_url": "https://github.com/your-repo",
"keywords": ["keyword1", "keyword2"],
"categories": ["Category1"],
"default_locale": "zh-CN",
"locales_path": "_locales",
"plugin_info": {
"is_built_in": False,
"plugin_type": "general",
"components": [{"type": "action", "name": "sample_action", "description": "示例动作组件"}],
},
}
try:
with open(manifest_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(complete_manifest, option=orjson.OPT_INDENT_2).decode("utf-8"))
print(f"✅ 已创建完整manifest模板: {manifest_path}")
print("💡 请根据实际情况修改manifest文件中的内容")
return True
except Exception as e:
print(f"❌ 创建manifest文件失败: {e}")
return False
def validate_manifest_file(plugin_dir: str) -> bool:
"""验证manifest文件
Args:
plugin_dir: 插件目录
Returns:
bool: 是否验证通过
"""
manifest_path = os.path.join(plugin_dir, "_manifest.json")
if not os.path.exists(manifest_path):
print(f"❌ 未找到manifest文件: {manifest_path}")
return False
try:
with open(manifest_path, encoding="utf-8") as f:
manifest_data = orjson.loads(f.read())
validator = ManifestValidator()
is_valid = validator.validate_manifest(manifest_data)
# 显示验证结果
print("📋 Manifest验证结果:")
print(validator.get_validation_report())
if is_valid:
print("✅ Manifest文件验证通过")
else:
print("❌ Manifest文件验证失败")
return is_valid
except orjson.JSONDecodeError as e:
print(f"❌ Manifest文件格式错误: {e}")
return False
except Exception as e:
print(f"❌ 验证过程中发生错误: {e}")
return False
def scan_plugins_without_manifest(root_dir: str) -> None:
"""扫描缺少manifest文件的插件
Args:
root_dir: 扫描的根目录
"""
print(f"🔍 扫描目录: {root_dir}")
plugins_without_manifest = []
for root, dirs, files in os.walk(root_dir):
# 跳过隐藏目录和__pycache__
dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"]
# 检查是否包含plugin.py文件标识为插件目录
if "plugin.py" in files:
manifest_path = os.path.join(root, "_manifest.json")
if not os.path.exists(manifest_path):
plugins_without_manifest.append(root)
if plugins_without_manifest:
print(f"❌ 发现 {len(plugins_without_manifest)} 个插件缺少manifest文件:")
for plugin_dir in plugins_without_manifest:
plugin_name = os.path.basename(plugin_dir)
print(f" - {plugin_name}: {plugin_dir}")
print("💡 使用 'python manifest_tool.py create-minimal <插件目录>' 创建manifest文件")
else:
print("✅ 所有插件都有manifest文件")
def main():
"""主函数"""
parser = argparse.ArgumentParser(description="插件Manifest管理工具")
subparsers = parser.add_subparsers(dest="command", help="可用命令")
# 创建最小化manifest命令
create_minimal_parser = subparsers.add_parser("create-minimal", help="创建最小化manifest文件")
create_minimal_parser.add_argument("plugin_dir", help="插件目录路径")
create_minimal_parser.add_argument("--name", help="插件名称")
create_minimal_parser.add_argument("--description", help="插件描述")
create_minimal_parser.add_argument("--author", help="插件作者")
# 创建完整manifest命令
create_complete_parser = subparsers.add_parser("create-complete", help="创建完整manifest模板")
create_complete_parser.add_argument("plugin_dir", help="插件目录路径")
create_complete_parser.add_argument("--name", help="插件名称")
# 验证manifest命令
validate_parser = subparsers.add_parser("validate", help="验证manifest文件")
validate_parser.add_argument("plugin_dir", help="插件目录路径")
# 扫描插件命令
scan_parser = subparsers.add_parser("scan", help="扫描缺少manifest的插件")
scan_parser.add_argument("root_dir", help="扫描的根目录路径")
args = parser.parse_args()
if not args.command:
parser.print_help()
return
try:
if args.command == "create-minimal":
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
success = create_minimal_manifest(args.plugin_dir, plugin_name, args.description or "", args.author or "")
sys.exit(0 if success else 1)
elif args.command == "create-complete":
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
success = create_complete_manifest(args.plugin_dir, plugin_name)
sys.exit(0 if success else 1)
elif args.command == "validate":
success = validate_manifest_file(args.plugin_dir)
sys.exit(0 if success else 1)
elif args.command == "scan":
scan_plugins_without_manifest(args.root_dir)
except Exception as e:
print(f"❌ 执行命令时发生错误: {e}")
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -1,922 +0,0 @@
import os
# import time
import pickle
import sys # 新增系统模块导入
from pathlib import Path
import orjson
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from peewee import Field, IntegrityError, Model
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
# Rich 进度条和显示组件
from rich.console import Console
from rich.panel import Panel
from rich.progress import (
BarColumn,
Progress,
SpinnerColumn,
TaskProgressColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
)
from rich.table import Table
# from rich.text import Text
from src.common.database.database import db
from src.common.database.sqlalchemy_models import (
ChatStreams,
Emoji,
GraphEdges,
GraphNodes,
ImageDescriptions,
Images,
Knowledges,
Messages,
PersonInfo,
ThinkingLog,
)
from src.common.logger import get_logger
logger = get_logger("mongodb_to_sqlite")
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@dataclass
class MigrationConfig:
"""迁移配置类"""
mongo_collection: str
target_model: type[Model]
field_mapping: dict[str, str]
batch_size: int = 500
enable_validation: bool = True
skip_duplicates: bool = True
unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段
# 数据验证相关类已移除 - 用户要求不要数据验证
@dataclass
class MigrationCheckpoint:
"""迁移断点数据"""
collection_name: str
processed_count: int
last_processed_id: Any
timestamp: datetime
batch_errors: list[dict[str, Any]] = field(default_factory=list)
@dataclass
class MigrationStats:
"""迁移统计信息"""
total_documents: int = 0
processed_count: int = 0
success_count: int = 0
error_count: int = 0
skipped_count: int = 0
duplicate_count: int = 0
validation_errors: int = 0
batch_insert_count: int = 0
errors: list[dict[str, Any]] = field(default_factory=list)
start_time: datetime | None = None
end_time: datetime | None = None
def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None):
"""添加错误记录"""
self.errors.append(
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
)
self.error_count += 1
def add_validation_error(self, doc_id: Any, field: str, error: str):
"""添加验证错误"""
self.add_error(doc_id, f"验证失败 - {field}: {error}")
self.validation_errors += 1
class MongoToSQLiteMigrator:
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
def __init__(self, mongo_uri: str | None = None, database_name: str | None = None):
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
self.mongo_uri = mongo_uri or self._build_mongo_uri()
self.mongo_client: MongoClient | None = None
self.mongo_db = None
# 迁移配置
self.migration_configs = self._initialize_migration_configs()
# 进度条控制台
self.console = Console()
# 检查点目录
self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints"))
self.checkpoint_dir.mkdir(exist_ok=True)
# 验证规则已禁用
self.validation_rules = self._initialize_validation_rules()
def _build_mongo_uri(self) -> str:
"""构建MongoDB连接URI"""
if mongo_uri := os.getenv("MONGODB_URI"):
return mongo_uri
user = os.getenv("MONGODB_USER")
password = os.getenv("MONGODB_PASS")
host = os.getenv("MONGODB_HOST", "localhost")
port = os.getenv("MONGODB_PORT", "27017")
auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin")
if user and password:
return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}"
else:
return f"mongodb://{host}:{port}/{self.database_name}"
def _initialize_migration_configs(self) -> list[MigrationConfig]:
"""初始化迁移配置"""
return [ # 表情包迁移配置
MigrationConfig(
mongo_collection="emoji",
target_model=Emoji,
field_mapping={
"full_path": "full_path",
"format": "format",
"hash": "emoji_hash",
"description": "description",
"emotion": "emotion",
"usage_count": "usage_count",
"last_used_time": "last_used_time",
# record_time字段将在转换时自动设置为当前时间
},
enable_validation=False, # 禁用数据验证
unique_fields=["full_path", "emoji_hash"],
),
# 聊天流迁移配置
MigrationConfig(
mongo_collection="chat_streams",
target_model=ChatStreams,
field_mapping={
"stream_id": "stream_id",
"create_time": "create_time",
"group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null而新的数据库不允许为null所以私聊聊天流是没法迁移的等更新吧。
"group_info.group_id": "group_id", # 同上
"group_info.group_name": "group_name", # 同上
"last_active_time": "last_active_time",
"platform": "platform",
"user_info.platform": "user_platform",
"user_info.user_id": "user_id",
"user_info.user_nickname": "user_nickname",
"user_info.user_cardname": "user_cardname",
},
enable_validation=False, # 禁用数据验证
unique_fields=["stream_id"],
),
# 消息迁移配置
MigrationConfig(
mongo_collection="messages",
target_model=Messages,
field_mapping={
"message_id": "message_id",
"time": "time",
"chat_id": "chat_id",
"chat_info.stream_id": "chat_info_stream_id",
"chat_info.platform": "chat_info_platform",
"chat_info.user_info.platform": "chat_info_user_platform",
"chat_info.user_info.user_id": "chat_info_user_id",
"chat_info.user_info.user_nickname": "chat_info_user_nickname",
"chat_info.user_info.user_cardname": "chat_info_user_cardname",
"chat_info.group_info.platform": "chat_info_group_platform",
"chat_info.group_info.group_id": "chat_info_group_id",
"chat_info.group_info.group_name": "chat_info_group_name",
"chat_info.create_time": "chat_info_create_time",
"chat_info.last_active_time": "chat_info_last_active_time",
"user_info.platform": "user_platform",
"user_info.user_id": "user_id",
"user_info.user_nickname": "user_nickname",
"user_info.user_cardname": "user_cardname",
"processed_plain_text": "processed_plain_text",
"memorized_times": "memorized_times",
},
enable_validation=False, # 禁用数据验证
unique_fields=["message_id"],
),
# 图片迁移配置
MigrationConfig(
mongo_collection="images",
target_model=Images,
field_mapping={
"hash": "emoji_hash",
"description": "description",
"path": "path",
"timestamp": "timestamp",
"type": "type",
},
unique_fields=["path"],
),
# 图片描述迁移配置
MigrationConfig(
mongo_collection="image_descriptions",
target_model=ImageDescriptions,
field_mapping={
"type": "type",
"hash": "image_description_hash",
"description": "description",
"timestamp": "timestamp",
},
unique_fields=["image_description_hash", "type"],
),
# 个人信息迁移配置
MigrationConfig(
mongo_collection="person_info",
target_model=PersonInfo,
field_mapping={
"person_id": "person_id",
"person_name": "person_name",
"name_reason": "name_reason",
"platform": "platform",
"user_id": "user_id",
"nickname": "nickname",
"relationship_value": "relationship_value",
"konw_time": "know_time",
},
unique_fields=["person_id"],
),
# 知识库迁移配置
MigrationConfig(
mongo_collection="knowledges",
target_model=Knowledges,
field_mapping={"content": "content", "embedding": "embedding"},
unique_fields=["content"], # 假设内容唯一
),
# 思考日志迁移配置
MigrationConfig(
mongo_collection="thinking_log",
target_model=ThinkingLog,
field_mapping={
"chat_id": "chat_id",
"trigger_text": "trigger_text",
"response_text": "response_text",
"trigger_info": "trigger_info_json",
"response_info": "response_info_json",
"timing_results": "timing_results_json",
"chat_history": "chat_history_json",
"chat_history_in_thinking": "chat_history_in_thinking_json",
"chat_history_after_response": "chat_history_after_response_json",
"heartflow_data": "heartflow_data_json",
"reasoning_data": "reasoning_data_json",
},
unique_fields=["chat_id", "trigger_text"],
),
# 图节点迁移配置
MigrationConfig(
mongo_collection="graph_data.nodes",
target_model=GraphNodes,
field_mapping={
"concept": "concept",
"memory_items": "memory_items",
"hash": "hash",
"created_time": "created_time",
"last_modified": "last_modified",
},
unique_fields=["concept"],
),
# 图边迁移配置
MigrationConfig(
mongo_collection="graph_data.edges",
target_model=GraphEdges,
field_mapping={
"source": "source",
"target": "target",
"strength": "strength",
"hash": "hash",
"created_time": "created_time",
"last_modified": "last_modified",
},
unique_fields=["source", "target"], # 组合唯一性
),
]
def _initialize_validation_rules(self) -> dict[str, Any]:
"""数据验证已禁用 - 返回空字典"""
return {}
def connect_mongodb(self) -> bool:
"""连接到MongoDB"""
try:
self.mongo_client = MongoClient(
self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10
)
# 测试连接
self.mongo_client.admin.command("ping")
self.mongo_db = self.mongo_client[self.database_name]
logger.info(f"成功连接到MongoDB: {self.database_name}")
return True
except ConnectionFailure as e:
logger.error(f"MongoDB连接失败: {e}")
return False
except Exception as e:
logger.error(f"MongoDB连接异常: {e}")
return False
def disconnect_mongodb(self):
"""断开MongoDB连接"""
if self.mongo_client:
self.mongo_client.close()
logger.info("MongoDB连接已关闭")
def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any:
"""获取嵌套字段的值"""
if "." not in field_path:
return document.get(field_path)
parts = field_path.split(".")
value = document
for part in parts:
if isinstance(value, dict):
value = value.get(part)
else:
return None
if value is None:
break
return value
def _convert_field_value(self, value: Any, target_field: Field) -> Any:
"""根据目标字段类型转换值"""
if value is None:
return None
field_type = target_field.__class__.__name__
try:
if target_field.name == "record_time" and field_type == "DateTimeField":
return datetime.now()
if field_type in ["CharField", "TextField"]:
if isinstance(value, list | dict):
return orjson.dumps(value, ensure_ascii=False)
return str(value) if value is not None else ""
elif field_type == "IntegerField":
if isinstance(value, str):
# 处理字符串数字
clean_value = value.strip()
if clean_value.replace(".", "").replace("-", "").isdigit():
return int(float(clean_value))
return 0
return int(value) if value is not None else 0
elif field_type in ["FloatField", "DoubleField"]:
return float(value) if value is not None else 0.0
elif field_type == "BooleanField":
if isinstance(value, str):
return value.lower() in ("true", "1", "yes", "on")
return bool(value)
elif field_type == "DateTimeField":
if isinstance(value, int | float):
return datetime.fromtimestamp(value)
elif isinstance(value, str):
try:
# 尝试解析ISO格式日期
return datetime.fromisoformat(value.replace("Z", "+00:00"))
except ValueError:
try:
# 尝试解析时间戳字符串
return datetime.fromtimestamp(float(value))
except ValueError:
return datetime.now()
return datetime.now()
return value
except (ValueError, TypeError) as e:
logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}")
return self._get_default_value_for_field(target_field)
def _get_default_value_for_field(self, field: Field) -> Any:
"""获取字段的默认值"""
field_type = field.__class__.__name__
if hasattr(field, "default") and field.default is not None:
return field.default
if field.null:
return None
# 根据字段类型返回默认值
if field_type in ["CharField", "TextField"]:
return ""
elif field_type == "IntegerField":
return 0
elif field_type in ["FloatField", "DoubleField"]:
return 0.0
elif field_type == "BooleanField":
return False
elif field_type == "DateTimeField":
return datetime.now()
return None
def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
"""数据验证已禁用 - 始终返回True"""
return True
def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any):
"""保存迁移断点"""
checkpoint = MigrationCheckpoint(
collection_name=collection_name,
processed_count=processed_count,
last_processed_id=last_id,
timestamp=datetime.now(),
)
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
try:
with open(checkpoint_file, "wb") as f:
pickle.dump(checkpoint, f)
except Exception as e:
logger.warning(f"保存断点失败: {e}")
def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None:
"""加载迁移断点"""
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
if not checkpoint_file.exists():
return None
try:
with open(checkpoint_file, "rb") as f:
return pickle.load(f)
except Exception as e:
logger.warning(f"加载断点失败: {e}")
return None
def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int:
"""批量插入数据"""
if not data_list:
return 0
success_count = 0
try:
with db.atomic():
# 分批插入避免SQL语句过长
batch_size = 100
for i in range(0, len(data_list), batch_size):
batch = data_list[i : i + batch_size]
model.insert_many(batch).execute()
success_count += len(batch)
except Exception as e:
logger.error(f"批量插入失败: {e}")
# 如果批量插入失败,尝试逐个插入
for data in data_list:
try:
model.create(**data)
success_count += 1
except Exception:
pass # 忽略单个插入失败
return success_count
def _check_duplicate_by_unique_fields(
self, model: type[Model], data: dict[str, Any], unique_fields: list[str]
) -> bool:
"""根据唯一字段检查重复"""
if not unique_fields:
return False
try:
query = model.select()
for field_name in unique_fields:
if field_name in data and data[field_name] is not None:
field_obj = getattr(model, field_name)
query = query.where(field_obj == data[field_name])
return query.exists()
except Exception as e:
logger.debug(f"重复检查失败: {e}")
return False
def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None:
"""使用ORM创建模型实例"""
try:
# 过滤掉不存在的字段
valid_data = {}
for field_name, value in data.items():
if hasattr(model, field_name):
valid_data[field_name] = value
else:
logger.debug(f"跳过未知字段: {field_name}")
# 创建实例
instance = model.create(**valid_data)
return instance
except IntegrityError as e:
# 处理唯一约束冲突等完整性错误
logger.debug(f"完整性约束冲突: {e}")
return None
except Exception as e:
logger.error(f"创建模型实例失败: {e}")
return None
def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
"""迁移单个集合 - 使用优化的批量插入和进度条"""
stats = MigrationStats()
stats.start_time = datetime.now()
# 检查是否有断点
checkpoint = self._load_checkpoint(config.mongo_collection)
start_from_id = checkpoint.last_processed_id if checkpoint else None
if checkpoint:
stats.processed_count = checkpoint.processed_count
logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录")
logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}")
try:
# 获取MongoDB集合
mongo_collection = self.mongo_db[config.mongo_collection]
# 构建查询条件(用于断点恢复)
query = {}
if start_from_id:
query = {"_id": {"$gt": start_from_id}}
stats.total_documents = mongo_collection.count_documents(query)
if stats.total_documents == 0:
logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移")
return stats
logger.info(f"待迁移文档数量: {stats.total_documents}")
# 创建Rich进度条
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
BarColumn(),
TaskProgressColumn(),
TimeElapsedColumn(),
TimeRemainingColumn(),
console=self.console,
refresh_per_second=10,
) as progress:
task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents)
# 批量处理数据
batch_data = []
batch_count = 0
last_processed_id = None
for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size):
try:
doc_id = mongo_doc.get("_id", "unknown")
last_processed_id = doc_id
# 构建目标数据
target_data = {}
for mongo_field, sqlite_field in config.field_mapping.items():
value = self._get_nested_value(mongo_doc, mongo_field)
# 获取目标字段对象并转换类型
if hasattr(config.target_model, sqlite_field):
field_obj = getattr(config.target_model, sqlite_field)
converted_value = self._convert_field_value(value, field_obj)
target_data[sqlite_field] = converted_value
# 数据验证已禁用
# if config.enable_validation:
# if not self._validate_data(config.mongo_collection, target_data, doc_id, stats):
# stats.skipped_count += 1
# continue
# 重复检查
if config.skip_duplicates and self._check_duplicate_by_unique_fields(
config.target_model, target_data, config.unique_fields
):
stats.duplicate_count += 1
stats.skipped_count += 1
logger.debug(f"跳过重复记录: {doc_id}")
continue
# 添加到批量数据
batch_data.append(target_data)
stats.processed_count += 1
# 执行批量插入
if len(batch_data) >= config.batch_size:
success_count = self._batch_insert(config.target_model, batch_data)
stats.success_count += success_count
stats.batch_insert_count += 1
# 保存断点
self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id)
batch_data.clear()
batch_count += 1
# 更新进度条
progress.update(task, advance=config.batch_size)
except Exception as e:
doc_id = mongo_doc.get("_id", "unknown")
stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
# 处理剩余的批量数据
if batch_data:
success_count = self._batch_insert(config.target_model, batch_data)
stats.success_count += success_count
stats.batch_insert_count += 1
progress.update(task, advance=len(batch_data))
# 完成进度条
progress.update(task, completed=stats.total_documents)
stats.end_time = datetime.now()
duration = stats.end_time - stats.start_time
logger.info(
f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n"
f"总计: {stats.total_documents}, 成功: {stats.success_count}, "
f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n"
f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}"
)
# 清理断点文件
checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl"
if checkpoint_file.exists():
checkpoint_file.unlink()
except Exception as e:
logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}")
stats.add_error("collection_error", str(e))
return stats
def migrate_all(self) -> dict[str, MigrationStats]:
"""执行所有迁移任务"""
logger.info("开始执行数据库迁移...")
if not self.connect_mongodb():
logger.error("无法连接到MongoDB迁移终止")
return {}
all_stats = {}
try:
# 创建总体进度表格
total_collections = len(self.migration_configs)
self.console.print(
Panel(
f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
f"[yellow]总集合数: {total_collections}[/yellow]",
title="迁移开始",
expand=False,
)
)
for idx, config in enumerate(self.migration_configs, 1):
self.console.print(
f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]"
)
stats = self.migrate_collection(config)
all_stats[config.mongo_collection] = stats
# 显示单个集合的快速统计
if stats.processed_count > 0:
success_rate = stats.success_count / stats.processed_count * 100
if success_rate >= 95:
status_emoji = ""
status_color = "bright_green"
elif success_rate >= 80:
status_emoji = "⚠️"
status_color = "yellow"
else:
status_emoji = ""
status_color = "red"
self.console.print(
f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} "
f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]"
)
# 错误率检查
if stats.processed_count > 0:
error_rate = stats.error_count / stats.processed_count
if error_rate > 0.1: # 错误率超过10%
self.console.print(
f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} "
f"({stats.error_count}/{stats.processed_count})[/red]"
)
finally:
self.disconnect_mongodb()
self._print_migration_summary(all_stats)
return all_stats
def _print_migration_summary(self, all_stats: dict[str, MigrationStats]):
"""使用Rich打印美观的迁移汇总信息"""
# 计算总体统计
total_processed = sum(stats.processed_count for stats in all_stats.values())
total_success = sum(stats.success_count for stats in all_stats.values())
total_errors = sum(stats.error_count for stats in all_stats.values())
total_skipped = sum(stats.skipped_count for stats in all_stats.values())
total_duplicates = sum(stats.duplicate_count for stats in all_stats.values())
total_validation_errors = sum(stats.validation_errors for stats in all_stats.values())
total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values())
# 计算总耗时
total_duration_seconds = 0
for stats in all_stats.values():
if stats.start_time and stats.end_time:
duration = stats.end_time - stats.start_time
total_duration_seconds += duration.total_seconds()
# 创建详细统计表格
table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta")
table.add_column("集合名称", style="cyan", width=20)
table.add_column("文档总数", justify="right", style="blue")
table.add_column("处理数量", justify="right", style="green")
table.add_column("成功数量", justify="right", style="green")
table.add_column("错误数量", justify="right", style="red")
table.add_column("跳过数量", justify="right", style="yellow")
table.add_column("重复数量", justify="right", style="bright_yellow")
table.add_column("验证错误", justify="right", style="red")
table.add_column("批次数", justify="right", style="purple")
table.add_column("成功率", justify="right", style="bright_green")
table.add_column("耗时(秒)", justify="right", style="blue")
for collection_name, stats in all_stats.items():
success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0
duration = 0
if stats.start_time and stats.end_time:
duration = (stats.end_time - stats.start_time).total_seconds()
# 根据成功率设置颜色
if success_rate >= 95:
success_rate_style = "[bright_green]"
elif success_rate >= 80:
success_rate_style = "[yellow]"
else:
success_rate_style = "[red]"
table.add_row(
collection_name,
str(stats.total_documents),
str(stats.processed_count),
str(stats.success_count),
f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0",
f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0",
f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0",
f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0",
str(stats.batch_insert_count),
f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}",
f"{duration:.2f}",
)
# 添加总计行
total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0
if total_success_rate >= 95:
total_rate_style = "[bright_green]"
elif total_success_rate >= 80:
total_rate_style = "[yellow]"
else:
total_rate_style = "[red]"
table.add_section()
table.add_row(
"[bold]总计[/bold]",
f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]",
f"[bold]{total_processed}[/bold]",
f"[bold]{total_success}[/bold]",
f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]",
f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]",
f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]"
if total_duplicates > 0
else "[bold]0[/bold]",
f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]",
f"[bold]{total_batch_inserts}[/bold]",
f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]",
f"[bold]{total_duration_seconds:.2f}[/bold]",
)
self.console.print(table)
# 创建状态面板
status_items = []
if total_errors > 0:
status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]")
if total_validation_errors > 0:
status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]")
if total_duplicates > 0:
status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]")
if total_success_rate >= 95:
status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]")
elif total_success_rate >= 80:
status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]")
else:
status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]")
if status_items:
status_panel = Panel(
"\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow"
)
self.console.print(status_panel)
# 性能统计面板
avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0
performance_info = (
f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f}\n"
f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n"
f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作"
)
performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green")
self.console.print(performance_panel)
def add_migration_config(self, config: MigrationConfig):
"""添加新的迁移配置"""
self.migration_configs.append(config)
def migrate_single_collection(self, collection_name: str) -> MigrationStats | None:
"""迁移单个指定的集合"""
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
if not config:
logger.error(f"未找到集合 {collection_name} 的迁移配置")
return None
if not self.connect_mongodb():
logger.error("无法连接到MongoDB")
return None
try:
stats = self.migrate_collection(config)
self._print_migration_summary({collection_name: stats})
return stats
finally:
self.disconnect_mongodb()
def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str):
"""导出错误报告"""
error_report = {
"timestamp": datetime.now().isoformat(),
"summary": {
collection: {
"total": stats.total_documents,
"processed": stats.processed_count,
"success": stats.success_count,
"errors": stats.error_count,
"skipped": stats.skipped_count,
"duplicates": stats.duplicate_count,
}
for collection, stats in all_stats.items()
},
"errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors},
}
try:
with open(filepath, "w", encoding="utf-8") as f:
orjson.dumps(error_report, f, ensure_ascii=False, indent=2)
logger.info(f"错误报告已导出到: {filepath}")
except Exception as e:
logger.error(f"导出错误报告失败: {e}")
def main():
"""主程序入口"""
migrator = MongoToSQLiteMigrator()
# 执行迁移
migration_results = migrator.migrate_all()
# 导出错误报告(如果有错误)
if any(stats.error_count > 0 for stats in migration_results.values()):
error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
migrator.export_error_report(migration_results, error_report_path)
logger.info("数据迁移完成!")
if __name__ == "__main__":
main()

View File

@@ -1,556 +0,0 @@
#!/bin/bash
# MaiCore & NapCat Adapter一键安装脚本 by Cookie_987
# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9
# 请小心使用任何一键脚本!
INSTALLER_VERSION="0.0.5-refactor"
LANG=C.UTF-8
# 如无法访问GitHub请修改此处镜像地址
GITHUB_REPO="https://ghfast.top/https://github.com"
# 颜色输出
GREEN="\e[32m"
RED="\e[31m"
RESET="\e[0m"
# 需要的基本软件包
declare -A REQUIRED_PACKAGES=(
["common"]="git sudo python3 curl gnupg"
["debian"]="python3-venv python3-pip build-essential"
["ubuntu"]="python3-venv python3-pip build-essential"
["centos"]="epel-release python3-pip python3-devel gcc gcc-c++ make"
["arch"]="python-virtualenv python-pip base-devel"
)
# 默认项目目录
DEFAULT_INSTALL_DIR="/opt/maicore"
# 服务名称
SERVICE_NAME="maicore"
SERVICE_NAME_WEB="maicore-web"
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
IS_INSTALL_NAPCAT=false
IS_INSTALL_DEPENDENCIES=false
# 检查是否已安装
check_installed() {
[[ -f /etc/systemd/system/${SERVICE_NAME}.service ]]
}
# 加载安装信息
load_install_info() {
if [[ -f /etc/maicore_install.conf ]]; then
source /etc/maicore_install.conf
else
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
BRANCH="refactor"
fi
}
# 显示管理菜单
show_menu() {
while true; do
choice=$(whiptail --title "MaiCore管理菜单" --menu "请选择要执行的操作:" 15 60 7 \
"1" "启动MaiCore" \
"2" "停止MaiCore" \
"3" "重启MaiCore" \
"4" "启动NapCat Adapter" \
"5" "停止NapCat Adapter" \
"6" "重启NapCat Adapter" \
"7" "拉取最新MaiCore仓库" \
"8" "切换分支" \
"9" "退出" 3>&1 1>&2 2>&3)
[[ $? -ne 0 ]] && exit 0
case "$choice" in
1)
systemctl start ${SERVICE_NAME}
whiptail --msgbox "✅MaiCore已启动" 10 60
;;
2)
systemctl stop ${SERVICE_NAME}
whiptail --msgbox "🛑MaiCore已停止" 10 60
;;
3)
systemctl restart ${SERVICE_NAME}
whiptail --msgbox "🔄MaiCore已重启" 10 60
;;
4)
systemctl start ${SERVICE_NAME_NBADAPTER}
whiptail --msgbox "✅NapCat Adapter已启动" 10 60
;;
5)
systemctl stop ${SERVICE_NAME_NBADAPTER}
whiptail --msgbox "🛑NapCat Adapter已停止" 10 60
;;
6)
systemctl restart ${SERVICE_NAME_NBADAPTER}
whiptail --msgbox "🔄NapCat Adapter已重启" 10 60
;;
7)
update_dependencies
;;
8)
switch_branch
;;
9)
exit 0
;;
*)
whiptail --msgbox "无效选项!" 10 60
;;
esac
done
}
# 更新依赖
update_dependencies() {
whiptail --title "⚠" --msgbox "更新后请阅读教程" 10 60
systemctl stop ${SERVICE_NAME}
cd "${INSTALL_DIR}/MaiBot" || {
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
return 1
}
if ! git pull origin "${BRANCH}"; then
whiptail --msgbox "🚫 代码更新失败!" 10 60
return 1
fi
source "${INSTALL_DIR}/venv/bin/activate"
if ! pip install -r requirements.txt; then
whiptail --msgbox "🚫 依赖安装失败!" 10 60
deactivate
return 1
fi
deactivate
whiptail --msgbox "✅ 已停止服务并拉取最新仓库提交" 10 60
}
# 切换分支
switch_branch() {
new_branch=$(whiptail --inputbox "请输入要切换的分支名称:" 10 60 "${BRANCH}" 3>&1 1>&2 2>&3)
[[ -z "$new_branch" ]] && {
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
return 1
}
cd "${INSTALL_DIR}/MaiBot" || {
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
return 1
}
if ! git ls-remote --exit-code --heads origin "${new_branch}" >/dev/null 2>&1; then
whiptail --msgbox "🚫 分支 ${new_branch} 不存在!" 10 60
return 1
fi
if ! git checkout "${new_branch}"; then
whiptail --msgbox "🚫 分支切换失败!" 10 60
return 1
fi
if ! git pull origin "${new_branch}"; then
whiptail --msgbox "🚫 代码拉取失败!" 10 60
return 1
fi
systemctl stop ${SERVICE_NAME}
source "${INSTALL_DIR}/venv/bin/activate"
pip install -r requirements.txt
deactivate
sed -i "s/^BRANCH=.*/BRANCH=${new_branch}/" /etc/maicore_install.conf
BRANCH="${new_branch}"
check_eula
whiptail --msgbox "✅ 已停止服务并切换到分支 ${new_branch} " 10 60
}
check_eula() {
# 首先计算当前EULA的MD5值
current_md5=$(md5sum "${INSTALL_DIR}/MaiBot/EULA.md" | awk '{print $1}')
# 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(md5sum "${INSTALL_DIR}/MaiBot/PRIVACY.md" | awk '{print $1}')
# 如果当前的md5值为空则直接返回
if [[ -z $current_md5 || -z $current_md5_privacy ]]; then
whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60
fi
# 检查eula.confirmed文件是否存在
if [[ -f ${INSTALL_DIR}/MaiBot/eula.confirmed ]]; then
# 如果存在则检查其中包含的md5与current_md5是否一致
confirmed_md5=$(cat ${INSTALL_DIR}/MaiBot/eula.confirmed)
else
confirmed_md5=""
fi
# 检查privacy.confirmed文件是否存在
if [[ -f ${INSTALL_DIR}/MaiBot/privacy.confirmed ]]; then
# 如果存在则检查其中包含的md5与current_md5是否一致
confirmed_md5_privacy=$(cat ${INSTALL_DIR}/MaiBot/privacy.confirmed)
else
confirmed_md5_privacy=""
fi
# 如果EULA或隐私条款有更新提示用户重新确认
if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then
whiptail --title "📜 使用协议更新" --yesno "检测到MaiCore EULA或隐私条款已更新。\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议 \n\n " 12 70
if [[ $? -eq 0 ]]; then
echo -n $current_md5 > ${INSTALL_DIR}/MaiBot/eula.confirmed
echo -n $current_md5_privacy > ${INSTALL_DIR}/MaiBot/privacy.confirmed
else
exit 1
fi
fi
}
# ----------- 主安装流程 -----------
run_installation() {
# 1/6: 检测是否安装 whiptail
if ! command -v whiptail &>/dev/null; then
echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}"
if command -v apt-get &>/dev/null; then
apt-get update && apt-get install -y whiptail
elif command -v pacman &>/dev/null; then
pacman -Syu --noconfirm whiptail
elif command -v yum &>/dev/null; then
yum install -y whiptail
else
echo -e "${RED}[Error] 无受支持的包管理器,无法安装 whiptail!${RESET}"
exit 1
fi
fi
whiptail --title " 提示" --msgbox "如果您没有特殊需求请优先使用docker方式部署。" 10 60
# 协议确认
if ! (whiptail --title " [1/6] 使用协议" --yes-button "我同意" --no-button "我拒绝" --yesno "使用MaiCore及此脚本前请先阅读EULA协议及隐私协议\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议" 12 70); then
exit 1
fi
# 欢迎信息
whiptail --title "[2/6] 欢迎使用MaiCore一键安装脚本 by Cookie987" --msgbox "检测到您未安装MaiCore将自动进入安装流程安装完成后再次运行此脚本即可进入管理菜单。\n\n项目处于活跃开发阶段代码可能随时更改\n文档未完善有问题可以提交 Issue 或者 Discussion\nQQ机器人存在被限制风险请自行了解谨慎使用\n由于持续迭代可能存在一些已知或未知的bug\n由于开发中可能消耗较多token\n\n本脚本可能更新不及时如遇到bug请优先尝试手动部署以确定是否为脚本问题" 17 60
# 系统检查
check_system() {
if [[ "$(id -u)" -ne 0 ]]; then
whiptail --title "🚫 权限不足" --msgbox "请使用 root 用户运行此脚本!\n执行方式: sudo bash $0" 10 60
exit 1
fi
if [[ -f /etc/os-release ]]; then
source /etc/os-release
if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then
return
elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then
return
elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then
return
elif [[ "$ID" == "arch" ]]; then
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
return
else
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
exit 1
fi
else
whiptail --title "⚠️ 无法检测系统" --msgbox "无法识别系统版本,安装已终止。" 10 60
exit 1
fi
}
check_system
# 设置包管理器
case "$ID" in
debian|ubuntu)
PKG_MANAGER="apt"
;;
centos)
PKG_MANAGER="yum"
;;
arch)
# 添加arch包管理器
PKG_MANAGER="pacman"
;;
esac
# 检查NapCat
check_napcat() {
if command -v napcat &>/dev/null; then
NAPCAT_INSTALLED=true
else
NAPCAT_INSTALLED=false
fi
}
check_napcat
# 安装必要软件包
install_packages() {
missing_packages=()
# 检查 common 及当前系统专属依赖
for package in ${REQUIRED_PACKAGES["common"]} ${REQUIRED_PACKAGES["$ID"]}; do
case "$PKG_MANAGER" in
apt)
dpkg -s "$package" &>/dev/null || missing_packages+=("$package")
;;
yum)
rpm -q "$package" &>/dev/null || missing_packages+=("$package")
;;
pacman)
pacman -Qi "$package" &>/dev/null || missing_packages+=("$package")
;;
esac
done
if [[ ${#missing_packages[@]} -gt 0 ]]; then
whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装" 10 60
if [[ $? -eq 0 ]]; then
IS_INSTALL_DEPENDENCIES=true
else
whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续" 10 60 || exit 1
fi
fi
}
install_packages
# 安装NapCat
install_napcat() {
[[ $NAPCAT_INSTALLED == true ]] && return
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat是否安装\n如果您想使用远程NapCat请跳过此步。" 10 60 && {
IS_INSTALL_NAPCAT=true
}
}
# 仅在非Arch系统上安装NapCat
[[ "$ID" != "arch" ]] && install_napcat
# Python版本检查
check_python() {
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
if ! python3 -c "import sys; exit(0) if sys.version_info >= (3,10) else exit(1)"; then
whiptail --title "⚠️ [4/6] Python 版本过低" --msgbox "检测到 Python 版本为 $PYTHON_VERSION,需要 3.10 或以上!\n请升级 Python 后重新运行本脚本。" 10 60
exit 1
fi
}
# 如果没安装python则不检查python版本
if command -v python3 &>/dev/null; then
check_python
fi
# 选择分支
choose_branch() {
BRANCH=$(whiptail --title "🔀 选择分支" --radiolist "请选择要安装的分支:" 15 60 4 \
"main" "稳定版本(推荐)" ON \
"dev" "开发版(不知道什么意思就别选)" OFF \
"classical" "经典版0.6.0以前的版本)" OFF \
"custom" "自定义分支" OFF 3>&1 1>&2 2>&3)
RETVAL=$?
if [ $RETVAL -ne 0 ]; then
whiptail --msgbox "🚫 操作取消!" 10 60
exit 1
fi
if [[ "$BRANCH" == "custom" ]]; then
BRANCH=$(whiptail --title "🔀 自定义分支" --inputbox "请输入自定义分支名称:" 10 60 "refactor" 3>&1 1>&2 2>&3)
RETVAL=$?
if [ $RETVAL -ne 0 ]; then
whiptail --msgbox "🚫 输入取消!" 10 60
exit 1
fi
if [[ -z "$BRANCH" ]]; then
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
exit 1
fi
fi
}
choose_branch
# 选择安装路径
choose_install_dir() {
INSTALL_DIR=$(whiptail --title "📂 [6/6] 选择安装路径" --inputbox "请输入MaiCore的安装目录" 10 60 "$DEFAULT_INSTALL_DIR" 3>&1 1>&2 2>&3)
[[ -z "$INSTALL_DIR" ]] && {
whiptail --title "⚠️ 取消输入" --yesno "未输入安装路径,是否退出安装?" 10 60 && exit 1
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
}
}
choose_install_dir
# 确认安装
confirm_install() {
local confirm_msg="请确认以下更改:\n\n"
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
confirm_msg+="🔀 分支: $BRANCH\n"
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
confirm_msg+="\n注意本脚本默认使用ghfast.top为GitHub进行加速如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1
}
confirm_install
# 开始安装
echo -e "${GREEN}安装${missing_packages[@]}...${RESET}"
if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then
case "$PKG_MANAGER" in
apt)
apt update && apt install -y "${missing_packages[@]}"
;;
yum)
yum install -y "${missing_packages[@]}" --nobest
;;
pacman)
pacman -S --noconfirm "${missing_packages[@]}"
;;
esac
fi
if [[ $IS_INSTALL_NAPCAT == true ]]; then
echo -e "${GREEN}安装 NapCat...${RESET}"
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
fi
echo -e "${GREEN}创建安装目录...${RESET}"
mkdir -p "$INSTALL_DIR"
cd "$INSTALL_DIR" || exit 1
echo -e "${GREEN}设置Python虚拟环境...${RESET}"
python3 -m venv venv
source venv/bin/activate
echo -e "${GREEN}克隆MaiCore仓库...${RESET}"
git clone -b "$BRANCH" "$GITHUB_REPO/MaiM-with-u/MaiBot" MaiBot || {
echo -e "${RED}克隆MaiCore仓库失败${RESET}"
exit 1
}
echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}"
git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || {
echo -e "${RED}克隆 maim_message 包仓库失败!${RESET}"
exit 1
}
echo -e "${GREEN}克隆 nonebot-plugin-maibot-adapters 仓库...${RESET}"
git clone $GITHUB_REPO/MaiM-with-u/MaiBot-Napcat-Adapter.git || {
echo -e "${RED}克隆 MaiBot-Napcat-Adapter.git 仓库失败!${RESET}"
exit 1
}
echo -e "${GREEN}安装Python依赖...${RESET}"
pip install -r MaiBot/requirements.txt
cd MaiBot
pip install uv
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
cd ..
echo -e "${GREEN}安装maim_message依赖...${RESET}"
cd maim_message
uv pip install -i https://mirrors.aliyun.com/pypi/simple -e .
cd ..
echo -e "${GREEN}部署MaiBot Napcat Adapter...${RESET}"
cd MaiBot-Napcat-Adapter
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
cd ..
echo -e "${GREEN}同意协议...${RESET}"
# 首先计算当前EULA的MD5值
current_md5=$(md5sum "MaiBot/EULA.md" | awk '{print $1}')
# 首先计算当前隐私条款文件的哈希值
current_md5_privacy=$(md5sum "MaiBot/PRIVACY.md" | awk '{print $1}')
echo -n $current_md5 > MaiBot/eula.confirmed
echo -n $current_md5_privacy > MaiBot/privacy.confirmed
echo -e "${GREEN}创建系统服务...${RESET}"
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
[Unit]
Description=MaiCore
After=network.target ${SERVICE_NAME_NBADAPTER}.service
[Service]
Type=simple
WorkingDirectory=${INSTALL_DIR}/MaiBot
ExecStart=$INSTALL_DIR/venv/bin/python3 bot.py
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
EOF
# cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
# [Unit]
# Description=MaiCore WebUI
# After=network.target ${SERVICE_NAME}.service
# [Service]
# Type=simple
# WorkingDirectory=${INSTALL_DIR}/MaiBot
# ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
# Restart=always
# RestartSec=10s
# [Install]
# WantedBy=multi-user.target
# EOF
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
[Unit]
Description=MaiBot Napcat Adapter
After=network.target mongod.service ${SERVICE_NAME}.service
[Service]
Type=simple
WorkingDirectory=${INSTALL_DIR}/MaiBot-Napcat-Adapter
ExecStart=$INSTALL_DIR/venv/bin/python3 main.py
Restart=always
RestartSec=10s
[Install]
WantedBy=multi-user.target
EOF
systemctl daemon-reload
# 保存安装信息
echo "INSTALLER_VERSION=${INSTALLER_VERSION}" > /etc/maicore_install.conf
echo "INSTALL_DIR=${INSTALL_DIR}" >> /etc/maicore_install.conf
echo "BRANCH=${BRANCH}" >> /etc/maicore_install.conf
whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成\n已创建系统服务${SERVICE_NAME}${SERVICE_NAME_WEB}${SERVICE_NAME_NBADAPTER}\n\n使用以下命令管理服务\n启动服务systemctl start ${SERVICE_NAME}\n查看状态systemctl status ${SERVICE_NAME}" 14 60
}
# ----------- 主执行流程 -----------
# 检查root权限
[[ $(id -u) -ne 0 ]] && {
echo -e "${RED}请使用root用户运行此脚本${RESET}"
exit 1
}
# 如果已安装显示菜单,并检查协议是否更新
if check_installed; then
load_install_info
check_eula
show_menu
else
run_installation
# 安装完成后询问是否启动
if whiptail --title "安装完成" --yesno "是否立即启动MaiCore服务" 10 60; then
systemctl start ${SERVICE_NAME}
whiptail --msgbox "✅ 服务已启动!\n使用 systemctl status ${SERVICE_NAME} 查看状态" 10 60
fi
fi