refactor: 清理项目结构并修复类型注解问题
修复 SQLAlchemy 模型的类型注解,使用 Mapped 类型避免类型检查器错误 - 修正异步数据库操作中缺少 await 的问题 - 优化反注入统计系统的数值字段处理逻辑 - 添加缺失的导入语句修复模块依赖问题
This commit is contained in:
2
bot.py
2
bot.py
@@ -543,7 +543,7 @@ class MaiBotMain:
|
|||||||
"""设置时区"""
|
"""设置时区"""
|
||||||
try:
|
try:
|
||||||
if platform.system().lower() != "windows":
|
if platform.system().lower() != "windows":
|
||||||
time.tzset()
|
time.tzset() # type: ignore
|
||||||
logger.info("时区设置完成")
|
logger.info("时区设置完成")
|
||||||
else:
|
else:
|
||||||
logger.info("Windows系统,跳过时区设置")
|
logger.info("Windows系统,跳过时区设置")
|
||||||
|
|||||||
@@ -1,10 +0,0 @@
|
|||||||
from src.plugin_system.base.plugin_metadata import PluginMetadata
|
|
||||||
|
|
||||||
__plugin_meta__ = PluginMetadata(
|
|
||||||
name="Echo Example Plugin",
|
|
||||||
description="An example plugin that echoes messages.",
|
|
||||||
usage="!echo [message]",
|
|
||||||
version="1.0.0",
|
|
||||||
author="Your Name",
|
|
||||||
license="MIT",
|
|
||||||
)
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
{
|
|
||||||
"manifest_version": 1,
|
|
||||||
"format_version": "1.0.0",
|
|
||||||
"name": "Echo 示例插件",
|
|
||||||
"description": "展示增强命令系统的Echo命令示例插件",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"author": {
|
|
||||||
"name": "MoFox"
|
|
||||||
},
|
|
||||||
"license": "MIT",
|
|
||||||
"keywords": ["echo", "example", "command"],
|
|
||||||
"categories": ["utility", "example"],
|
|
||||||
"host_application": {
|
|
||||||
"name": "MaiBot",
|
|
||||||
"min_version": "0.10.0"
|
|
||||||
},
|
|
||||||
"entry_points": {
|
|
||||||
"main": "plugin.py"
|
|
||||||
},
|
|
||||||
"plugin_info": {
|
|
||||||
"is_built_in": false,
|
|
||||||
"plugin_type": "example",
|
|
||||||
"components": [
|
|
||||||
{
|
|
||||||
"type": "command",
|
|
||||||
"name": "echo",
|
|
||||||
"description": "回显命令,支持别名 say, repeat"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "command",
|
|
||||||
"name": "hello",
|
|
||||||
"description": "问候命令,支持别名 hi, greet"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "command",
|
|
||||||
"name": "info",
|
|
||||||
"description": "显示插件信息,支持别名 about"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "command",
|
|
||||||
"name": "test",
|
|
||||||
"description": "测试命令,展示参数解析功能"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"features": [
|
|
||||||
"增强命令系统示例",
|
|
||||||
"无需正则表达式的命令定义",
|
|
||||||
"命令别名支持",
|
|
||||||
"参数解析功能",
|
|
||||||
"聊天类型限制"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,204 +0,0 @@
|
|||||||
"""
|
|
||||||
Echo 示例插件
|
|
||||||
|
|
||||||
展示增强命令系统的使用方法
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from src.plugin_system import (
|
|
||||||
BasePlugin,
|
|
||||||
ChatType,
|
|
||||||
CommandArgs,
|
|
||||||
ConfigField,
|
|
||||||
PlusCommand,
|
|
||||||
PlusCommandInfo,
|
|
||||||
register_plugin,
|
|
||||||
)
|
|
||||||
from src.plugin_system.base.component_types import PythonDependency
|
|
||||||
|
|
||||||
|
|
||||||
class EchoCommand(PlusCommand):
|
|
||||||
"""Echo命令示例"""
|
|
||||||
|
|
||||||
command_name = "echo"
|
|
||||||
command_description = "回显命令"
|
|
||||||
command_aliases = ["say", "repeat"]
|
|
||||||
priority = 5
|
|
||||||
chat_type_allow = ChatType.ALL
|
|
||||||
intercept_message = True
|
|
||||||
|
|
||||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
|
||||||
"""执行echo命令"""
|
|
||||||
if args.is_empty():
|
|
||||||
await self.send_text("❓ 请提供要回显的内容\n用法: /echo <内容>")
|
|
||||||
return True, "参数不足", True
|
|
||||||
|
|
||||||
content = args.get_raw()
|
|
||||||
|
|
||||||
# 检查内容长度限制
|
|
||||||
max_length = self.get_config("commands.max_content_length", 500)
|
|
||||||
if len(content) > max_length:
|
|
||||||
await self.send_text(f"❌ 内容过长,最大允许 {max_length} 字符")
|
|
||||||
return True, "内容过长", True
|
|
||||||
|
|
||||||
await self.send_text(f"🔊 {content}")
|
|
||||||
|
|
||||||
return True, "Echo命令执行成功", True
|
|
||||||
|
|
||||||
|
|
||||||
class HelloCommand(PlusCommand):
|
|
||||||
"""Hello命令示例"""
|
|
||||||
|
|
||||||
command_name = "hello"
|
|
||||||
command_description = "问候命令"
|
|
||||||
command_aliases = ["hi", "greet"]
|
|
||||||
priority = 3
|
|
||||||
chat_type_allow = ChatType.ALL
|
|
||||||
intercept_message = True
|
|
||||||
|
|
||||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
|
||||||
"""执行hello命令"""
|
|
||||||
if args.is_empty():
|
|
||||||
await self.send_text("👋 Hello! 很高兴见到你!")
|
|
||||||
else:
|
|
||||||
name = args.get_first()
|
|
||||||
await self.send_text(f"👋 Hello, {name}! 很高兴见到你!")
|
|
||||||
|
|
||||||
return True, "Hello命令执行成功", True
|
|
||||||
|
|
||||||
|
|
||||||
class InfoCommand(PlusCommand):
|
|
||||||
"""信息命令示例"""
|
|
||||||
|
|
||||||
command_name = "info"
|
|
||||||
command_description = "显示插件信息"
|
|
||||||
command_aliases = ["about"]
|
|
||||||
priority = 1
|
|
||||||
chat_type_allow = ChatType.ALL
|
|
||||||
intercept_message = True
|
|
||||||
|
|
||||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
|
||||||
"""执行info命令"""
|
|
||||||
info_text = (
|
|
||||||
"📋 Echo 示例插件信息\n"
|
|
||||||
"版本: 1.0.0\n"
|
|
||||||
"作者: MaiBot Team\n"
|
|
||||||
"描述: 展示增强命令系统的使用方法\n\n"
|
|
||||||
"🎯 可用命令:\n"
|
|
||||||
"• /echo|/say|/repeat <内容> - 回显内容\n"
|
|
||||||
"• /hello|/hi|/greet [名字] - 问候\n"
|
|
||||||
"• /info|/about - 显示此信息\n"
|
|
||||||
"• /test <子命令> [参数] - 测试各种功能"
|
|
||||||
)
|
|
||||||
await self.send_text(info_text)
|
|
||||||
|
|
||||||
return True, "Info命令执行成功", True
|
|
||||||
|
|
||||||
|
|
||||||
class TestCommand(PlusCommand):
|
|
||||||
"""测试命令示例,展示参数解析功能"""
|
|
||||||
|
|
||||||
command_name = "test"
|
|
||||||
command_description = "测试命令,展示参数解析功能"
|
|
||||||
command_aliases = ["t"]
|
|
||||||
priority = 2
|
|
||||||
chat_type_allow = ChatType.ALL
|
|
||||||
intercept_message = True
|
|
||||||
|
|
||||||
async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]:
|
|
||||||
"""执行test命令"""
|
|
||||||
if args.is_empty():
|
|
||||||
help_text = (
|
|
||||||
"🧪 测试命令帮助\n"
|
|
||||||
"用法: /test <子命令> [参数]\n\n"
|
|
||||||
"可用子命令:\n"
|
|
||||||
"• args - 显示参数解析结果\n"
|
|
||||||
"• flags - 测试标志参数\n"
|
|
||||||
"• count - 计算参数数量\n"
|
|
||||||
"• join - 连接所有参数"
|
|
||||||
)
|
|
||||||
await self.send_text(help_text)
|
|
||||||
return True, "显示帮助", True
|
|
||||||
|
|
||||||
subcommand = args.get_first().lower()
|
|
||||||
|
|
||||||
if subcommand == "args":
|
|
||||||
result = (
|
|
||||||
f"🔍 参数解析结果:\n"
|
|
||||||
f"原始字符串: '{args.get_raw()}'\n"
|
|
||||||
f"解析后参数: {args.get_args()}\n"
|
|
||||||
f"参数数量: {args.count()}\n"
|
|
||||||
f"第一个参数: '{args.get_first()}'\n"
|
|
||||||
f"剩余参数: '{args.get_remaining()}'"
|
|
||||||
)
|
|
||||||
await self.send_text(result)
|
|
||||||
|
|
||||||
elif subcommand == "flags":
|
|
||||||
result = (
|
|
||||||
f"🏴 标志测试结果:\n"
|
|
||||||
f"包含 --verbose: {args.has_flag('--verbose')}\n"
|
|
||||||
f"包含 -v: {args.has_flag('-v')}\n"
|
|
||||||
f"--output 的值: '{args.get_flag_value('--output', '未设置')}'\n"
|
|
||||||
f"--name 的值: '{args.get_flag_value('--name', '未设置')}'"
|
|
||||||
)
|
|
||||||
await self.send_text(result)
|
|
||||||
|
|
||||||
elif subcommand == "count":
|
|
||||||
count = args.count() - 1 # 减去子命令本身
|
|
||||||
await self.send_text(f"📊 除子命令外的参数数量: {count}")
|
|
||||||
|
|
||||||
elif subcommand == "join":
|
|
||||||
remaining = args.get_remaining()
|
|
||||||
if remaining:
|
|
||||||
await self.send_text(f"🔗 连接结果: {remaining}")
|
|
||||||
else:
|
|
||||||
await self.send_text("❌ 没有可连接的参数")
|
|
||||||
|
|
||||||
else:
|
|
||||||
await self.send_text(f"❓ 未知的子命令: {subcommand}")
|
|
||||||
|
|
||||||
return True, "Test命令执行成功", True
|
|
||||||
|
|
||||||
|
|
||||||
@register_plugin
|
|
||||||
class EchoExamplePlugin(BasePlugin):
|
|
||||||
"""Echo 示例插件"""
|
|
||||||
|
|
||||||
plugin_name: str = "echo_example_plugin"
|
|
||||||
enable_plugin: bool = True
|
|
||||||
dependencies: list[str] = []
|
|
||||||
python_dependencies: list[Union[str, "PythonDependency"]] = []
|
|
||||||
config_file_name: str = "config.toml"
|
|
||||||
|
|
||||||
config_schema = {
|
|
||||||
"plugin": {
|
|
||||||
"enabled": ConfigField(bool, default=True, description="是否启用插件"),
|
|
||||||
"config_version": ConfigField(str, default="1.0.0", description="配置文件版本"),
|
|
||||||
},
|
|
||||||
"commands": {
|
|
||||||
"echo_enabled": ConfigField(bool, default=True, description="是否启用 Echo 命令"),
|
|
||||||
"cooldown": ConfigField(int, default=0, description="命令冷却时间(秒)"),
|
|
||||||
"max_content_length": ConfigField(int, default=500, description="最大回显内容长度"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
config_section_descriptions = {
|
|
||||||
"plugin": "插件基本配置",
|
|
||||||
"commands": "命令相关配置",
|
|
||||||
}
|
|
||||||
|
|
||||||
def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type]]:
|
|
||||||
"""获取插件组件"""
|
|
||||||
components = []
|
|
||||||
|
|
||||||
if self.get_config("plugin.enabled", True):
|
|
||||||
# 添加所有命令,直接使用PlusCommand类
|
|
||||||
if self.get_config("commands.echo_enabled", True):
|
|
||||||
components.append((EchoCommand.get_plus_command_info(), EchoCommand))
|
|
||||||
|
|
||||||
components.append((HelloCommand.get_plus_command_info(), HelloCommand))
|
|
||||||
components.append((InfoCommand.get_plus_command_info(), InfoCommand))
|
|
||||||
components.append((TestCommand.get_plus_command_info(), TestCommand))
|
|
||||||
|
|
||||||
return components
|
|
||||||
@@ -107,7 +107,7 @@ class HelloWorldPlugin(BasePlugin):
|
|||||||
components.append((GetSystemInfoTool.get_tool_info(), GetSystemInfoTool))
|
components.append((GetSystemInfoTool.get_tool_info(), GetSystemInfoTool))
|
||||||
|
|
||||||
if self.get_config("components.hello_command_enabled", True):
|
if self.get_config("components.hello_command_enabled", True):
|
||||||
components.append((HelloCommand.get_command_info(), HelloCommand))
|
components.append((HelloCommand.get_plus_command_info(), HelloCommand))
|
||||||
|
|
||||||
if self.get_config("components.random_emoji_action_enabled", True):
|
if self.get_config("components.random_emoji_action_enabled", True):
|
||||||
components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction))
|
components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction))
|
||||||
|
|||||||
31
pyrightconfig.json
Normal file
31
pyrightconfig.json
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
{
|
||||||
|
"$schema": "https://raw.githubusercontent.com/microsoft/pyright/main/packages/vscode-pyright/schemas/pyrightconfig.schema.json",
|
||||||
|
"include": [
|
||||||
|
"src",
|
||||||
|
"bot.py",
|
||||||
|
"__main__.py"
|
||||||
|
],
|
||||||
|
"exclude": [
|
||||||
|
"**/__pycache__",
|
||||||
|
"data",
|
||||||
|
"logs",
|
||||||
|
"tests",
|
||||||
|
"target",
|
||||||
|
"*.egg-info"
|
||||||
|
],
|
||||||
|
"typeCheckingMode": "standard",
|
||||||
|
"reportMissingImports": false,
|
||||||
|
"reportMissingTypeStubs": false,
|
||||||
|
"reportMissingModuleSource": false,
|
||||||
|
"diagnosticSeverityOverrides": {
|
||||||
|
"reportMissingImports": "none",
|
||||||
|
"reportMissingTypeStubs": "none",
|
||||||
|
"reportMissingModuleSource": "none"
|
||||||
|
},
|
||||||
|
"pythonVersion": "3.12",
|
||||||
|
"venvPath": ".",
|
||||||
|
"venv": ".venv",
|
||||||
|
"executionEnvironments": [
|
||||||
|
{"root": "src"}
|
||||||
|
]
|
||||||
|
}
|
||||||
220
scripts/convert_sqlalchemy_models.py
Normal file
220
scripts/convert_sqlalchemy_models.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""批量将经典 SQLAlchemy 模型字段写法
|
||||||
|
|
||||||
|
field = Column(Integer, nullable=False, default=0)
|
||||||
|
|
||||||
|
转换为 2.0 推荐的带类型注解写法:
|
||||||
|
|
||||||
|
field: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
|
||||||
|
脚本特点:
|
||||||
|
1. 仅处理指定文件(默认: src/common/database/sqlalchemy_models.py)。
|
||||||
|
2. 自动识别多行 Column(...) 定义 (括号未闭合会继续合并)。
|
||||||
|
3. 已经是 Mapped 写法的行会跳过。
|
||||||
|
4. 根据类型名 (Integer / Float / Boolean / Text / String / DateTime / get_string_field) 推断 Python 类型。
|
||||||
|
5. nullable=True 时自动添加 "| None"。
|
||||||
|
6. 保留 Column(...) 内的原始参数顺序与内容。
|
||||||
|
7. 生成 .bak 备份文件,确保可回滚。
|
||||||
|
8. 支持 --dry-run 查看差异,不写回文件。
|
||||||
|
|
||||||
|
局限/注意:
|
||||||
|
- 简单基于正则/括号计数,不解析完整 AST;非常规写法(例如变量中构造 Column 再赋值)不会处理。
|
||||||
|
- 复杂工厂/自定义类型未在映射表中的,统一映射为 Any。
|
||||||
|
- 不自动添加 from __future__ import annotations;如需 Python 3.10 以下更先进类型表达式,请自行处理。
|
||||||
|
|
||||||
|
使用方式: (在项目根目录执行)
|
||||||
|
|
||||||
|
python scripts/convert_sqlalchemy_models.py \
|
||||||
|
--file src/common/database/sqlalchemy_models.py --dry-run
|
||||||
|
|
||||||
|
确认无误后去掉 --dry-run 真实写入。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
TYPE_MAP = {
|
||||||
|
"Integer": "int",
|
||||||
|
"Float": "float",
|
||||||
|
"Boolean": "bool",
|
||||||
|
"Text": "str",
|
||||||
|
"String": "str",
|
||||||
|
"DateTime": "datetime.datetime",
|
||||||
|
# 自定义帮助函数 get_string_field(...) 也返回字符串类型
|
||||||
|
"get_string_field": "str",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
COLUMN_ASSIGN_RE = re.compile(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=\s*Column\(")
|
||||||
|
ALREADY_MAPPED_RE = re.compile(r"^[ \t]*[A-Za-z_][A-Za-z0-9_]*\s*:\s*Mapped\[")
|
||||||
|
|
||||||
|
|
||||||
|
def detect_column_block(lines: list[str], start_index: int) -> tuple[int, int] | None:
|
||||||
|
"""检测从 start_index 开始的 Column(...) 语句跨越的行范围 (包含结束行)。
|
||||||
|
|
||||||
|
使用括号计数法处理多行。
|
||||||
|
返回 (start, end) 行号 (包含 end)。"""
|
||||||
|
line = lines[start_index]
|
||||||
|
if "Column(" not in line:
|
||||||
|
return None
|
||||||
|
open_parens = line.count("(") - line.count(")")
|
||||||
|
i = start_index
|
||||||
|
while open_parens > 0 and i + 1 < len(lines):
|
||||||
|
i += 1
|
||||||
|
l2 = lines[i]
|
||||||
|
open_parens += l2.count("(") - l2.count(")")
|
||||||
|
return (start_index, i)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_column_body(block_lines: list[str]) -> str:
|
||||||
|
"""提取 Column(...) 内部参数文本 (去掉首尾 Column( 和 最后一个 ) )。"""
|
||||||
|
joined = "\n".join(block_lines)
|
||||||
|
# 找到第一次出现 Column(
|
||||||
|
start_pos = joined.find("Column(")
|
||||||
|
if start_pos == -1:
|
||||||
|
return ""
|
||||||
|
inner = joined[start_pos + len("Column(") :]
|
||||||
|
# 去掉最后一个 ) —— 简单方式: 找到最后一个 ) 并截断
|
||||||
|
last_paren = inner.rfind(")")
|
||||||
|
if last_paren != -1:
|
||||||
|
inner = inner[:last_paren]
|
||||||
|
return inner.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def guess_python_type(column_body: str) -> str:
|
||||||
|
# 简单取第一个类型标识符 (去掉前导装饰/空格)
|
||||||
|
# 可能形式: Integer, Text, get_string_field(50), DateTime, Boolean
|
||||||
|
# 利用正则抓取第一个标识符
|
||||||
|
m = re.search(r"([A-Za-z_][A-Za-z0-9_]*)", column_body)
|
||||||
|
if not m:
|
||||||
|
return "Any"
|
||||||
|
type_token = m.group(1)
|
||||||
|
py_type = TYPE_MAP.get(type_token, "Any")
|
||||||
|
# nullable 检测
|
||||||
|
if "nullable=True" in column_body or "nullable = True" in column_body:
|
||||||
|
# 避免重复 Optional
|
||||||
|
if py_type != "Any" and not py_type.endswith(" | None"):
|
||||||
|
py_type = f"{py_type} | None"
|
||||||
|
elif py_type == "Any":
|
||||||
|
py_type = "Any | None"
|
||||||
|
return py_type
|
||||||
|
|
||||||
|
|
||||||
|
def convert_block(block_lines: list[str]) -> list[str]:
|
||||||
|
first_line = block_lines[0]
|
||||||
|
m_name = re.match(r"^(?P<indent>\s+)(?P<name>[A-Za-z_][A-Za-z0-9_]*)\s*=", first_line)
|
||||||
|
if not m_name:
|
||||||
|
return block_lines
|
||||||
|
indent = m_name.group("indent")
|
||||||
|
name = m_name.group("name")
|
||||||
|
body = extract_column_body(block_lines)
|
||||||
|
py_type = guess_python_type(body)
|
||||||
|
# 构造新的多行 mapped_column 写法
|
||||||
|
# 保留内部参数的换行缩进: 重新缩进为 indent + 4 空格 (延续原风格: 在 indent 基础上再加 4 空格)
|
||||||
|
inner_lines = body.split("\n")
|
||||||
|
if len(inner_lines) == 1:
|
||||||
|
new_line = f"{indent}{name}: Mapped[{py_type}] = mapped_column({inner_lines[0].strip()})\n"
|
||||||
|
return [new_line]
|
||||||
|
else:
|
||||||
|
# 多行情况
|
||||||
|
ind2 = indent + " "
|
||||||
|
rebuilt = [f"{indent}{name}: Mapped[{py_type}] = mapped_column(",]
|
||||||
|
for il in inner_lines:
|
||||||
|
if il.strip():
|
||||||
|
rebuilt.append(f"{ind2}{il.rstrip()}")
|
||||||
|
rebuilt.append(f"{indent})\n")
|
||||||
|
return [l + ("\n" if not l.endswith("\n") else "") for l in rebuilt]
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_imports(content: str) -> str:
|
||||||
|
if "Mapped," in content or "Mapped[" in content:
|
||||||
|
# 已经可能存在导入
|
||||||
|
if "from sqlalchemy.orm import Mapped, mapped_column" not in content:
|
||||||
|
# 简单插到第一个 import sqlalchemy 之后
|
||||||
|
lines = content.splitlines()
|
||||||
|
for i, line in enumerate(lines):
|
||||||
|
if "sqlalchemy" in line and line.startswith("from sqlalchemy"):
|
||||||
|
lines.insert(i + 1, "from sqlalchemy.orm import Mapped, mapped_column")
|
||||||
|
return "\n".join(lines)
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
def process_file(path: Path) -> tuple[str, str]:
|
||||||
|
original = path.read_text(encoding="utf-8")
|
||||||
|
lines = original.splitlines(keepends=True)
|
||||||
|
i = 0
|
||||||
|
out: list[str] = []
|
||||||
|
changed = 0
|
||||||
|
while i < len(lines):
|
||||||
|
line = lines[i]
|
||||||
|
# 跳过已是 Mapped 风格
|
||||||
|
if ALREADY_MAPPED_RE.match(line):
|
||||||
|
out.append(line)
|
||||||
|
i += 1
|
||||||
|
continue
|
||||||
|
if "= Column(" in line and re.match(r"^\s+[A-Za-z_][A-Za-z0-9_]*\s*=", line):
|
||||||
|
start, end = detect_column_block(lines, i) or (i, i)
|
||||||
|
block = lines[start : end + 1]
|
||||||
|
converted = convert_block(block)
|
||||||
|
out.extend(converted)
|
||||||
|
i = end + 1
|
||||||
|
# 如果转换结果与原始不同,计数
|
||||||
|
if "".join(converted) != "".join(block):
|
||||||
|
changed += 1
|
||||||
|
else:
|
||||||
|
out.append(line)
|
||||||
|
i += 1
|
||||||
|
new_content = "".join(out)
|
||||||
|
new_content = ensure_imports(new_content)
|
||||||
|
# 在文件末尾或头部预留统计信息打印(不写入文件,只返回)
|
||||||
|
return original, new_content if changed else original
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="批量转换 SQLAlchemy 模型字段为 2.0 Mapped 写法")
|
||||||
|
parser.add_argument("--file", default="src/common/database/sqlalchemy_models.py", help="目标模型文件")
|
||||||
|
parser.add_argument("--dry-run", action="store_true", help="仅显示差异,不写回")
|
||||||
|
parser.add_argument("--write", action="store_true", help="执行写回 (与 --dry-run 互斥)")
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
target = Path(args.file)
|
||||||
|
if not target.exists():
|
||||||
|
raise SystemExit(f"文件不存在: {target}")
|
||||||
|
|
||||||
|
original, new_content = process_file(target)
|
||||||
|
|
||||||
|
if original == new_content:
|
||||||
|
print("[INFO] 没有需要转换的内容或转换后无差异。")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 简单差异输出 (行对比)
|
||||||
|
if args.dry_run or not args.write:
|
||||||
|
print("[DRY-RUN] 以下为转换后预览 (仅显示不同段落):")
|
||||||
|
import difflib
|
||||||
|
|
||||||
|
diff = difflib.unified_diff(
|
||||||
|
original.splitlines(), new_content.splitlines(), fromfile="original", tofile="converted", lineterm=""
|
||||||
|
)
|
||||||
|
count = 0
|
||||||
|
for d in diff:
|
||||||
|
print(d)
|
||||||
|
count += 1
|
||||||
|
if count == 0:
|
||||||
|
print("[INFO] 差异为空 (可能未匹配到 Column 定义)。")
|
||||||
|
if not args.write:
|
||||||
|
print("\n未写回。若确认无误,添加 --write 执行替换。")
|
||||||
|
return
|
||||||
|
|
||||||
|
backup = target.with_suffix(target.suffix + ".bak")
|
||||||
|
shutil.copyfile(target, backup)
|
||||||
|
target.write_text(new_content, encoding="utf-8")
|
||||||
|
print(f"[DONE] 已写回: {target},备份文件: {backup.name}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__": # pragma: no cover
|
||||||
|
main()
|
||||||
@@ -1,284 +0,0 @@
|
|||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
# Add project root to Python path
|
|
||||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
|
||||||
sys.path.insert(0, project_root)
|
|
||||||
from src.common.database.database_model import Messages, ChatStreams # noqa
|
|
||||||
|
|
||||||
|
|
||||||
def get_chat_name(chat_id: str) -> str:
|
|
||||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
|
||||||
try:
|
|
||||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
|
||||||
if chat_stream is None:
|
|
||||||
return f"未知聊天 ({chat_id})"
|
|
||||||
|
|
||||||
if chat_stream.group_name:
|
|
||||||
return f"{chat_stream.group_name} ({chat_id})"
|
|
||||||
elif chat_stream.user_nickname:
|
|
||||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
|
||||||
else:
|
|
||||||
return f"未知聊天 ({chat_id})"
|
|
||||||
except Exception:
|
|
||||||
return f"查询失败 ({chat_id})"
|
|
||||||
|
|
||||||
|
|
||||||
def format_timestamp(timestamp: float) -> str:
|
|
||||||
"""Format timestamp to readable date string"""
|
|
||||||
try:
|
|
||||||
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
|
||||||
except (ValueError, OSError):
|
|
||||||
return "未知时间"
|
|
||||||
|
|
||||||
|
|
||||||
def calculate_interest_value_distribution(messages) -> dict[str, int]:
|
|
||||||
"""Calculate distribution of interest_value"""
|
|
||||||
distribution = {
|
|
||||||
"0.000-0.010": 0,
|
|
||||||
"0.010-0.050": 0,
|
|
||||||
"0.050-0.100": 0,
|
|
||||||
"0.100-0.500": 0,
|
|
||||||
"0.500-1.000": 0,
|
|
||||||
"1.000-2.000": 0,
|
|
||||||
"2.000-5.000": 0,
|
|
||||||
"5.000-10.000": 0,
|
|
||||||
"10.000+": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
for msg in messages:
|
|
||||||
if msg.interest_value is None or msg.interest_value == 0.0:
|
|
||||||
continue
|
|
||||||
|
|
||||||
value = float(msg.interest_value)
|
|
||||||
if value < 0.010:
|
|
||||||
distribution["0.000-0.010"] += 1
|
|
||||||
elif value < 0.050:
|
|
||||||
distribution["0.010-0.050"] += 1
|
|
||||||
elif value < 0.100:
|
|
||||||
distribution["0.050-0.100"] += 1
|
|
||||||
elif value < 0.500:
|
|
||||||
distribution["0.100-0.500"] += 1
|
|
||||||
elif value < 1.000:
|
|
||||||
distribution["0.500-1.000"] += 1
|
|
||||||
elif value < 2.000:
|
|
||||||
distribution["1.000-2.000"] += 1
|
|
||||||
elif value < 5.000:
|
|
||||||
distribution["2.000-5.000"] += 1
|
|
||||||
elif value < 10.000:
|
|
||||||
distribution["5.000-10.000"] += 1
|
|
||||||
else:
|
|
||||||
distribution["10.000+"] += 1
|
|
||||||
|
|
||||||
return distribution
|
|
||||||
|
|
||||||
|
|
||||||
def get_interest_value_stats(messages) -> dict[str, float]:
|
|
||||||
"""Calculate basic statistics for interest_value"""
|
|
||||||
values = [
|
|
||||||
float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0
|
|
||||||
]
|
|
||||||
|
|
||||||
if not values:
|
|
||||||
return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0}
|
|
||||||
|
|
||||||
values.sort()
|
|
||||||
count = len(values)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"count": count,
|
|
||||||
"min": min(values),
|
|
||||||
"max": max(values),
|
|
||||||
"avg": sum(values) / count,
|
|
||||||
"median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_available_chats() -> list[tuple[str, str, int]]:
|
|
||||||
"""Get all available chats with message counts"""
|
|
||||||
try:
|
|
||||||
# 获取所有有消息的chat_id
|
|
||||||
chat_counts = {}
|
|
||||||
for msg in Messages.select(Messages.chat_id).distinct():
|
|
||||||
chat_id = msg.chat_id
|
|
||||||
count = (
|
|
||||||
Messages.select()
|
|
||||||
.where(
|
|
||||||
(Messages.chat_id == chat_id)
|
|
||||||
& (Messages.interest_value.is_null(False))
|
|
||||||
& (Messages.interest_value != 0.0)
|
|
||||||
)
|
|
||||||
.count()
|
|
||||||
)
|
|
||||||
if count > 0:
|
|
||||||
chat_counts[chat_id] = count
|
|
||||||
|
|
||||||
# 获取聊天名称
|
|
||||||
result = []
|
|
||||||
for chat_id, count in chat_counts.items():
|
|
||||||
chat_name = get_chat_name(chat_id)
|
|
||||||
result.append((chat_id, chat_name, count))
|
|
||||||
|
|
||||||
# 按消息数量排序
|
|
||||||
result.sort(key=lambda x: x[2], reverse=True)
|
|
||||||
return result
|
|
||||||
except Exception as e:
|
|
||||||
print(f"获取聊天列表失败: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def get_time_range_input() -> tuple[float | None, float | None]:
|
|
||||||
"""Get time range input from user"""
|
|
||||||
print("\n时间范围选择:")
|
|
||||||
print("1. 最近1天")
|
|
||||||
print("2. 最近3天")
|
|
||||||
print("3. 最近7天")
|
|
||||||
print("4. 最近30天")
|
|
||||||
print("5. 自定义时间范围")
|
|
||||||
print("6. 不限制时间")
|
|
||||||
|
|
||||||
choice = input("请选择时间范围 (1-6): ").strip()
|
|
||||||
|
|
||||||
now = time.time()
|
|
||||||
|
|
||||||
if choice == "1":
|
|
||||||
return now - 24 * 3600, now
|
|
||||||
elif choice == "2":
|
|
||||||
return now - 3 * 24 * 3600, now
|
|
||||||
elif choice == "3":
|
|
||||||
return now - 7 * 24 * 3600, now
|
|
||||||
elif choice == "4":
|
|
||||||
return now - 30 * 24 * 3600, now
|
|
||||||
elif choice == "5":
|
|
||||||
print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
|
||||||
start_str = input().strip()
|
|
||||||
print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):")
|
|
||||||
end_str = input().strip()
|
|
||||||
|
|
||||||
try:
|
|
||||||
start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
|
||||||
end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp()
|
|
||||||
return start_time, end_time
|
|
||||||
except ValueError:
|
|
||||||
print("时间格式错误,将不限制时间范围")
|
|
||||||
return None, None
|
|
||||||
else:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_interest_values(
|
|
||||||
chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None
|
|
||||||
) -> None:
|
|
||||||
"""Analyze interest values with optional filters"""
|
|
||||||
|
|
||||||
# 构建查询条件
|
|
||||||
query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0))
|
|
||||||
|
|
||||||
if chat_id:
|
|
||||||
query = query.where(Messages.chat_id == chat_id)
|
|
||||||
|
|
||||||
if start_time:
|
|
||||||
query = query.where(Messages.time >= start_time)
|
|
||||||
|
|
||||||
if end_time:
|
|
||||||
query = query.where(Messages.time <= end_time)
|
|
||||||
|
|
||||||
messages = list(query)
|
|
||||||
|
|
||||||
if not messages:
|
|
||||||
print("没有找到符合条件的消息")
|
|
||||||
return
|
|
||||||
|
|
||||||
# 计算统计信息
|
|
||||||
distribution = calculate_interest_value_distribution(messages)
|
|
||||||
stats = get_interest_value_stats(messages)
|
|
||||||
|
|
||||||
# 显示结果
|
|
||||||
print("\n=== Interest Value 分析结果 ===")
|
|
||||||
if chat_id:
|
|
||||||
print(f"聊天: {get_chat_name(chat_id)}")
|
|
||||||
else:
|
|
||||||
print("聊天: 全部聊天")
|
|
||||||
|
|
||||||
if start_time and end_time:
|
|
||||||
print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}")
|
|
||||||
elif start_time:
|
|
||||||
print(f"时间范围: {format_timestamp(start_time)} 之后")
|
|
||||||
elif end_time:
|
|
||||||
print(f"时间范围: {format_timestamp(end_time)} 之前")
|
|
||||||
else:
|
|
||||||
print("时间范围: 不限制")
|
|
||||||
|
|
||||||
print("\n基本统计:")
|
|
||||||
print(f"有效消息数量: {stats['count']} (排除null和0值)")
|
|
||||||
print(f"最小值: {stats['min']:.3f}")
|
|
||||||
print(f"最大值: {stats['max']:.3f}")
|
|
||||||
print(f"平均值: {stats['avg']:.3f}")
|
|
||||||
print(f"中位数: {stats['median']:.3f}")
|
|
||||||
|
|
||||||
print("\nInterest Value 分布:")
|
|
||||||
total = stats["count"]
|
|
||||||
for range_name, count in distribution.items():
|
|
||||||
if count > 0:
|
|
||||||
percentage = count / total * 100
|
|
||||||
print(f"{range_name}: {count} ({percentage:.2f}%)")
|
|
||||||
|
|
||||||
|
|
||||||
def interactive_menu() -> None:
|
|
||||||
"""Interactive menu for interest value analysis"""
|
|
||||||
|
|
||||||
while True:
|
|
||||||
print("\n" + "=" * 50)
|
|
||||||
print("Interest Value 分析工具")
|
|
||||||
print("=" * 50)
|
|
||||||
print("1. 分析全部聊天")
|
|
||||||
print("2. 选择特定聊天分析")
|
|
||||||
print("q. 退出")
|
|
||||||
|
|
||||||
choice = input("\n请选择分析模式 (1-2, q): ").strip()
|
|
||||||
|
|
||||||
if choice.lower() == "q":
|
|
||||||
print("再见!")
|
|
||||||
break
|
|
||||||
|
|
||||||
chat_id = None
|
|
||||||
|
|
||||||
if choice == "2":
|
|
||||||
# 显示可用的聊天列表
|
|
||||||
chats = get_available_chats()
|
|
||||||
if not chats:
|
|
||||||
print("没有找到有interest_value数据的聊天")
|
|
||||||
continue
|
|
||||||
|
|
||||||
print(f"\n可用的聊天 (共{len(chats)}个):")
|
|
||||||
for i, (_cid, name, count) in enumerate(chats, 1):
|
|
||||||
print(f"{i}. {name} ({count}条有效消息)")
|
|
||||||
|
|
||||||
try:
|
|
||||||
chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip())
|
|
||||||
if 1 <= chat_choice <= len(chats):
|
|
||||||
chat_id = chats[chat_choice - 1][0]
|
|
||||||
else:
|
|
||||||
print("无效选择")
|
|
||||||
continue
|
|
||||||
except ValueError:
|
|
||||||
print("请输入有效数字")
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif choice != "1":
|
|
||||||
print("无效选择")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取时间范围
|
|
||||||
start_time, end_time = get_time_range_input()
|
|
||||||
|
|
||||||
# 执行分析
|
|
||||||
analyze_interest_values(chat_id, start_time, end_time)
|
|
||||||
|
|
||||||
input("\n按回车键继续...")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
interactive_menu()
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,239 +0,0 @@
|
|||||||
"""
|
|
||||||
插件Manifest管理命令行工具
|
|
||||||
|
|
||||||
提供插件manifest文件的创建、验证和管理功能
|
|
||||||
"""
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
from src.plugin_system.utils.manifest_utils import (
|
|
||||||
ManifestValidator,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加项目根目录到Python路径
|
|
||||||
project_root = Path(__file__).parent.parent.parent.parent
|
|
||||||
sys.path.insert(0, str(project_root))
|
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger("manifest_tool")
|
|
||||||
|
|
||||||
|
|
||||||
def create_minimal_manifest(plugin_dir: str, plugin_name: str, description: str = "", author: str = "") -> bool:
|
|
||||||
"""创建最小化的manifest文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_dir: 插件目录
|
|
||||||
plugin_name: 插件名称
|
|
||||||
description: 插件描述
|
|
||||||
author: 插件作者
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否创建成功
|
|
||||||
"""
|
|
||||||
manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
|
||||||
|
|
||||||
if os.path.exists(manifest_path):
|
|
||||||
print(f"❌ Manifest文件已存在: {manifest_path}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 创建最小化manifest
|
|
||||||
minimal_manifest = {
|
|
||||||
"manifest_version": 1,
|
|
||||||
"name": plugin_name,
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": description or f"{plugin_name}插件",
|
|
||||||
"author": {"name": author or "Unknown"},
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(orjson.dumps(minimal_manifest, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
|
||||||
print(f"✅ 已创建最小化manifest文件: {manifest_path}")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 创建manifest文件失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def create_complete_manifest(plugin_dir: str, plugin_name: str) -> bool:
|
|
||||||
"""创建完整的manifest模板文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_dir: 插件目录
|
|
||||||
plugin_name: 插件名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否创建成功
|
|
||||||
"""
|
|
||||||
manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
|
||||||
|
|
||||||
if os.path.exists(manifest_path):
|
|
||||||
print(f"❌ Manifest文件已存在: {manifest_path}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
# 创建完整模板
|
|
||||||
complete_manifest = {
|
|
||||||
"manifest_version": 1,
|
|
||||||
"name": plugin_name,
|
|
||||||
"version": "1.0.0",
|
|
||||||
"description": f"{plugin_name}插件描述",
|
|
||||||
"author": {"name": "插件作者", "url": "https://github.com/your-username"},
|
|
||||||
"license": "MIT",
|
|
||||||
"host_application": {"min_version": "1.0.0", "max_version": "4.0.0"},
|
|
||||||
"homepage_url": "https://github.com/your-repo",
|
|
||||||
"repository_url": "https://github.com/your-repo",
|
|
||||||
"keywords": ["keyword1", "keyword2"],
|
|
||||||
"categories": ["Category1"],
|
|
||||||
"default_locale": "zh-CN",
|
|
||||||
"locales_path": "_locales",
|
|
||||||
"plugin_info": {
|
|
||||||
"is_built_in": False,
|
|
||||||
"plugin_type": "general",
|
|
||||||
"components": [{"type": "action", "name": "sample_action", "description": "示例动作组件"}],
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(manifest_path, "w", encoding="utf-8") as f:
|
|
||||||
f.write(orjson.dumps(complete_manifest, option=orjson.OPT_INDENT_2).decode("utf-8"))
|
|
||||||
print(f"✅ 已创建完整manifest模板: {manifest_path}")
|
|
||||||
print("💡 请根据实际情况修改manifest文件中的内容")
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 创建manifest文件失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def validate_manifest_file(plugin_dir: str) -> bool:
|
|
||||||
"""验证manifest文件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
plugin_dir: 插件目录
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否验证通过
|
|
||||||
"""
|
|
||||||
manifest_path = os.path.join(plugin_dir, "_manifest.json")
|
|
||||||
|
|
||||||
if not os.path.exists(manifest_path):
|
|
||||||
print(f"❌ 未找到manifest文件: {manifest_path}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(manifest_path, encoding="utf-8") as f:
|
|
||||||
manifest_data = orjson.loads(f.read())
|
|
||||||
|
|
||||||
validator = ManifestValidator()
|
|
||||||
is_valid = validator.validate_manifest(manifest_data)
|
|
||||||
|
|
||||||
# 显示验证结果
|
|
||||||
print("📋 Manifest验证结果:")
|
|
||||||
print(validator.get_validation_report())
|
|
||||||
|
|
||||||
if is_valid:
|
|
||||||
print("✅ Manifest文件验证通过")
|
|
||||||
else:
|
|
||||||
print("❌ Manifest文件验证失败")
|
|
||||||
|
|
||||||
return is_valid
|
|
||||||
|
|
||||||
except orjson.JSONDecodeError as e:
|
|
||||||
print(f"❌ Manifest文件格式错误: {e}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 验证过程中发生错误: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def scan_plugins_without_manifest(root_dir: str) -> None:
|
|
||||||
"""扫描缺少manifest文件的插件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
root_dir: 扫描的根目录
|
|
||||||
"""
|
|
||||||
print(f"🔍 扫描目录: {root_dir}")
|
|
||||||
|
|
||||||
plugins_without_manifest = []
|
|
||||||
|
|
||||||
for root, dirs, files in os.walk(root_dir):
|
|
||||||
# 跳过隐藏目录和__pycache__
|
|
||||||
dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"]
|
|
||||||
|
|
||||||
# 检查是否包含plugin.py文件(标识为插件目录)
|
|
||||||
if "plugin.py" in files:
|
|
||||||
manifest_path = os.path.join(root, "_manifest.json")
|
|
||||||
if not os.path.exists(manifest_path):
|
|
||||||
plugins_without_manifest.append(root)
|
|
||||||
|
|
||||||
if plugins_without_manifest:
|
|
||||||
print(f"❌ 发现 {len(plugins_without_manifest)} 个插件缺少manifest文件:")
|
|
||||||
for plugin_dir in plugins_without_manifest:
|
|
||||||
plugin_name = os.path.basename(plugin_dir)
|
|
||||||
print(f" - {plugin_name}: {plugin_dir}")
|
|
||||||
print("💡 使用 'python manifest_tool.py create-minimal <插件目录>' 创建manifest文件")
|
|
||||||
else:
|
|
||||||
print("✅ 所有插件都有manifest文件")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主函数"""
|
|
||||||
parser = argparse.ArgumentParser(description="插件Manifest管理工具")
|
|
||||||
subparsers = parser.add_subparsers(dest="command", help="可用命令")
|
|
||||||
|
|
||||||
# 创建最小化manifest命令
|
|
||||||
create_minimal_parser = subparsers.add_parser("create-minimal", help="创建最小化manifest文件")
|
|
||||||
create_minimal_parser.add_argument("plugin_dir", help="插件目录路径")
|
|
||||||
create_minimal_parser.add_argument("--name", help="插件名称")
|
|
||||||
create_minimal_parser.add_argument("--description", help="插件描述")
|
|
||||||
create_minimal_parser.add_argument("--author", help="插件作者")
|
|
||||||
|
|
||||||
# 创建完整manifest命令
|
|
||||||
create_complete_parser = subparsers.add_parser("create-complete", help="创建完整manifest模板")
|
|
||||||
create_complete_parser.add_argument("plugin_dir", help="插件目录路径")
|
|
||||||
create_complete_parser.add_argument("--name", help="插件名称")
|
|
||||||
|
|
||||||
# 验证manifest命令
|
|
||||||
validate_parser = subparsers.add_parser("validate", help="验证manifest文件")
|
|
||||||
validate_parser.add_argument("plugin_dir", help="插件目录路径")
|
|
||||||
|
|
||||||
# 扫描插件命令
|
|
||||||
scan_parser = subparsers.add_parser("scan", help="扫描缺少manifest的插件")
|
|
||||||
scan_parser.add_argument("root_dir", help="扫描的根目录路径")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
if not args.command:
|
|
||||||
parser.print_help()
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
if args.command == "create-minimal":
|
|
||||||
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
|
|
||||||
success = create_minimal_manifest(args.plugin_dir, plugin_name, args.description or "", args.author or "")
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
|
|
||||||
elif args.command == "create-complete":
|
|
||||||
plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir))
|
|
||||||
success = create_complete_manifest(args.plugin_dir, plugin_name)
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
|
|
||||||
elif args.command == "validate":
|
|
||||||
success = validate_manifest_file(args.plugin_dir)
|
|
||||||
sys.exit(0 if success else 1)
|
|
||||||
|
|
||||||
elif args.command == "scan":
|
|
||||||
scan_plugins_without_manifest(args.root_dir)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
print(f"❌ 执行命令时发生错误: {e}")
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,922 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
# import time
|
|
||||||
import pickle
|
|
||||||
import sys # 新增系统模块导入
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import orjson
|
|
||||||
|
|
||||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from peewee import Field, IntegrityError, Model
|
|
||||||
from pymongo import MongoClient
|
|
||||||
from pymongo.errors import ConnectionFailure
|
|
||||||
|
|
||||||
# Rich 进度条和显示组件
|
|
||||||
from rich.console import Console
|
|
||||||
from rich.panel import Panel
|
|
||||||
from rich.progress import (
|
|
||||||
BarColumn,
|
|
||||||
Progress,
|
|
||||||
SpinnerColumn,
|
|
||||||
TaskProgressColumn,
|
|
||||||
TextColumn,
|
|
||||||
TimeElapsedColumn,
|
|
||||||
TimeRemainingColumn,
|
|
||||||
)
|
|
||||||
from rich.table import Table
|
|
||||||
|
|
||||||
# from rich.text import Text
|
|
||||||
from src.common.database.database import db
|
|
||||||
from src.common.database.sqlalchemy_models import (
|
|
||||||
ChatStreams,
|
|
||||||
Emoji,
|
|
||||||
GraphEdges,
|
|
||||||
GraphNodes,
|
|
||||||
ImageDescriptions,
|
|
||||||
Images,
|
|
||||||
Knowledges,
|
|
||||||
Messages,
|
|
||||||
PersonInfo,
|
|
||||||
ThinkingLog,
|
|
||||||
)
|
|
||||||
from src.common.logger import get_logger
|
|
||||||
|
|
||||||
logger = get_logger("mongodb_to_sqlite")
|
|
||||||
|
|
||||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MigrationConfig:
|
|
||||||
"""迁移配置类"""
|
|
||||||
|
|
||||||
mongo_collection: str
|
|
||||||
target_model: type[Model]
|
|
||||||
field_mapping: dict[str, str]
|
|
||||||
batch_size: int = 500
|
|
||||||
enable_validation: bool = True
|
|
||||||
skip_duplicates: bool = True
|
|
||||||
unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段
|
|
||||||
|
|
||||||
|
|
||||||
# 数据验证相关类已移除 - 用户要求不要数据验证
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MigrationCheckpoint:
|
|
||||||
"""迁移断点数据"""
|
|
||||||
|
|
||||||
collection_name: str
|
|
||||||
processed_count: int
|
|
||||||
last_processed_id: Any
|
|
||||||
timestamp: datetime
|
|
||||||
batch_errors: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MigrationStats:
|
|
||||||
"""迁移统计信息"""
|
|
||||||
|
|
||||||
total_documents: int = 0
|
|
||||||
processed_count: int = 0
|
|
||||||
success_count: int = 0
|
|
||||||
error_count: int = 0
|
|
||||||
skipped_count: int = 0
|
|
||||||
duplicate_count: int = 0
|
|
||||||
validation_errors: int = 0
|
|
||||||
batch_insert_count: int = 0
|
|
||||||
errors: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
start_time: datetime | None = None
|
|
||||||
end_time: datetime | None = None
|
|
||||||
|
|
||||||
def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None):
|
|
||||||
"""添加错误记录"""
|
|
||||||
self.errors.append(
|
|
||||||
{"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data}
|
|
||||||
)
|
|
||||||
self.error_count += 1
|
|
||||||
|
|
||||||
def add_validation_error(self, doc_id: Any, field: str, error: str):
|
|
||||||
"""添加验证错误"""
|
|
||||||
self.add_error(doc_id, f"验证失败 - {field}: {error}")
|
|
||||||
self.validation_errors += 1
|
|
||||||
|
|
||||||
|
|
||||||
class MongoToSQLiteMigrator:
|
|
||||||
"""MongoDB到SQLite数据迁移器 - 使用Peewee ORM"""
|
|
||||||
|
|
||||||
def __init__(self, mongo_uri: str | None = None, database_name: str | None = None):
|
|
||||||
self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot")
|
|
||||||
self.mongo_uri = mongo_uri or self._build_mongo_uri()
|
|
||||||
self.mongo_client: MongoClient | None = None
|
|
||||||
self.mongo_db = None
|
|
||||||
|
|
||||||
# 迁移配置
|
|
||||||
self.migration_configs = self._initialize_migration_configs()
|
|
||||||
|
|
||||||
# 进度条控制台
|
|
||||||
self.console = Console()
|
|
||||||
# 检查点目录
|
|
||||||
self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints"))
|
|
||||||
self.checkpoint_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
# 验证规则已禁用
|
|
||||||
self.validation_rules = self._initialize_validation_rules()
|
|
||||||
|
|
||||||
def _build_mongo_uri(self) -> str:
|
|
||||||
"""构建MongoDB连接URI"""
|
|
||||||
if mongo_uri := os.getenv("MONGODB_URI"):
|
|
||||||
return mongo_uri
|
|
||||||
|
|
||||||
user = os.getenv("MONGODB_USER")
|
|
||||||
password = os.getenv("MONGODB_PASS")
|
|
||||||
host = os.getenv("MONGODB_HOST", "localhost")
|
|
||||||
port = os.getenv("MONGODB_PORT", "27017")
|
|
||||||
auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin")
|
|
||||||
|
|
||||||
if user and password:
|
|
||||||
return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}"
|
|
||||||
else:
|
|
||||||
return f"mongodb://{host}:{port}/{self.database_name}"
|
|
||||||
|
|
||||||
def _initialize_migration_configs(self) -> list[MigrationConfig]:
|
|
||||||
"""初始化迁移配置"""
|
|
||||||
return [ # 表情包迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="emoji",
|
|
||||||
target_model=Emoji,
|
|
||||||
field_mapping={
|
|
||||||
"full_path": "full_path",
|
|
||||||
"format": "format",
|
|
||||||
"hash": "emoji_hash",
|
|
||||||
"description": "description",
|
|
||||||
"emotion": "emotion",
|
|
||||||
"usage_count": "usage_count",
|
|
||||||
"last_used_time": "last_used_time",
|
|
||||||
# record_time字段将在转换时自动设置为当前时间
|
|
||||||
},
|
|
||||||
enable_validation=False, # 禁用数据验证
|
|
||||||
unique_fields=["full_path", "emoji_hash"],
|
|
||||||
),
|
|
||||||
# 聊天流迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="chat_streams",
|
|
||||||
target_model=ChatStreams,
|
|
||||||
field_mapping={
|
|
||||||
"stream_id": "stream_id",
|
|
||||||
"create_time": "create_time",
|
|
||||||
"group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。
|
|
||||||
"group_info.group_id": "group_id", # 同上
|
|
||||||
"group_info.group_name": "group_name", # 同上
|
|
||||||
"last_active_time": "last_active_time",
|
|
||||||
"platform": "platform",
|
|
||||||
"user_info.platform": "user_platform",
|
|
||||||
"user_info.user_id": "user_id",
|
|
||||||
"user_info.user_nickname": "user_nickname",
|
|
||||||
"user_info.user_cardname": "user_cardname",
|
|
||||||
},
|
|
||||||
enable_validation=False, # 禁用数据验证
|
|
||||||
unique_fields=["stream_id"],
|
|
||||||
),
|
|
||||||
# 消息迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="messages",
|
|
||||||
target_model=Messages,
|
|
||||||
field_mapping={
|
|
||||||
"message_id": "message_id",
|
|
||||||
"time": "time",
|
|
||||||
"chat_id": "chat_id",
|
|
||||||
"chat_info.stream_id": "chat_info_stream_id",
|
|
||||||
"chat_info.platform": "chat_info_platform",
|
|
||||||
"chat_info.user_info.platform": "chat_info_user_platform",
|
|
||||||
"chat_info.user_info.user_id": "chat_info_user_id",
|
|
||||||
"chat_info.user_info.user_nickname": "chat_info_user_nickname",
|
|
||||||
"chat_info.user_info.user_cardname": "chat_info_user_cardname",
|
|
||||||
"chat_info.group_info.platform": "chat_info_group_platform",
|
|
||||||
"chat_info.group_info.group_id": "chat_info_group_id",
|
|
||||||
"chat_info.group_info.group_name": "chat_info_group_name",
|
|
||||||
"chat_info.create_time": "chat_info_create_time",
|
|
||||||
"chat_info.last_active_time": "chat_info_last_active_time",
|
|
||||||
"user_info.platform": "user_platform",
|
|
||||||
"user_info.user_id": "user_id",
|
|
||||||
"user_info.user_nickname": "user_nickname",
|
|
||||||
"user_info.user_cardname": "user_cardname",
|
|
||||||
"processed_plain_text": "processed_plain_text",
|
|
||||||
"memorized_times": "memorized_times",
|
|
||||||
},
|
|
||||||
enable_validation=False, # 禁用数据验证
|
|
||||||
unique_fields=["message_id"],
|
|
||||||
),
|
|
||||||
# 图片迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="images",
|
|
||||||
target_model=Images,
|
|
||||||
field_mapping={
|
|
||||||
"hash": "emoji_hash",
|
|
||||||
"description": "description",
|
|
||||||
"path": "path",
|
|
||||||
"timestamp": "timestamp",
|
|
||||||
"type": "type",
|
|
||||||
},
|
|
||||||
unique_fields=["path"],
|
|
||||||
),
|
|
||||||
# 图片描述迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="image_descriptions",
|
|
||||||
target_model=ImageDescriptions,
|
|
||||||
field_mapping={
|
|
||||||
"type": "type",
|
|
||||||
"hash": "image_description_hash",
|
|
||||||
"description": "description",
|
|
||||||
"timestamp": "timestamp",
|
|
||||||
},
|
|
||||||
unique_fields=["image_description_hash", "type"],
|
|
||||||
),
|
|
||||||
# 个人信息迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="person_info",
|
|
||||||
target_model=PersonInfo,
|
|
||||||
field_mapping={
|
|
||||||
"person_id": "person_id",
|
|
||||||
"person_name": "person_name",
|
|
||||||
"name_reason": "name_reason",
|
|
||||||
"platform": "platform",
|
|
||||||
"user_id": "user_id",
|
|
||||||
"nickname": "nickname",
|
|
||||||
"relationship_value": "relationship_value",
|
|
||||||
"konw_time": "know_time",
|
|
||||||
},
|
|
||||||
unique_fields=["person_id"],
|
|
||||||
),
|
|
||||||
# 知识库迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="knowledges",
|
|
||||||
target_model=Knowledges,
|
|
||||||
field_mapping={"content": "content", "embedding": "embedding"},
|
|
||||||
unique_fields=["content"], # 假设内容唯一
|
|
||||||
),
|
|
||||||
# 思考日志迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="thinking_log",
|
|
||||||
target_model=ThinkingLog,
|
|
||||||
field_mapping={
|
|
||||||
"chat_id": "chat_id",
|
|
||||||
"trigger_text": "trigger_text",
|
|
||||||
"response_text": "response_text",
|
|
||||||
"trigger_info": "trigger_info_json",
|
|
||||||
"response_info": "response_info_json",
|
|
||||||
"timing_results": "timing_results_json",
|
|
||||||
"chat_history": "chat_history_json",
|
|
||||||
"chat_history_in_thinking": "chat_history_in_thinking_json",
|
|
||||||
"chat_history_after_response": "chat_history_after_response_json",
|
|
||||||
"heartflow_data": "heartflow_data_json",
|
|
||||||
"reasoning_data": "reasoning_data_json",
|
|
||||||
},
|
|
||||||
unique_fields=["chat_id", "trigger_text"],
|
|
||||||
),
|
|
||||||
# 图节点迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="graph_data.nodes",
|
|
||||||
target_model=GraphNodes,
|
|
||||||
field_mapping={
|
|
||||||
"concept": "concept",
|
|
||||||
"memory_items": "memory_items",
|
|
||||||
"hash": "hash",
|
|
||||||
"created_time": "created_time",
|
|
||||||
"last_modified": "last_modified",
|
|
||||||
},
|
|
||||||
unique_fields=["concept"],
|
|
||||||
),
|
|
||||||
# 图边迁移配置
|
|
||||||
MigrationConfig(
|
|
||||||
mongo_collection="graph_data.edges",
|
|
||||||
target_model=GraphEdges,
|
|
||||||
field_mapping={
|
|
||||||
"source": "source",
|
|
||||||
"target": "target",
|
|
||||||
"strength": "strength",
|
|
||||||
"hash": "hash",
|
|
||||||
"created_time": "created_time",
|
|
||||||
"last_modified": "last_modified",
|
|
||||||
},
|
|
||||||
unique_fields=["source", "target"], # 组合唯一性
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
def _initialize_validation_rules(self) -> dict[str, Any]:
|
|
||||||
"""数据验证已禁用 - 返回空字典"""
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def connect_mongodb(self) -> bool:
|
|
||||||
"""连接到MongoDB"""
|
|
||||||
try:
|
|
||||||
self.mongo_client = MongoClient(
|
|
||||||
self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10
|
|
||||||
)
|
|
||||||
|
|
||||||
# 测试连接
|
|
||||||
self.mongo_client.admin.command("ping")
|
|
||||||
self.mongo_db = self.mongo_client[self.database_name]
|
|
||||||
|
|
||||||
logger.info(f"成功连接到MongoDB: {self.database_name}")
|
|
||||||
return True
|
|
||||||
|
|
||||||
except ConnectionFailure as e:
|
|
||||||
logger.error(f"MongoDB连接失败: {e}")
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"MongoDB连接异常: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def disconnect_mongodb(self):
|
|
||||||
"""断开MongoDB连接"""
|
|
||||||
if self.mongo_client:
|
|
||||||
self.mongo_client.close()
|
|
||||||
logger.info("MongoDB连接已关闭")
|
|
||||||
|
|
||||||
def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any:
|
|
||||||
"""获取嵌套字段的值"""
|
|
||||||
if "." not in field_path:
|
|
||||||
return document.get(field_path)
|
|
||||||
|
|
||||||
parts = field_path.split(".")
|
|
||||||
value = document
|
|
||||||
|
|
||||||
for part in parts:
|
|
||||||
if isinstance(value, dict):
|
|
||||||
value = value.get(part)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if value is None:
|
|
||||||
break
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def _convert_field_value(self, value: Any, target_field: Field) -> Any:
|
|
||||||
"""根据目标字段类型转换值"""
|
|
||||||
if value is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
field_type = target_field.__class__.__name__
|
|
||||||
|
|
||||||
try:
|
|
||||||
if target_field.name == "record_time" and field_type == "DateTimeField":
|
|
||||||
return datetime.now()
|
|
||||||
|
|
||||||
if field_type in ["CharField", "TextField"]:
|
|
||||||
if isinstance(value, list | dict):
|
|
||||||
return orjson.dumps(value, ensure_ascii=False)
|
|
||||||
return str(value) if value is not None else ""
|
|
||||||
|
|
||||||
elif field_type == "IntegerField":
|
|
||||||
if isinstance(value, str):
|
|
||||||
# 处理字符串数字
|
|
||||||
clean_value = value.strip()
|
|
||||||
if clean_value.replace(".", "").replace("-", "").isdigit():
|
|
||||||
return int(float(clean_value))
|
|
||||||
return 0
|
|
||||||
return int(value) if value is not None else 0
|
|
||||||
|
|
||||||
elif field_type in ["FloatField", "DoubleField"]:
|
|
||||||
return float(value) if value is not None else 0.0
|
|
||||||
|
|
||||||
elif field_type == "BooleanField":
|
|
||||||
if isinstance(value, str):
|
|
||||||
return value.lower() in ("true", "1", "yes", "on")
|
|
||||||
return bool(value)
|
|
||||||
|
|
||||||
elif field_type == "DateTimeField":
|
|
||||||
if isinstance(value, int | float):
|
|
||||||
return datetime.fromtimestamp(value)
|
|
||||||
elif isinstance(value, str):
|
|
||||||
try:
|
|
||||||
# 尝试解析ISO格式日期
|
|
||||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
|
||||||
except ValueError:
|
|
||||||
try:
|
|
||||||
# 尝试解析时间戳字符串
|
|
||||||
return datetime.fromtimestamp(float(value))
|
|
||||||
except ValueError:
|
|
||||||
return datetime.now()
|
|
||||||
return datetime.now()
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
except (ValueError, TypeError) as e:
|
|
||||||
logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}")
|
|
||||||
return self._get_default_value_for_field(target_field)
|
|
||||||
|
|
||||||
def _get_default_value_for_field(self, field: Field) -> Any:
|
|
||||||
"""获取字段的默认值"""
|
|
||||||
field_type = field.__class__.__name__
|
|
||||||
|
|
||||||
if hasattr(field, "default") and field.default is not None:
|
|
||||||
return field.default
|
|
||||||
|
|
||||||
if field.null:
|
|
||||||
return None
|
|
||||||
|
|
||||||
# 根据字段类型返回默认值
|
|
||||||
if field_type in ["CharField", "TextField"]:
|
|
||||||
return ""
|
|
||||||
elif field_type == "IntegerField":
|
|
||||||
return 0
|
|
||||||
elif field_type in ["FloatField", "DoubleField"]:
|
|
||||||
return 0.0
|
|
||||||
elif field_type == "BooleanField":
|
|
||||||
return False
|
|
||||||
elif field_type == "DateTimeField":
|
|
||||||
return datetime.now()
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool:
|
|
||||||
"""数据验证已禁用 - 始终返回True"""
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any):
|
|
||||||
"""保存迁移断点"""
|
|
||||||
checkpoint = MigrationCheckpoint(
|
|
||||||
collection_name=collection_name,
|
|
||||||
processed_count=processed_count,
|
|
||||||
last_processed_id=last_id,
|
|
||||||
timestamp=datetime.now(),
|
|
||||||
)
|
|
||||||
|
|
||||||
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
|
|
||||||
try:
|
|
||||||
with open(checkpoint_file, "wb") as f:
|
|
||||||
pickle.dump(checkpoint, f)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"保存断点失败: {e}")
|
|
||||||
|
|
||||||
def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None:
|
|
||||||
"""加载迁移断点"""
|
|
||||||
checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl"
|
|
||||||
if not checkpoint_file.exists():
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(checkpoint_file, "rb") as f:
|
|
||||||
return pickle.load(f)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"加载断点失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int:
|
|
||||||
"""批量插入数据"""
|
|
||||||
if not data_list:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
success_count = 0
|
|
||||||
try:
|
|
||||||
with db.atomic():
|
|
||||||
# 分批插入,避免SQL语句过长
|
|
||||||
batch_size = 100
|
|
||||||
for i in range(0, len(data_list), batch_size):
|
|
||||||
batch = data_list[i : i + batch_size]
|
|
||||||
model.insert_many(batch).execute()
|
|
||||||
success_count += len(batch)
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"批量插入失败: {e}")
|
|
||||||
# 如果批量插入失败,尝试逐个插入
|
|
||||||
for data in data_list:
|
|
||||||
try:
|
|
||||||
model.create(**data)
|
|
||||||
success_count += 1
|
|
||||||
except Exception:
|
|
||||||
pass # 忽略单个插入失败
|
|
||||||
|
|
||||||
return success_count
|
|
||||||
|
|
||||||
def _check_duplicate_by_unique_fields(
|
|
||||||
self, model: type[Model], data: dict[str, Any], unique_fields: list[str]
|
|
||||||
) -> bool:
|
|
||||||
"""根据唯一字段检查重复"""
|
|
||||||
if not unique_fields:
|
|
||||||
return False
|
|
||||||
|
|
||||||
try:
|
|
||||||
query = model.select()
|
|
||||||
for field_name in unique_fields:
|
|
||||||
if field_name in data and data[field_name] is not None:
|
|
||||||
field_obj = getattr(model, field_name)
|
|
||||||
query = query.where(field_obj == data[field_name])
|
|
||||||
|
|
||||||
return query.exists()
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"重复检查失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None:
|
|
||||||
"""使用ORM创建模型实例"""
|
|
||||||
try:
|
|
||||||
# 过滤掉不存在的字段
|
|
||||||
valid_data = {}
|
|
||||||
for field_name, value in data.items():
|
|
||||||
if hasattr(model, field_name):
|
|
||||||
valid_data[field_name] = value
|
|
||||||
else:
|
|
||||||
logger.debug(f"跳过未知字段: {field_name}")
|
|
||||||
|
|
||||||
# 创建实例
|
|
||||||
instance = model.create(**valid_data)
|
|
||||||
return instance
|
|
||||||
|
|
||||||
except IntegrityError as e:
|
|
||||||
# 处理唯一约束冲突等完整性错误
|
|
||||||
logger.debug(f"完整性约束冲突: {e}")
|
|
||||||
return None
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"创建模型实例失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def migrate_collection(self, config: MigrationConfig) -> MigrationStats:
|
|
||||||
"""迁移单个集合 - 使用优化的批量插入和进度条"""
|
|
||||||
stats = MigrationStats()
|
|
||||||
stats.start_time = datetime.now()
|
|
||||||
|
|
||||||
# 检查是否有断点
|
|
||||||
checkpoint = self._load_checkpoint(config.mongo_collection)
|
|
||||||
start_from_id = checkpoint.last_processed_id if checkpoint else None
|
|
||||||
if checkpoint:
|
|
||||||
stats.processed_count = checkpoint.processed_count
|
|
||||||
logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录")
|
|
||||||
|
|
||||||
logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取MongoDB集合
|
|
||||||
mongo_collection = self.mongo_db[config.mongo_collection]
|
|
||||||
|
|
||||||
# 构建查询条件(用于断点恢复)
|
|
||||||
query = {}
|
|
||||||
if start_from_id:
|
|
||||||
query = {"_id": {"$gt": start_from_id}}
|
|
||||||
|
|
||||||
stats.total_documents = mongo_collection.count_documents(query)
|
|
||||||
|
|
||||||
if stats.total_documents == 0:
|
|
||||||
logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移")
|
|
||||||
return stats
|
|
||||||
|
|
||||||
logger.info(f"待迁移文档数量: {stats.total_documents}")
|
|
||||||
|
|
||||||
# 创建Rich进度条
|
|
||||||
with Progress(
|
|
||||||
SpinnerColumn(),
|
|
||||||
TextColumn("[progress.description]{task.description}"),
|
|
||||||
BarColumn(),
|
|
||||||
TaskProgressColumn(),
|
|
||||||
TimeElapsedColumn(),
|
|
||||||
TimeRemainingColumn(),
|
|
||||||
console=self.console,
|
|
||||||
refresh_per_second=10,
|
|
||||||
) as progress:
|
|
||||||
task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents)
|
|
||||||
# 批量处理数据
|
|
||||||
batch_data = []
|
|
||||||
batch_count = 0
|
|
||||||
last_processed_id = None
|
|
||||||
|
|
||||||
for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size):
|
|
||||||
try:
|
|
||||||
doc_id = mongo_doc.get("_id", "unknown")
|
|
||||||
last_processed_id = doc_id
|
|
||||||
|
|
||||||
# 构建目标数据
|
|
||||||
target_data = {}
|
|
||||||
for mongo_field, sqlite_field in config.field_mapping.items():
|
|
||||||
value = self._get_nested_value(mongo_doc, mongo_field)
|
|
||||||
|
|
||||||
# 获取目标字段对象并转换类型
|
|
||||||
if hasattr(config.target_model, sqlite_field):
|
|
||||||
field_obj = getattr(config.target_model, sqlite_field)
|
|
||||||
converted_value = self._convert_field_value(value, field_obj)
|
|
||||||
target_data[sqlite_field] = converted_value
|
|
||||||
|
|
||||||
# 数据验证已禁用
|
|
||||||
# if config.enable_validation:
|
|
||||||
# if not self._validate_data(config.mongo_collection, target_data, doc_id, stats):
|
|
||||||
# stats.skipped_count += 1
|
|
||||||
# continue
|
|
||||||
|
|
||||||
# 重复检查
|
|
||||||
if config.skip_duplicates and self._check_duplicate_by_unique_fields(
|
|
||||||
config.target_model, target_data, config.unique_fields
|
|
||||||
):
|
|
||||||
stats.duplicate_count += 1
|
|
||||||
stats.skipped_count += 1
|
|
||||||
logger.debug(f"跳过重复记录: {doc_id}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 添加到批量数据
|
|
||||||
batch_data.append(target_data)
|
|
||||||
stats.processed_count += 1
|
|
||||||
|
|
||||||
# 执行批量插入
|
|
||||||
if len(batch_data) >= config.batch_size:
|
|
||||||
success_count = self._batch_insert(config.target_model, batch_data)
|
|
||||||
stats.success_count += success_count
|
|
||||||
stats.batch_insert_count += 1
|
|
||||||
|
|
||||||
# 保存断点
|
|
||||||
self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id)
|
|
||||||
|
|
||||||
batch_data.clear()
|
|
||||||
batch_count += 1
|
|
||||||
|
|
||||||
# 更新进度条
|
|
||||||
progress.update(task, advance=config.batch_size)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
doc_id = mongo_doc.get("_id", "unknown")
|
|
||||||
stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc)
|
|
||||||
logger.error(f"处理文档失败 (ID: {doc_id}): {e}")
|
|
||||||
|
|
||||||
# 处理剩余的批量数据
|
|
||||||
if batch_data:
|
|
||||||
success_count = self._batch_insert(config.target_model, batch_data)
|
|
||||||
stats.success_count += success_count
|
|
||||||
stats.batch_insert_count += 1
|
|
||||||
progress.update(task, advance=len(batch_data))
|
|
||||||
|
|
||||||
# 完成进度条
|
|
||||||
progress.update(task, completed=stats.total_documents)
|
|
||||||
|
|
||||||
stats.end_time = datetime.now()
|
|
||||||
duration = stats.end_time - stats.start_time
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n"
|
|
||||||
f"总计: {stats.total_documents}, 成功: {stats.success_count}, "
|
|
||||||
f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n"
|
|
||||||
f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 清理断点文件
|
|
||||||
checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl"
|
|
||||||
if checkpoint_file.exists():
|
|
||||||
checkpoint_file.unlink()
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}")
|
|
||||||
stats.add_error("collection_error", str(e))
|
|
||||||
|
|
||||||
return stats
|
|
||||||
|
|
||||||
def migrate_all(self) -> dict[str, MigrationStats]:
|
|
||||||
"""执行所有迁移任务"""
|
|
||||||
logger.info("开始执行数据库迁移...")
|
|
||||||
|
|
||||||
if not self.connect_mongodb():
|
|
||||||
logger.error("无法连接到MongoDB,迁移终止")
|
|
||||||
return {}
|
|
||||||
|
|
||||||
all_stats = {}
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 创建总体进度表格
|
|
||||||
total_collections = len(self.migration_configs)
|
|
||||||
self.console.print(
|
|
||||||
Panel(
|
|
||||||
f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n"
|
|
||||||
f"[yellow]总集合数: {total_collections}[/yellow]",
|
|
||||||
title="迁移开始",
|
|
||||||
expand=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
for idx, config in enumerate(self.migration_configs, 1):
|
|
||||||
self.console.print(
|
|
||||||
f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]"
|
|
||||||
)
|
|
||||||
stats = self.migrate_collection(config)
|
|
||||||
all_stats[config.mongo_collection] = stats
|
|
||||||
|
|
||||||
# 显示单个集合的快速统计
|
|
||||||
if stats.processed_count > 0:
|
|
||||||
success_rate = stats.success_count / stats.processed_count * 100
|
|
||||||
if success_rate >= 95:
|
|
||||||
status_emoji = "✅"
|
|
||||||
status_color = "bright_green"
|
|
||||||
elif success_rate >= 80:
|
|
||||||
status_emoji = "⚠️"
|
|
||||||
status_color = "yellow"
|
|
||||||
else:
|
|
||||||
status_emoji = "❌"
|
|
||||||
status_color = "red"
|
|
||||||
|
|
||||||
self.console.print(
|
|
||||||
f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} "
|
|
||||||
f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 错误率检查
|
|
||||||
if stats.processed_count > 0:
|
|
||||||
error_rate = stats.error_count / stats.processed_count
|
|
||||||
if error_rate > 0.1: # 错误率超过10%
|
|
||||||
self.console.print(
|
|
||||||
f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} "
|
|
||||||
f"({stats.error_count}/{stats.processed_count})[/red]"
|
|
||||||
)
|
|
||||||
|
|
||||||
finally:
|
|
||||||
self.disconnect_mongodb()
|
|
||||||
|
|
||||||
self._print_migration_summary(all_stats)
|
|
||||||
return all_stats
|
|
||||||
|
|
||||||
def _print_migration_summary(self, all_stats: dict[str, MigrationStats]):
|
|
||||||
"""使用Rich打印美观的迁移汇总信息"""
|
|
||||||
# 计算总体统计
|
|
||||||
total_processed = sum(stats.processed_count for stats in all_stats.values())
|
|
||||||
total_success = sum(stats.success_count for stats in all_stats.values())
|
|
||||||
total_errors = sum(stats.error_count for stats in all_stats.values())
|
|
||||||
total_skipped = sum(stats.skipped_count for stats in all_stats.values())
|
|
||||||
total_duplicates = sum(stats.duplicate_count for stats in all_stats.values())
|
|
||||||
total_validation_errors = sum(stats.validation_errors for stats in all_stats.values())
|
|
||||||
total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values())
|
|
||||||
|
|
||||||
# 计算总耗时
|
|
||||||
total_duration_seconds = 0
|
|
||||||
for stats in all_stats.values():
|
|
||||||
if stats.start_time and stats.end_time:
|
|
||||||
duration = stats.end_time - stats.start_time
|
|
||||||
total_duration_seconds += duration.total_seconds()
|
|
||||||
|
|
||||||
# 创建详细统计表格
|
|
||||||
table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta")
|
|
||||||
table.add_column("集合名称", style="cyan", width=20)
|
|
||||||
table.add_column("文档总数", justify="right", style="blue")
|
|
||||||
table.add_column("处理数量", justify="right", style="green")
|
|
||||||
table.add_column("成功数量", justify="right", style="green")
|
|
||||||
table.add_column("错误数量", justify="right", style="red")
|
|
||||||
table.add_column("跳过数量", justify="right", style="yellow")
|
|
||||||
table.add_column("重复数量", justify="right", style="bright_yellow")
|
|
||||||
table.add_column("验证错误", justify="right", style="red")
|
|
||||||
table.add_column("批次数", justify="right", style="purple")
|
|
||||||
table.add_column("成功率", justify="right", style="bright_green")
|
|
||||||
table.add_column("耗时(秒)", justify="right", style="blue")
|
|
||||||
|
|
||||||
for collection_name, stats in all_stats.items():
|
|
||||||
success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0
|
|
||||||
duration = 0
|
|
||||||
if stats.start_time and stats.end_time:
|
|
||||||
duration = (stats.end_time - stats.start_time).total_seconds()
|
|
||||||
|
|
||||||
# 根据成功率设置颜色
|
|
||||||
if success_rate >= 95:
|
|
||||||
success_rate_style = "[bright_green]"
|
|
||||||
elif success_rate >= 80:
|
|
||||||
success_rate_style = "[yellow]"
|
|
||||||
else:
|
|
||||||
success_rate_style = "[red]"
|
|
||||||
|
|
||||||
table.add_row(
|
|
||||||
collection_name,
|
|
||||||
str(stats.total_documents),
|
|
||||||
str(stats.processed_count),
|
|
||||||
str(stats.success_count),
|
|
||||||
f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0",
|
|
||||||
f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0",
|
|
||||||
f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0",
|
|
||||||
f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0",
|
|
||||||
str(stats.batch_insert_count),
|
|
||||||
f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}",
|
|
||||||
f"{duration:.2f}",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 添加总计行
|
|
||||||
total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0
|
|
||||||
if total_success_rate >= 95:
|
|
||||||
total_rate_style = "[bright_green]"
|
|
||||||
elif total_success_rate >= 80:
|
|
||||||
total_rate_style = "[yellow]"
|
|
||||||
else:
|
|
||||||
total_rate_style = "[red]"
|
|
||||||
|
|
||||||
table.add_section()
|
|
||||||
table.add_row(
|
|
||||||
"[bold]总计[/bold]",
|
|
||||||
f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]",
|
|
||||||
f"[bold]{total_processed}[/bold]",
|
|
||||||
f"[bold]{total_success}[/bold]",
|
|
||||||
f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]",
|
|
||||||
f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]",
|
|
||||||
f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]"
|
|
||||||
if total_duplicates > 0
|
|
||||||
else "[bold]0[/bold]",
|
|
||||||
f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]",
|
|
||||||
f"[bold]{total_batch_inserts}[/bold]",
|
|
||||||
f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]",
|
|
||||||
f"[bold]{total_duration_seconds:.2f}[/bold]",
|
|
||||||
)
|
|
||||||
|
|
||||||
self.console.print(table)
|
|
||||||
|
|
||||||
# 创建状态面板
|
|
||||||
status_items = []
|
|
||||||
if total_errors > 0:
|
|
||||||
status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]")
|
|
||||||
|
|
||||||
if total_validation_errors > 0:
|
|
||||||
status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]")
|
|
||||||
|
|
||||||
if total_duplicates > 0:
|
|
||||||
status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]")
|
|
||||||
|
|
||||||
if total_success_rate >= 95:
|
|
||||||
status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]")
|
|
||||||
elif total_success_rate >= 80:
|
|
||||||
status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]")
|
|
||||||
else:
|
|
||||||
status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]")
|
|
||||||
|
|
||||||
if status_items:
|
|
||||||
status_panel = Panel(
|
|
||||||
"\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow"
|
|
||||||
)
|
|
||||||
self.console.print(status_panel)
|
|
||||||
|
|
||||||
# 性能统计面板
|
|
||||||
avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0
|
|
||||||
performance_info = (
|
|
||||||
f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f} 秒\n"
|
|
||||||
f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n"
|
|
||||||
f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作"
|
|
||||||
)
|
|
||||||
|
|
||||||
performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green")
|
|
||||||
self.console.print(performance_panel)
|
|
||||||
|
|
||||||
def add_migration_config(self, config: MigrationConfig):
|
|
||||||
"""添加新的迁移配置"""
|
|
||||||
self.migration_configs.append(config)
|
|
||||||
|
|
||||||
def migrate_single_collection(self, collection_name: str) -> MigrationStats | None:
|
|
||||||
"""迁移单个指定的集合"""
|
|
||||||
config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None)
|
|
||||||
if not config:
|
|
||||||
logger.error(f"未找到集合 {collection_name} 的迁移配置")
|
|
||||||
return None
|
|
||||||
|
|
||||||
if not self.connect_mongodb():
|
|
||||||
logger.error("无法连接到MongoDB")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
stats = self.migrate_collection(config)
|
|
||||||
self._print_migration_summary({collection_name: stats})
|
|
||||||
return stats
|
|
||||||
finally:
|
|
||||||
self.disconnect_mongodb()
|
|
||||||
|
|
||||||
def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str):
|
|
||||||
"""导出错误报告"""
|
|
||||||
error_report = {
|
|
||||||
"timestamp": datetime.now().isoformat(),
|
|
||||||
"summary": {
|
|
||||||
collection: {
|
|
||||||
"total": stats.total_documents,
|
|
||||||
"processed": stats.processed_count,
|
|
||||||
"success": stats.success_count,
|
|
||||||
"errors": stats.error_count,
|
|
||||||
"skipped": stats.skipped_count,
|
|
||||||
"duplicates": stats.duplicate_count,
|
|
||||||
}
|
|
||||||
for collection, stats in all_stats.items()
|
|
||||||
},
|
|
||||||
"errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors},
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(filepath, "w", encoding="utf-8") as f:
|
|
||||||
orjson.dumps(error_report, f, ensure_ascii=False, indent=2)
|
|
||||||
logger.info(f"错误报告已导出到: {filepath}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"导出错误报告失败: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主程序入口"""
|
|
||||||
migrator = MongoToSQLiteMigrator()
|
|
||||||
|
|
||||||
# 执行迁移
|
|
||||||
migration_results = migrator.migrate_all()
|
|
||||||
|
|
||||||
# 导出错误报告(如果有错误)
|
|
||||||
if any(stats.error_count > 0 for stats in migration_results.values()):
|
|
||||||
error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
|
||||||
migrator.export_error_report(migration_results, error_report_path)
|
|
||||||
|
|
||||||
logger.info("数据迁移完成!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
556
scripts/run.sh
556
scripts/run.sh
@@ -1,556 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
|
|
||||||
# MaiCore & NapCat Adapter一键安装脚本 by Cookie_987
|
|
||||||
# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9
|
|
||||||
# 请小心使用任何一键脚本!
|
|
||||||
|
|
||||||
INSTALLER_VERSION="0.0.5-refactor"
|
|
||||||
LANG=C.UTF-8
|
|
||||||
|
|
||||||
# 如无法访问GitHub请修改此处镜像地址
|
|
||||||
GITHUB_REPO="https://ghfast.top/https://github.com"
|
|
||||||
|
|
||||||
# 颜色输出
|
|
||||||
GREEN="\e[32m"
|
|
||||||
RED="\e[31m"
|
|
||||||
RESET="\e[0m"
|
|
||||||
|
|
||||||
# 需要的基本软件包
|
|
||||||
|
|
||||||
declare -A REQUIRED_PACKAGES=(
|
|
||||||
["common"]="git sudo python3 curl gnupg"
|
|
||||||
["debian"]="python3-venv python3-pip build-essential"
|
|
||||||
["ubuntu"]="python3-venv python3-pip build-essential"
|
|
||||||
["centos"]="epel-release python3-pip python3-devel gcc gcc-c++ make"
|
|
||||||
["arch"]="python-virtualenv python-pip base-devel"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 默认项目目录
|
|
||||||
DEFAULT_INSTALL_DIR="/opt/maicore"
|
|
||||||
|
|
||||||
# 服务名称
|
|
||||||
SERVICE_NAME="maicore"
|
|
||||||
SERVICE_NAME_WEB="maicore-web"
|
|
||||||
SERVICE_NAME_NBADAPTER="maibot-napcat-adapter"
|
|
||||||
|
|
||||||
IS_INSTALL_NAPCAT=false
|
|
||||||
IS_INSTALL_DEPENDENCIES=false
|
|
||||||
|
|
||||||
# 检查是否已安装
|
|
||||||
check_installed() {
|
|
||||||
[[ -f /etc/systemd/system/${SERVICE_NAME}.service ]]
|
|
||||||
}
|
|
||||||
|
|
||||||
# 加载安装信息
|
|
||||||
load_install_info() {
|
|
||||||
if [[ -f /etc/maicore_install.conf ]]; then
|
|
||||||
source /etc/maicore_install.conf
|
|
||||||
else
|
|
||||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
|
||||||
BRANCH="refactor"
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# 显示管理菜单
|
|
||||||
show_menu() {
|
|
||||||
while true; do
|
|
||||||
choice=$(whiptail --title "MaiCore管理菜单" --menu "请选择要执行的操作:" 15 60 7 \
|
|
||||||
"1" "启动MaiCore" \
|
|
||||||
"2" "停止MaiCore" \
|
|
||||||
"3" "重启MaiCore" \
|
|
||||||
"4" "启动NapCat Adapter" \
|
|
||||||
"5" "停止NapCat Adapter" \
|
|
||||||
"6" "重启NapCat Adapter" \
|
|
||||||
"7" "拉取最新MaiCore仓库" \
|
|
||||||
"8" "切换分支" \
|
|
||||||
"9" "退出" 3>&1 1>&2 2>&3)
|
|
||||||
|
|
||||||
[[ $? -ne 0 ]] && exit 0
|
|
||||||
|
|
||||||
case "$choice" in
|
|
||||||
1)
|
|
||||||
systemctl start ${SERVICE_NAME}
|
|
||||||
whiptail --msgbox "✅MaiCore已启动" 10 60
|
|
||||||
;;
|
|
||||||
2)
|
|
||||||
systemctl stop ${SERVICE_NAME}
|
|
||||||
whiptail --msgbox "🛑MaiCore已停止" 10 60
|
|
||||||
;;
|
|
||||||
3)
|
|
||||||
systemctl restart ${SERVICE_NAME}
|
|
||||||
whiptail --msgbox "🔄MaiCore已重启" 10 60
|
|
||||||
;;
|
|
||||||
4)
|
|
||||||
systemctl start ${SERVICE_NAME_NBADAPTER}
|
|
||||||
whiptail --msgbox "✅NapCat Adapter已启动" 10 60
|
|
||||||
;;
|
|
||||||
5)
|
|
||||||
systemctl stop ${SERVICE_NAME_NBADAPTER}
|
|
||||||
whiptail --msgbox "🛑NapCat Adapter已停止" 10 60
|
|
||||||
;;
|
|
||||||
6)
|
|
||||||
systemctl restart ${SERVICE_NAME_NBADAPTER}
|
|
||||||
whiptail --msgbox "🔄NapCat Adapter已重启" 10 60
|
|
||||||
;;
|
|
||||||
7)
|
|
||||||
update_dependencies
|
|
||||||
;;
|
|
||||||
8)
|
|
||||||
switch_branch
|
|
||||||
;;
|
|
||||||
9)
|
|
||||||
exit 0
|
|
||||||
;;
|
|
||||||
*)
|
|
||||||
whiptail --msgbox "无效选项!" 10 60
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
}
|
|
||||||
|
|
||||||
# 更新依赖
|
|
||||||
update_dependencies() {
|
|
||||||
whiptail --title "⚠" --msgbox "更新后请阅读教程" 10 60
|
|
||||||
systemctl stop ${SERVICE_NAME}
|
|
||||||
cd "${INSTALL_DIR}/MaiBot" || {
|
|
||||||
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
if ! git pull origin "${BRANCH}"; then
|
|
||||||
whiptail --msgbox "🚫 代码更新失败!" 10 60
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
source "${INSTALL_DIR}/venv/bin/activate"
|
|
||||||
if ! pip install -r requirements.txt; then
|
|
||||||
whiptail --msgbox "🚫 依赖安装失败!" 10 60
|
|
||||||
deactivate
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
deactivate
|
|
||||||
whiptail --msgbox "✅ 已停止服务并拉取最新仓库提交" 10 60
|
|
||||||
}
|
|
||||||
|
|
||||||
# 切换分支
|
|
||||||
switch_branch() {
|
|
||||||
new_branch=$(whiptail --inputbox "请输入要切换的分支名称:" 10 60 "${BRANCH}" 3>&1 1>&2 2>&3)
|
|
||||||
[[ -z "$new_branch" ]] && {
|
|
||||||
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
cd "${INSTALL_DIR}/MaiBot" || {
|
|
||||||
whiptail --msgbox "🚫 无法进入安装目录!" 10 60
|
|
||||||
return 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if ! git ls-remote --exit-code --heads origin "${new_branch}" >/dev/null 2>&1; then
|
|
||||||
whiptail --msgbox "🚫 分支 ${new_branch} 不存在!" 10 60
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if ! git checkout "${new_branch}"; then
|
|
||||||
whiptail --msgbox "🚫 分支切换失败!" 10 60
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if ! git pull origin "${new_branch}"; then
|
|
||||||
whiptail --msgbox "🚫 代码拉取失败!" 10 60
|
|
||||||
return 1
|
|
||||||
fi
|
|
||||||
systemctl stop ${SERVICE_NAME}
|
|
||||||
source "${INSTALL_DIR}/venv/bin/activate"
|
|
||||||
pip install -r requirements.txt
|
|
||||||
deactivate
|
|
||||||
|
|
||||||
sed -i "s/^BRANCH=.*/BRANCH=${new_branch}/" /etc/maicore_install.conf
|
|
||||||
BRANCH="${new_branch}"
|
|
||||||
check_eula
|
|
||||||
whiptail --msgbox "✅ 已停止服务并切换到分支 ${new_branch} !" 10 60
|
|
||||||
}
|
|
||||||
|
|
||||||
check_eula() {
|
|
||||||
# 首先计算当前EULA的MD5值
|
|
||||||
current_md5=$(md5sum "${INSTALL_DIR}/MaiBot/EULA.md" | awk '{print $1}')
|
|
||||||
|
|
||||||
# 首先计算当前隐私条款文件的哈希值
|
|
||||||
current_md5_privacy=$(md5sum "${INSTALL_DIR}/MaiBot/PRIVACY.md" | awk '{print $1}')
|
|
||||||
|
|
||||||
# 如果当前的md5值为空,则直接返回
|
|
||||||
if [[ -z $current_md5 || -z $current_md5_privacy ]]; then
|
|
||||||
whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 检查eula.confirmed文件是否存在
|
|
||||||
if [[ -f ${INSTALL_DIR}/MaiBot/eula.confirmed ]]; then
|
|
||||||
# 如果存在则检查其中包含的md5与current_md5是否一致
|
|
||||||
confirmed_md5=$(cat ${INSTALL_DIR}/MaiBot/eula.confirmed)
|
|
||||||
else
|
|
||||||
confirmed_md5=""
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 检查privacy.confirmed文件是否存在
|
|
||||||
if [[ -f ${INSTALL_DIR}/MaiBot/privacy.confirmed ]]; then
|
|
||||||
# 如果存在则检查其中包含的md5与current_md5是否一致
|
|
||||||
confirmed_md5_privacy=$(cat ${INSTALL_DIR}/MaiBot/privacy.confirmed)
|
|
||||||
else
|
|
||||||
confirmed_md5_privacy=""
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 如果EULA或隐私条款有更新,提示用户重新确认
|
|
||||||
if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then
|
|
||||||
whiptail --title "📜 使用协议更新" --yesno "检测到MaiCore EULA或隐私条款已更新。\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议? \n\n " 12 70
|
|
||||||
if [[ $? -eq 0 ]]; then
|
|
||||||
echo -n $current_md5 > ${INSTALL_DIR}/MaiBot/eula.confirmed
|
|
||||||
echo -n $current_md5_privacy > ${INSTALL_DIR}/MaiBot/privacy.confirmed
|
|
||||||
else
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
# ----------- 主安装流程 -----------
|
|
||||||
run_installation() {
|
|
||||||
# 1/6: 检测是否安装 whiptail
|
|
||||||
if ! command -v whiptail &>/dev/null; then
|
|
||||||
echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}"
|
|
||||||
|
|
||||||
if command -v apt-get &>/dev/null; then
|
|
||||||
apt-get update && apt-get install -y whiptail
|
|
||||||
elif command -v pacman &>/dev/null; then
|
|
||||||
pacman -Syu --noconfirm whiptail
|
|
||||||
elif command -v yum &>/dev/null; then
|
|
||||||
yum install -y whiptail
|
|
||||||
else
|
|
||||||
echo -e "${RED}[Error] 无受支持的包管理器,无法安装 whiptail!${RESET}"
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
whiptail --title "ℹ️ 提示" --msgbox "如果您没有特殊需求,请优先使用docker方式部署。" 10 60
|
|
||||||
|
|
||||||
# 协议确认
|
|
||||||
if ! (whiptail --title "ℹ️ [1/6] 使用协议" --yes-button "我同意" --no-button "我拒绝" --yesno "使用MaiCore及此脚本前请先阅读EULA协议及隐私协议\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议?" 12 70); then
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
# 欢迎信息
|
|
||||||
whiptail --title "[2/6] 欢迎使用MaiCore一键安装脚本 by Cookie987" --msgbox "检测到您未安装MaiCore,将自动进入安装流程,安装完成后再次运行此脚本即可进入管理菜单。\n\n项目处于活跃开发阶段,代码可能随时更改\n文档未完善,有问题可以提交 Issue 或者 Discussion\nQQ机器人存在被限制风险,请自行了解,谨慎使用\n由于持续迭代,可能存在一些已知或未知的bug\n由于开发中,可能消耗较多token\n\n本脚本可能更新不及时,如遇到bug请优先尝试手动部署以确定是否为脚本问题" 17 60
|
|
||||||
|
|
||||||
# 系统检查
|
|
||||||
check_system() {
|
|
||||||
if [[ "$(id -u)" -ne 0 ]]; then
|
|
||||||
whiptail --title "🚫 权限不足" --msgbox "请使用 root 用户运行此脚本!\n执行方式: sudo bash $0" 10 60
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ -f /etc/os-release ]]; then
|
|
||||||
source /etc/os-release
|
|
||||||
if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then
|
|
||||||
return
|
|
||||||
elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then
|
|
||||||
return
|
|
||||||
elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then
|
|
||||||
return
|
|
||||||
elif [[ "$ID" == "arch" ]]; then
|
|
||||||
whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法,将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60
|
|
||||||
return
|
|
||||||
else
|
|
||||||
whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9!\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
else
|
|
||||||
whiptail --title "⚠️ 无法检测系统" --msgbox "无法识别系统版本,安装已终止。" 10 60
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
check_system
|
|
||||||
|
|
||||||
# 设置包管理器
|
|
||||||
case "$ID" in
|
|
||||||
debian|ubuntu)
|
|
||||||
PKG_MANAGER="apt"
|
|
||||||
;;
|
|
||||||
centos)
|
|
||||||
PKG_MANAGER="yum"
|
|
||||||
;;
|
|
||||||
arch)
|
|
||||||
# 添加arch包管理器
|
|
||||||
PKG_MANAGER="pacman"
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
|
|
||||||
# 检查NapCat
|
|
||||||
check_napcat() {
|
|
||||||
if command -v napcat &>/dev/null; then
|
|
||||||
NAPCAT_INSTALLED=true
|
|
||||||
else
|
|
||||||
NAPCAT_INSTALLED=false
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
check_napcat
|
|
||||||
|
|
||||||
# 安装必要软件包
|
|
||||||
install_packages() {
|
|
||||||
missing_packages=()
|
|
||||||
# 检查 common 及当前系统专属依赖
|
|
||||||
for package in ${REQUIRED_PACKAGES["common"]} ${REQUIRED_PACKAGES["$ID"]}; do
|
|
||||||
case "$PKG_MANAGER" in
|
|
||||||
apt)
|
|
||||||
dpkg -s "$package" &>/dev/null || missing_packages+=("$package")
|
|
||||||
;;
|
|
||||||
yum)
|
|
||||||
rpm -q "$package" &>/dev/null || missing_packages+=("$package")
|
|
||||||
;;
|
|
||||||
pacman)
|
|
||||||
pacman -Qi "$package" &>/dev/null || missing_packages+=("$package")
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
done
|
|
||||||
|
|
||||||
if [[ ${#missing_packages[@]} -gt 0 ]]; then
|
|
||||||
whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装?" 10 60
|
|
||||||
if [[ $? -eq 0 ]]; then
|
|
||||||
IS_INSTALL_DEPENDENCIES=true
|
|
||||||
else
|
|
||||||
whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续?" 10 60 || exit 1
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
install_packages
|
|
||||||
|
|
||||||
# 安装NapCat
|
|
||||||
install_napcat() {
|
|
||||||
[[ $NAPCAT_INSTALLED == true ]] && return
|
|
||||||
whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat,是否安装?\n如果您想使用远程NapCat,请跳过此步。" 10 60 && {
|
|
||||||
IS_INSTALL_NAPCAT=true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
# 仅在非Arch系统上安装NapCat
|
|
||||||
[[ "$ID" != "arch" ]] && install_napcat
|
|
||||||
|
|
||||||
# Python版本检查
|
|
||||||
check_python() {
|
|
||||||
PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")')
|
|
||||||
if ! python3 -c "import sys; exit(0) if sys.version_info >= (3,10) else exit(1)"; then
|
|
||||||
whiptail --title "⚠️ [4/6] Python 版本过低" --msgbox "检测到 Python 版本为 $PYTHON_VERSION,需要 3.10 或以上!\n请升级 Python 后重新运行本脚本。" 10 60
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
|
|
||||||
# 如果没安装python则不检查python版本
|
|
||||||
if command -v python3 &>/dev/null; then
|
|
||||||
check_python
|
|
||||||
fi
|
|
||||||
|
|
||||||
|
|
||||||
# 选择分支
|
|
||||||
choose_branch() {
|
|
||||||
BRANCH=$(whiptail --title "🔀 选择分支" --radiolist "请选择要安装的分支:" 15 60 4 \
|
|
||||||
"main" "稳定版本(推荐)" ON \
|
|
||||||
"dev" "开发版(不知道什么意思就别选)" OFF \
|
|
||||||
"classical" "经典版(0.6.0以前的版本)" OFF \
|
|
||||||
"custom" "自定义分支" OFF 3>&1 1>&2 2>&3)
|
|
||||||
RETVAL=$?
|
|
||||||
if [ $RETVAL -ne 0 ]; then
|
|
||||||
whiptail --msgbox "🚫 操作取消!" 10 60
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ "$BRANCH" == "custom" ]]; then
|
|
||||||
BRANCH=$(whiptail --title "🔀 自定义分支" --inputbox "请输入自定义分支名称:" 10 60 "refactor" 3>&1 1>&2 2>&3)
|
|
||||||
RETVAL=$?
|
|
||||||
if [ $RETVAL -ne 0 ]; then
|
|
||||||
whiptail --msgbox "🚫 输入取消!" 10 60
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
if [[ -z "$BRANCH" ]]; then
|
|
||||||
whiptail --msgbox "🚫 分支名称不能为空!" 10 60
|
|
||||||
exit 1
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
}
|
|
||||||
choose_branch
|
|
||||||
|
|
||||||
# 选择安装路径
|
|
||||||
choose_install_dir() {
|
|
||||||
INSTALL_DIR=$(whiptail --title "📂 [6/6] 选择安装路径" --inputbox "请输入MaiCore的安装目录:" 10 60 "$DEFAULT_INSTALL_DIR" 3>&1 1>&2 2>&3)
|
|
||||||
[[ -z "$INSTALL_DIR" ]] && {
|
|
||||||
whiptail --title "⚠️ 取消输入" --yesno "未输入安装路径,是否退出安装?" 10 60 && exit 1
|
|
||||||
INSTALL_DIR="$DEFAULT_INSTALL_DIR"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
choose_install_dir
|
|
||||||
|
|
||||||
# 确认安装
|
|
||||||
confirm_install() {
|
|
||||||
local confirm_msg="请确认以下更改:\n\n"
|
|
||||||
confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n"
|
|
||||||
confirm_msg+="🔀 分支: $BRANCH\n"
|
|
||||||
[[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n"
|
|
||||||
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n"
|
|
||||||
|
|
||||||
[[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n"
|
|
||||||
confirm_msg+="\n注意:本脚本默认使用ghfast.top为GitHub进行加速,如不想使用请手动修改脚本开头的GITHUB_REPO变量。"
|
|
||||||
|
|
||||||
whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1
|
|
||||||
}
|
|
||||||
confirm_install
|
|
||||||
|
|
||||||
# 开始安装
|
|
||||||
echo -e "${GREEN}安装${missing_packages[@]}...${RESET}"
|
|
||||||
|
|
||||||
if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then
|
|
||||||
case "$PKG_MANAGER" in
|
|
||||||
apt)
|
|
||||||
apt update && apt install -y "${missing_packages[@]}"
|
|
||||||
;;
|
|
||||||
yum)
|
|
||||||
yum install -y "${missing_packages[@]}" --nobest
|
|
||||||
;;
|
|
||||||
pacman)
|
|
||||||
pacman -S --noconfirm "${missing_packages[@]}"
|
|
||||||
;;
|
|
||||||
esac
|
|
||||||
fi
|
|
||||||
|
|
||||||
if [[ $IS_INSTALL_NAPCAT == true ]]; then
|
|
||||||
echo -e "${GREEN}安装 NapCat...${RESET}"
|
|
||||||
curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo -e "${GREEN}创建安装目录...${RESET}"
|
|
||||||
mkdir -p "$INSTALL_DIR"
|
|
||||||
cd "$INSTALL_DIR" || exit 1
|
|
||||||
|
|
||||||
echo -e "${GREEN}设置Python虚拟环境...${RESET}"
|
|
||||||
python3 -m venv venv
|
|
||||||
source venv/bin/activate
|
|
||||||
|
|
||||||
echo -e "${GREEN}克隆MaiCore仓库...${RESET}"
|
|
||||||
git clone -b "$BRANCH" "$GITHUB_REPO/MaiM-with-u/MaiBot" MaiBot || {
|
|
||||||
echo -e "${RED}克隆MaiCore仓库失败!${RESET}"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}"
|
|
||||||
git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || {
|
|
||||||
echo -e "${RED}克隆 maim_message 包仓库失败!${RESET}"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
echo -e "${GREEN}克隆 nonebot-plugin-maibot-adapters 仓库...${RESET}"
|
|
||||||
git clone $GITHUB_REPO/MaiM-with-u/MaiBot-Napcat-Adapter.git || {
|
|
||||||
echo -e "${RED}克隆 MaiBot-Napcat-Adapter.git 仓库失败!${RESET}"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
echo -e "${GREEN}安装Python依赖...${RESET}"
|
|
||||||
pip install -r MaiBot/requirements.txt
|
|
||||||
cd MaiBot
|
|
||||||
pip install uv
|
|
||||||
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
echo -e "${GREEN}安装maim_message依赖...${RESET}"
|
|
||||||
cd maim_message
|
|
||||||
uv pip install -i https://mirrors.aliyun.com/pypi/simple -e .
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
echo -e "${GREEN}部署MaiBot Napcat Adapter...${RESET}"
|
|
||||||
cd MaiBot-Napcat-Adapter
|
|
||||||
uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt
|
|
||||||
cd ..
|
|
||||||
|
|
||||||
echo -e "${GREEN}同意协议...${RESET}"
|
|
||||||
|
|
||||||
# 首先计算当前EULA的MD5值
|
|
||||||
current_md5=$(md5sum "MaiBot/EULA.md" | awk '{print $1}')
|
|
||||||
|
|
||||||
# 首先计算当前隐私条款文件的哈希值
|
|
||||||
current_md5_privacy=$(md5sum "MaiBot/PRIVACY.md" | awk '{print $1}')
|
|
||||||
|
|
||||||
echo -n $current_md5 > MaiBot/eula.confirmed
|
|
||||||
echo -n $current_md5_privacy > MaiBot/privacy.confirmed
|
|
||||||
|
|
||||||
echo -e "${GREEN}创建系统服务...${RESET}"
|
|
||||||
cat > /etc/systemd/system/${SERVICE_NAME}.service <<EOF
|
|
||||||
[Unit]
|
|
||||||
Description=MaiCore
|
|
||||||
After=network.target ${SERVICE_NAME_NBADAPTER}.service
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
Type=simple
|
|
||||||
WorkingDirectory=${INSTALL_DIR}/MaiBot
|
|
||||||
ExecStart=$INSTALL_DIR/venv/bin/python3 bot.py
|
|
||||||
Restart=always
|
|
||||||
RestartSec=10s
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
EOF
|
|
||||||
|
|
||||||
# cat > /etc/systemd/system/${SERVICE_NAME_WEB}.service <<EOF
|
|
||||||
# [Unit]
|
|
||||||
# Description=MaiCore WebUI
|
|
||||||
# After=network.target ${SERVICE_NAME}.service
|
|
||||||
|
|
||||||
# [Service]
|
|
||||||
# Type=simple
|
|
||||||
# WorkingDirectory=${INSTALL_DIR}/MaiBot
|
|
||||||
# ExecStart=$INSTALL_DIR/venv/bin/python3 webui.py
|
|
||||||
# Restart=always
|
|
||||||
# RestartSec=10s
|
|
||||||
|
|
||||||
# [Install]
|
|
||||||
# WantedBy=multi-user.target
|
|
||||||
# EOF
|
|
||||||
|
|
||||||
cat > /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service <<EOF
|
|
||||||
[Unit]
|
|
||||||
Description=MaiBot Napcat Adapter
|
|
||||||
After=network.target mongod.service ${SERVICE_NAME}.service
|
|
||||||
|
|
||||||
[Service]
|
|
||||||
Type=simple
|
|
||||||
WorkingDirectory=${INSTALL_DIR}/MaiBot-Napcat-Adapter
|
|
||||||
ExecStart=$INSTALL_DIR/venv/bin/python3 main.py
|
|
||||||
Restart=always
|
|
||||||
RestartSec=10s
|
|
||||||
|
|
||||||
[Install]
|
|
||||||
WantedBy=multi-user.target
|
|
||||||
EOF
|
|
||||||
|
|
||||||
systemctl daemon-reload
|
|
||||||
|
|
||||||
# 保存安装信息
|
|
||||||
echo "INSTALLER_VERSION=${INSTALLER_VERSION}" > /etc/maicore_install.conf
|
|
||||||
echo "INSTALL_DIR=${INSTALL_DIR}" >> /etc/maicore_install.conf
|
|
||||||
echo "BRANCH=${BRANCH}" >> /etc/maicore_install.conf
|
|
||||||
|
|
||||||
whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成!\n已创建系统服务:${SERVICE_NAME}、${SERVICE_NAME_WEB}、${SERVICE_NAME_NBADAPTER}\n\n使用以下命令管理服务:\n启动服务:systemctl start ${SERVICE_NAME}\n查看状态:systemctl status ${SERVICE_NAME}" 14 60
|
|
||||||
}
|
|
||||||
|
|
||||||
# ----------- 主执行流程 -----------
|
|
||||||
# 检查root权限
|
|
||||||
[[ $(id -u) -ne 0 ]] && {
|
|
||||||
echo -e "${RED}请使用root用户运行此脚本!${RESET}"
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
|
|
||||||
# 如果已安装显示菜单,并检查协议是否更新
|
|
||||||
if check_installed; then
|
|
||||||
load_install_info
|
|
||||||
check_eula
|
|
||||||
show_menu
|
|
||||||
else
|
|
||||||
run_installation
|
|
||||||
# 安装完成后询问是否启动
|
|
||||||
if whiptail --title "安装完成" --yesno "是否立即启动MaiCore服务?" 10 60; then
|
|
||||||
systemctl start ${SERVICE_NAME}
|
|
||||||
whiptail --msgbox "✅ 服务已启动!\n使用 systemctl status ${SERVICE_NAME} 查看状态" 10 60
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
@@ -265,7 +265,8 @@ class AntiPromptInjector:
|
|||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 删除对应的消息记录
|
# 删除对应的消息记录
|
||||||
stmt = delete(Messages).where(Messages.message_id == message_id)
|
stmt = delete(Messages).where(Messages.message_id == message_id)
|
||||||
result = session.execute(stmt)
|
# 注意: 异步会话需要 await 执行,否则 result 是 coroutine,无法获取 rowcount
|
||||||
|
result = await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
if result.rowcount > 0:
|
if result.rowcount > 0:
|
||||||
@@ -296,7 +297,7 @@ class AntiPromptInjector:
|
|||||||
.where(Messages.message_id == message_id)
|
.where(Messages.message_id == message_id)
|
||||||
.values(processed_plain_text=new_content, display_message=new_content)
|
.values(processed_plain_text=new_content, display_message=new_content)
|
||||||
)
|
)
|
||||||
result = session.execute(stmt)
|
result = await session.execute(stmt)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
if result.rowcount > 0:
|
if result.rowcount > 0:
|
||||||
|
|||||||
@@ -5,9 +5,9 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Any
|
from typing import Any, Optional, TypedDict, Literal, Union, Callable, TypeVar, cast
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select, delete
|
||||||
|
|
||||||
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
|
from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -16,8 +16,30 @@ from src.config.config import global_config
|
|||||||
logger = get_logger("anti_injector.statistics")
|
logger = get_logger("anti_injector.statistics")
|
||||||
|
|
||||||
|
|
||||||
|
TNum = TypeVar("TNum", int, float)
|
||||||
|
|
||||||
|
|
||||||
|
def _add_optional(a: Optional[TNum], b: TNum) -> TNum:
|
||||||
|
"""安全相加:左值可能为 None。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: 可能为 None 的当前值
|
||||||
|
b: 要累加的增量(非 None)
|
||||||
|
Returns:
|
||||||
|
新的累加结果(与 b 同类型)
|
||||||
|
"""
|
||||||
|
if a is None:
|
||||||
|
return b
|
||||||
|
return cast(TNum, a + b) # a 不为 None,此处显式 cast 便于类型检查
|
||||||
|
|
||||||
|
|
||||||
class AntiInjectionStatistics:
|
class AntiInjectionStatistics:
|
||||||
"""反注入系统统计管理类"""
|
"""反注入系统统计管理类
|
||||||
|
|
||||||
|
主要改进:
|
||||||
|
- 对 "可能为 None" 的数值字段做集中安全处理,减少在业务逻辑里反复判空。
|
||||||
|
- 补充类型注解,便于静态检查器(Pylance/Pyright)识别。
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""初始化统计管理器"""
|
"""初始化统计管理器"""
|
||||||
@@ -25,8 +47,12 @@ class AntiInjectionStatistics:
|
|||||||
"""当前会话开始时间"""
|
"""当前会话开始时间"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def get_or_create_stats():
|
async def get_or_create_stats() -> Optional[AntiInjectionStats]: # type: ignore[name-defined]
|
||||||
"""获取或创建统计记录"""
|
"""获取或创建统计记录
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AntiInjectionStats | None: 成功返回模型实例,否则 None
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 获取最新的统计记录,如果没有则创建
|
# 获取最新的统计记录,如果没有则创建
|
||||||
@@ -46,8 +72,15 @@ class AntiInjectionStatistics:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def update_stats(**kwargs):
|
async def update_stats(**kwargs: Any) -> None:
|
||||||
"""更新统计数据"""
|
"""更新统计数据(批量可选字段)
|
||||||
|
|
||||||
|
支持字段:
|
||||||
|
- processing_time_delta: float 累加到 processing_time_total
|
||||||
|
- last_processing_time: float 设置 last_process_time
|
||||||
|
- total_messages / detected_injections / blocked_messages / shielded_messages / error_count: 累加
|
||||||
|
- 其他任意字段:直接赋值(若模型存在该属性)
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
stats = (
|
stats = (
|
||||||
@@ -62,14 +95,13 @@ class AntiInjectionStatistics:
|
|||||||
# 更新统计字段
|
# 更新统计字段
|
||||||
for key, value in kwargs.items():
|
for key, value in kwargs.items():
|
||||||
if key == "processing_time_delta":
|
if key == "processing_time_delta":
|
||||||
# 处理 时间累加 - 确保不为None
|
# 处理时间累加 - 确保不为 None
|
||||||
if stats.processing_time_total is None:
|
delta = float(value)
|
||||||
stats.processing_time_total = 0.0
|
stats.processing_time_total = _add_optional(stats.processing_time_total, delta) # type: ignore[attr-defined]
|
||||||
stats.processing_time_total += value
|
|
||||||
continue
|
continue
|
||||||
elif key == "last_processing_time":
|
elif key == "last_processing_time":
|
||||||
# 直接设置最后处理时间
|
# 直接设置最后处理时间
|
||||||
stats.last_process_time = value
|
stats.last_process_time = float(value)
|
||||||
continue
|
continue
|
||||||
elif hasattr(stats, key):
|
elif hasattr(stats, key):
|
||||||
if key in [
|
if key in [
|
||||||
@@ -79,12 +111,10 @@ class AntiInjectionStatistics:
|
|||||||
"shielded_messages",
|
"shielded_messages",
|
||||||
"error_count",
|
"error_count",
|
||||||
]:
|
]:
|
||||||
# 累加类型的字段 - 确保不为None
|
# 累加类型的字段 - 统一用辅助函数
|
||||||
current_value = getattr(stats, key)
|
current_value = cast(Optional[int], getattr(stats, key))
|
||||||
if current_value is None:
|
increment = int(value)
|
||||||
setattr(stats, key, value)
|
setattr(stats, key, _add_optional(current_value, increment))
|
||||||
else:
|
|
||||||
setattr(stats, key, current_value + value)
|
|
||||||
else:
|
else:
|
||||||
# 直接设置的字段
|
# 直接设置的字段
|
||||||
setattr(stats, key, value)
|
setattr(stats, key, value)
|
||||||
@@ -114,10 +144,11 @@ class AntiInjectionStatistics:
|
|||||||
|
|
||||||
stats = await self.get_or_create_stats()
|
stats = await self.get_or_create_stats()
|
||||||
|
|
||||||
# 计算派生统计信息 - 处理None值
|
|
||||||
total_messages = stats.total_messages or 0
|
# 计算派生统计信息 - 处理 None 值
|
||||||
detected_injections = stats.detected_injections or 0
|
total_messages = stats.total_messages or 0 # type: ignore[attr-defined]
|
||||||
processing_time_total = stats.processing_time_total or 0.0
|
detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined]
|
||||||
|
processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined]
|
||||||
|
|
||||||
detection_rate = (detected_injections / total_messages * 100) if total_messages > 0 else 0
|
detection_rate = (detected_injections / total_messages * 100) if total_messages > 0 else 0
|
||||||
avg_processing_time = (processing_time_total / total_messages) if total_messages > 0 else 0
|
avg_processing_time = (processing_time_total / total_messages) if total_messages > 0 else 0
|
||||||
@@ -127,17 +158,22 @@ class AntiInjectionStatistics:
|
|||||||
current_time = datetime.datetime.now()
|
current_time = datetime.datetime.now()
|
||||||
uptime = current_time - self.session_start_time
|
uptime = current_time - self.session_start_time
|
||||||
|
|
||||||
|
last_proc = stats.last_process_time # type: ignore[attr-defined]
|
||||||
|
blocked_messages = stats.blocked_messages or 0 # type: ignore[attr-defined]
|
||||||
|
shielded_messages = stats.shielded_messages or 0 # type: ignore[attr-defined]
|
||||||
|
error_count = stats.error_count or 0 # type: ignore[attr-defined]
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"status": "enabled",
|
"status": "enabled",
|
||||||
"uptime": str(uptime),
|
"uptime": str(uptime),
|
||||||
"total_messages": total_messages,
|
"total_messages": total_messages,
|
||||||
"detected_injections": detected_injections,
|
"detected_injections": detected_injections,
|
||||||
"blocked_messages": stats.blocked_messages or 0,
|
"blocked_messages": blocked_messages,
|
||||||
"shielded_messages": stats.shielded_messages or 0,
|
"shielded_messages": shielded_messages,
|
||||||
"detection_rate": f"{detection_rate:.2f}%",
|
"detection_rate": f"{detection_rate:.2f}%",
|
||||||
"average_processing_time": f"{avg_processing_time:.3f}s",
|
"average_processing_time": f"{avg_processing_time:.3f}s",
|
||||||
"last_processing_time": f"{stats.last_process_time:.3f}s" if stats.last_process_time else "0.000s",
|
"last_processing_time": f"{last_proc:.3f}s" if last_proc else "0.000s",
|
||||||
"error_count": stats.error_count or 0,
|
"error_count": error_count,
|
||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"获取统计信息失败: {e}")
|
logger.error(f"获取统计信息失败: {e}")
|
||||||
@@ -149,7 +185,7 @@ class AntiInjectionStatistics:
|
|||||||
try:
|
try:
|
||||||
async with get_db_session() as session:
|
async with get_db_session() as session:
|
||||||
# 删除现有统计记录
|
# 删除现有统计记录
|
||||||
await session.execute(select(AntiInjectionStats).delete())
|
await session.execute(delete(AntiInjectionStats))
|
||||||
await session.commit()
|
await session.commit()
|
||||||
logger.info("统计信息已重置")
|
logger.info("统计信息已重置")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class UserBanManager:
|
|||||||
remaining_time = ban_duration - (datetime.datetime.now() - ban_record.created_at)
|
remaining_time = ban_duration - (datetime.datetime.now() - ban_record.created_at)
|
||||||
return False, None, f"用户被封禁中,剩余时间: {remaining_time}"
|
return False, None, f"用户被封禁中,剩余时间: {remaining_time}"
|
||||||
else:
|
else:
|
||||||
# 封禁已过期,重置违规次数
|
# 封禁已过期,重置违规次数与时间(模型已使用 Mapped 类型,可直接赋值)
|
||||||
ban_record.violation_num = 0
|
ban_record.violation_num = 0
|
||||||
ban_record.created_at = datetime.datetime.now()
|
ban_record.created_at = datetime.datetime.now()
|
||||||
await session.commit()
|
await session.commit()
|
||||||
@@ -92,7 +92,6 @@ class UserBanManager:
|
|||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
# 检查是否需要自动封禁
|
|
||||||
if ban_record.violation_num >= self.config.auto_ban_violation_threshold:
|
if ban_record.violation_num >= self.config.auto_ban_violation_threshold:
|
||||||
logger.warning(f"用户 {platform}:{user_id} 违规次数达到 {ban_record.violation_num},触发自动封禁")
|
logger.warning(f"用户 {platform}:{user_id} 违规次数达到 {ban_record.violation_num},触发自动封禁")
|
||||||
# 只有在首次达到阈值时才更新封禁开始时间
|
# 只有在首次达到阈值时才更新封禁开始时间
|
||||||
|
|||||||
@@ -377,11 +377,12 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], r
|
|||||||
|
|
||||||
class EmojiManager:
|
class EmojiManager:
|
||||||
_instance = None
|
_instance = None
|
||||||
|
_initialized: bool = False # 显式声明,避免属性未定义错误
|
||||||
|
|
||||||
def __new__(cls) -> "EmojiManager":
|
def __new__(cls) -> "EmojiManager":
|
||||||
if cls._instance is None:
|
if cls._instance is None:
|
||||||
cls._instance = super().__new__(cls)
|
cls._instance = super().__new__(cls)
|
||||||
cls._instance._initialized = False
|
# 类属性已声明,无需再次赋值
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@@ -399,7 +400,8 @@ class EmojiManager:
|
|||||||
self.emoji_num_max = global_config.emoji.max_reg_num
|
self.emoji_num_max = global_config.emoji.max_reg_num
|
||||||
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
|
self.emoji_num_max_reach_deletion = global_config.emoji.do_replace
|
||||||
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型
|
self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型
|
||||||
|
logger.info("启动表情包管理器")
|
||||||
|
self._initialized = True
|
||||||
logger.info("启动表情包管理器")
|
logger.info("启动表情包管理器")
|
||||||
|
|
||||||
def shutdown(self) -> None:
|
def shutdown(self) -> None:
|
||||||
@@ -752,8 +754,8 @@ class EmojiManager:
|
|||||||
try:
|
try:
|
||||||
emoji_record = await self.get_emoji_from_db(emoji_hash)
|
emoji_record = await self.get_emoji_from_db(emoji_hash)
|
||||||
if emoji_record and emoji_record[0].emotion:
|
if emoji_record and emoji_record[0].emotion:
|
||||||
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...")
|
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") # type: ignore # type: ignore
|
||||||
return emoji_record.emotion
|
return emoji_record.emotion # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
logger.error(f"从数据库查询表情包描述时出错: {e}")
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from chat.message_manager.adaptive_stream_manager import StreamPriority
|
||||||
from src.chat.chatter_manager import ChatterManager
|
from src.chat.chatter_manager import ChatterManager
|
||||||
from src.chat.energy_system import energy_manager
|
from src.chat.energy_system import energy_manager
|
||||||
from src.common.data_models.message_manager_data_model import StreamContext
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
|||||||
@@ -1,6 +1,14 @@
|
|||||||
"""SQLAlchemy数据库模型定义
|
"""SQLAlchemy数据库模型定义
|
||||||
|
|
||||||
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
||||||
|
|
||||||
|
说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到
|
||||||
|
SQLAlchemy 2.0 推荐的带类型注解的声明式风格:
|
||||||
|
|
||||||
|
field_name: Mapped[PyType] = mapped_column(Type, ...)
|
||||||
|
|
||||||
|
这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。
|
||||||
|
当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
@@ -103,31 +111,31 @@ class ChatStreams(Base):
|
|||||||
|
|
||||||
__tablename__ = "chat_streams"
|
__tablename__ = "chat_streams"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
|
stream_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, unique=True, index=True)
|
||||||
create_time = Column(Float, nullable=False)
|
create_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
group_platform = Column(Text, nullable=True)
|
group_platform: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
group_id = Column(get_string_field(100), nullable=True, index=True)
|
group_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True)
|
||||||
group_name = Column(Text, nullable=True)
|
group_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
last_active_time = Column(Float, nullable=False)
|
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
platform = Column(Text, nullable=False)
|
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
user_platform = Column(Text, nullable=False)
|
user_platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
user_id = Column(get_string_field(100), nullable=False, index=True)
|
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
||||||
user_nickname = Column(Text, nullable=False)
|
user_nickname: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
user_cardname = Column(Text, nullable=True)
|
user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
energy_value = Column(Float, nullable=True, default=5.0)
|
energy_value: Mapped[float | None] = mapped_column(Float, nullable=True, default=5.0)
|
||||||
sleep_pressure = Column(Float, nullable=True, default=0.0)
|
sleep_pressure: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0)
|
||||||
focus_energy = Column(Float, nullable=True, default=0.5)
|
focus_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5)
|
||||||
# 动态兴趣度系统字段
|
# 动态兴趣度系统字段
|
||||||
base_interest_energy = Column(Float, nullable=True, default=0.5)
|
base_interest_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5)
|
||||||
message_interest_total = Column(Float, nullable=True, default=0.0)
|
message_interest_total: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0)
|
||||||
message_count = Column(Integer, nullable=True, default=0)
|
message_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
||||||
action_count = Column(Integer, nullable=True, default=0)
|
action_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
||||||
reply_count = Column(Integer, nullable=True, default=0)
|
reply_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
||||||
last_interaction_time = Column(Float, nullable=True, default=None)
|
last_interaction_time: Mapped[float | None] = mapped_column(Float, nullable=True, default=None)
|
||||||
consecutive_no_reply = Column(Integer, nullable=True, default=0)
|
consecutive_no_reply: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
||||||
# 消息打断系统字段
|
# 消息打断系统字段
|
||||||
interruption_count = Column(Integer, nullable=True, default=0)
|
interruption_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_chatstreams_stream_id", "stream_id"),
|
Index("idx_chatstreams_stream_id", "stream_id"),
|
||||||
@@ -141,20 +149,20 @@ class LLMUsage(Base):
|
|||||||
|
|
||||||
__tablename__ = "llm_usage"
|
__tablename__ = "llm_usage"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
model_name = Column(get_string_field(100), nullable=False, index=True)
|
model_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
||||||
model_assign_name = Column(get_string_field(100), index=True) # 添加索引
|
model_assign_name: Mapped[str] = mapped_column(get_string_field(100), index=True)
|
||||||
model_api_provider = Column(get_string_field(100), index=True) # 添加索引
|
model_api_provider: Mapped[str] = mapped_column(get_string_field(100), index=True)
|
||||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
||||||
request_type = Column(get_string_field(50), nullable=False, index=True)
|
request_type: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
||||||
endpoint = Column(Text, nullable=False)
|
endpoint: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
prompt_tokens = Column(Integer, nullable=False)
|
prompt_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
completion_tokens = Column(Integer, nullable=False)
|
completion_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
time_cost = Column(Float, nullable=True)
|
time_cost: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
total_tokens = Column(Integer, nullable=False)
|
total_tokens: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
cost = Column(Float, nullable=False)
|
cost: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
status = Column(Text, nullable=False)
|
status: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_llmusage_model_name", "model_name"),
|
Index("idx_llmusage_model_name", "model_name"),
|
||||||
@@ -172,19 +180,19 @@ class Emoji(Base):
|
|||||||
|
|
||||||
__tablename__ = "emoji"
|
__tablename__ = "emoji"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
full_path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||||
format = Column(Text, nullable=False)
|
format: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||||
description = Column(Text, nullable=False)
|
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
query_count = Column(Integer, nullable=False, default=0)
|
query_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
is_registered = Column(Boolean, nullable=False, default=False)
|
is_registered: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
is_banned = Column(Boolean, nullable=False, default=False)
|
is_banned: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
emotion = Column(Text, nullable=True)
|
emotion: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
record_time = Column(Float, nullable=False)
|
record_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
register_time = Column(Float, nullable=True)
|
register_time: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
usage_count = Column(Integer, nullable=False, default=0)
|
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
last_used_time = Column(Float, nullable=True)
|
last_used_time: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_emoji_full_path", "full_path"),
|
Index("idx_emoji_full_path", "full_path"),
|
||||||
@@ -197,50 +205,50 @@ class Messages(Base):
|
|||||||
|
|
||||||
__tablename__ = "messages"
|
__tablename__ = "messages"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
message_id = Column(get_string_field(100), nullable=False, index=True)
|
message_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
||||||
time = Column(Float, nullable=False)
|
time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||||
reply_to = Column(Text, nullable=True)
|
reply_to: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
interest_value = Column(Float, nullable=True)
|
interest_value: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
key_words = Column(Text, nullable=True)
|
key_words: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
key_words_lite = Column(Text, nullable=True)
|
key_words_lite: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
is_mentioned = Column(Boolean, nullable=True)
|
is_mentioned: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||||
|
|
||||||
# 从 chat_info 扁平化而来的字段
|
# 从 chat_info 扁平化而来的字段
|
||||||
chat_info_stream_id = Column(Text, nullable=False)
|
chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
chat_info_platform = Column(Text, nullable=False)
|
chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
chat_info_user_platform = Column(Text, nullable=False)
|
chat_info_user_platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
chat_info_user_id = Column(Text, nullable=False)
|
chat_info_user_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
chat_info_user_nickname = Column(Text, nullable=False)
|
chat_info_user_nickname: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
chat_info_user_cardname = Column(Text, nullable=True)
|
chat_info_user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
chat_info_group_platform = Column(Text, nullable=True)
|
chat_info_group_platform: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
chat_info_group_id = Column(Text, nullable=True)
|
chat_info_group_id: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
chat_info_group_name = Column(Text, nullable=True)
|
chat_info_group_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
chat_info_create_time = Column(Float, nullable=False)
|
chat_info_create_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
chat_info_last_active_time = Column(Float, nullable=False)
|
chat_info_last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
|
||||||
# 从顶层 user_info 扁平化而来的字段
|
# 从顶层 user_info 扁平化而来的字段
|
||||||
user_platform = Column(Text, nullable=True)
|
user_platform: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
user_id = Column(get_string_field(100), nullable=True, index=True)
|
user_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True)
|
||||||
user_nickname = Column(Text, nullable=True)
|
user_nickname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
user_cardname = Column(Text, nullable=True)
|
user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
|
||||||
processed_plain_text = Column(Text, nullable=True)
|
processed_plain_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
display_message = Column(Text, nullable=True)
|
display_message: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
memorized_times = Column(Integer, nullable=False, default=0)
|
memorized_times: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
priority_mode = Column(Text, nullable=True)
|
priority_mode: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
priority_info = Column(Text, nullable=True)
|
priority_info: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
additional_config = Column(Text, nullable=True)
|
additional_config: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
is_emoji = Column(Boolean, nullable=False, default=False)
|
is_emoji: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
is_picid = Column(Boolean, nullable=False, default=False)
|
is_picid: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
is_command = Column(Boolean, nullable=False, default=False)
|
is_command: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
is_notify = Column(Boolean, nullable=False, default=False)
|
is_notify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
# 兴趣度系统字段
|
# 兴趣度系统字段
|
||||||
actions = Column(Text, nullable=True) # JSON格式存储动作列表
|
actions: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
should_reply = Column(Boolean, nullable=True, default=False)
|
should_reply: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False)
|
||||||
should_act = Column(Boolean, nullable=True, default=False)
|
should_act: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_messages_message_id", "message_id"),
|
Index("idx_messages_message_id", "message_id"),
|
||||||
@@ -257,17 +265,17 @@ class ActionRecords(Base):
|
|||||||
|
|
||||||
__tablename__ = "action_records"
|
__tablename__ = "action_records"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
action_id = Column(get_string_field(100), nullable=False, index=True)
|
action_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
||||||
time = Column(Float, nullable=False)
|
time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
action_name = Column(Text, nullable=False)
|
action_name: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
action_data = Column(Text, nullable=False)
|
action_data: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
action_done = Column(Boolean, nullable=False, default=False)
|
action_done: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
|
action_build_into_prompt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
action_prompt_display = Column(Text, nullable=False)
|
action_prompt_display: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||||
chat_info_stream_id = Column(Text, nullable=False)
|
chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
chat_info_platform = Column(Text, nullable=False)
|
chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_actionrecords_action_id", "action_id"),
|
Index("idx_actionrecords_action_id", "action_id"),
|
||||||
@@ -281,15 +289,15 @@ class Images(Base):
|
|||||||
|
|
||||||
__tablename__ = "images"
|
__tablename__ = "images"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
image_id = Column(Text, nullable=False, default="")
|
image_id: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||||
description = Column(Text, nullable=True)
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
path = Column(get_string_field(500), nullable=False, unique=True)
|
path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True)
|
||||||
count = Column(Integer, nullable=False, default=1)
|
count: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||||
timestamp = Column(Float, nullable=False)
|
timestamp: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
type = Column(Text, nullable=False)
|
type: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_images_emoji_hash", "emoji_hash"),
|
Index("idx_images_emoji_hash", "emoji_hash"),
|
||||||
@@ -302,11 +310,11 @@ class ImageDescriptions(Base):
|
|||||||
|
|
||||||
__tablename__ = "image_descriptions"
|
__tablename__ = "image_descriptions"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
type = Column(Text, nullable=False)
|
type: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
|
image_description_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||||
description = Column(Text, nullable=False)
|
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
timestamp = Column(Float, nullable=False)
|
timestamp: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),)
|
__table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),)
|
||||||
|
|
||||||
@@ -316,20 +324,20 @@ class Videos(Base):
|
|||||||
|
|
||||||
__tablename__ = "videos"
|
__tablename__ = "videos"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
video_id = Column(Text, nullable=False, default="")
|
video_id: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
video_hash = Column(get_string_field(64), nullable=False, index=True, unique=True)
|
video_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True, unique=True)
|
||||||
description = Column(Text, nullable=True)
|
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
count = Column(Integer, nullable=False, default=1)
|
count: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||||
timestamp = Column(Float, nullable=False)
|
timestamp: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
vlm_processed = Column(Boolean, nullable=False, default=False)
|
vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
# 视频特有属性
|
# 视频特有属性
|
||||||
duration = Column(Float, nullable=True) # 视频时长(秒)
|
duration: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
frame_count = Column(Integer, nullable=True) # 总帧数
|
frame_count: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
fps = Column(Float, nullable=True) # 帧率
|
fps: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
resolution = Column(Text, nullable=True) # 分辨率
|
resolution: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
file_size = Column(Integer, nullable=True) # 文件大小(字节)
|
file_size: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_videos_video_hash", "video_hash"),
|
Index("idx_videos_video_hash", "video_hash"),
|
||||||
@@ -342,11 +350,11 @@ class OnlineTime(Base):
|
|||||||
|
|
||||||
__tablename__ = "online_time"
|
__tablename__ = "online_time"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
|
timestamp: Mapped[str] = mapped_column(Text, nullable=False, default=str(datetime.datetime.now))
|
||||||
duration = Column(Integer, nullable=False)
|
duration: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
start_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
end_timestamp = Column(DateTime, nullable=False, index=True)
|
end_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True)
|
||||||
|
|
||||||
__table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),)
|
__table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),)
|
||||||
|
|
||||||
@@ -356,22 +364,22 @@ class PersonInfo(Base):
|
|||||||
|
|
||||||
__tablename__ = "person_info"
|
__tablename__ = "person_info"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
|
person_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True)
|
||||||
person_name = Column(Text, nullable=True)
|
person_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
name_reason = Column(Text, nullable=True)
|
name_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
platform = Column(Text, nullable=False)
|
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
||||||
nickname = Column(Text, nullable=True)
|
nickname: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
impression = Column(Text, nullable=True)
|
impression: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
short_impression = Column(Text, nullable=True)
|
short_impression: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
points = Column(Text, nullable=True)
|
points: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
forgotten_points = Column(Text, nullable=True)
|
forgotten_points: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
info_list = Column(Text, nullable=True)
|
info_list: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
know_times = Column(Float, nullable=True)
|
know_times: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
know_since = Column(Float, nullable=True)
|
know_since: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
last_know = Column(Float, nullable=True)
|
last_know: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
attitude = Column(Integer, nullable=True, default=50)
|
attitude: Mapped[int | None] = mapped_column(Integer, nullable=True, default=50)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_personinfo_person_id", "person_id"),
|
Index("idx_personinfo_person_id", "person_id"),
|
||||||
@@ -384,13 +392,13 @@ class BotPersonalityInterests(Base):
|
|||||||
|
|
||||||
__tablename__ = "bot_personality_interests"
|
__tablename__ = "bot_personality_interests"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
personality_id = Column(get_string_field(100), nullable=False, index=True)
|
personality_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
||||||
personality_description = Column(Text, nullable=False)
|
personality_description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
|
interest_tags: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
|
embedding_model: Mapped[str] = mapped_column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
|
||||||
version = Column(Integer, nullable=False, default=1)
|
version: Mapped[int] = mapped_column(Integer, nullable=False, default=1)
|
||||||
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
last_updated: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_botpersonality_personality_id", "personality_id"),
|
Index("idx_botpersonality_personality_id", "personality_id"),
|
||||||
@@ -404,13 +412,13 @@ class Memory(Base):
|
|||||||
|
|
||||||
__tablename__ = "memory"
|
__tablename__ = "memory"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
memory_id = Column(get_string_field(64), nullable=False, index=True)
|
memory_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||||
chat_id = Column(Text, nullable=True)
|
chat_id: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
memory_text = Column(Text, nullable=True)
|
memory_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
keywords = Column(Text, nullable=True)
|
keywords: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
create_time = Column(Float, nullable=True)
|
create_time: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
last_view_time = Column(Float, nullable=True)
|
last_view_time: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
|
|
||||||
__table_args__ = (Index("idx_memory_memory_id", "memory_id"),)
|
__table_args__ = (Index("idx_memory_memory_id", "memory_id"),)
|
||||||
|
|
||||||
@@ -437,19 +445,19 @@ class ThinkingLog(Base):
|
|||||||
|
|
||||||
__tablename__ = "thinking_logs"
|
__tablename__ = "thinking_logs"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||||
trigger_text = Column(Text, nullable=True)
|
trigger_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
response_text = Column(Text, nullable=True)
|
response_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
trigger_info_json = Column(Text, nullable=True)
|
trigger_info_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
response_info_json = Column(Text, nullable=True)
|
response_info_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
timing_results_json = Column(Text, nullable=True)
|
timing_results_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
chat_history_json = Column(Text, nullable=True)
|
chat_history_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
chat_history_in_thinking_json = Column(Text, nullable=True)
|
chat_history_in_thinking_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
chat_history_after_response_json = Column(Text, nullable=True)
|
chat_history_after_response_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
heartflow_data_json = Column(Text, nullable=True)
|
heartflow_data_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
reasoning_data_json = Column(Text, nullable=True)
|
reasoning_data_json: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
|
||||||
__table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),)
|
__table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),)
|
||||||
|
|
||||||
@@ -459,13 +467,13 @@ class GraphNodes(Base):
|
|||||||
|
|
||||||
__tablename__ = "graph_nodes"
|
__tablename__ = "graph_nodes"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
concept: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True)
|
||||||
memory_items = Column(Text, nullable=False)
|
memory_items: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
hash = Column(Text, nullable=False)
|
hash: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
weight = Column(Float, nullable=False, default=1.0)
|
weight: Mapped[float] = mapped_column(Float, nullable=False, default=1.0)
|
||||||
created_time = Column(Float, nullable=False)
|
created_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
last_modified = Column(Float, nullable=False)
|
last_modified: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (Index("idx_graphnodes_concept", "concept"),)
|
__table_args__ = (Index("idx_graphnodes_concept", "concept"),)
|
||||||
|
|
||||||
@@ -475,13 +483,13 @@ class GraphEdges(Base):
|
|||||||
|
|
||||||
__tablename__ = "graph_edges"
|
__tablename__ = "graph_edges"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
source = Column(get_string_field(255), nullable=False, index=True)
|
source: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True)
|
||||||
target = Column(get_string_field(255), nullable=False, index=True)
|
target: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True)
|
||||||
strength = Column(Integer, nullable=False)
|
strength: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
hash = Column(Text, nullable=False)
|
hash: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
created_time = Column(Float, nullable=False)
|
created_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
last_modified = Column(Float, nullable=False)
|
last_modified: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_graphedges_source", "source"),
|
Index("idx_graphedges_source", "source"),
|
||||||
@@ -494,11 +502,11 @@ class Schedule(Base):
|
|||||||
|
|
||||||
__tablename__ = "schedule"
|
__tablename__ = "schedule"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式
|
date: Mapped[str] = mapped_column(get_string_field(10), nullable=False, unique=True, index=True)
|
||||||
schedule_data = Column(Text, nullable=False) # JSON格式的日程数据
|
schedule_data: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||||
|
|
||||||
__table_args__ = (Index("idx_schedule_date", "date"),)
|
__table_args__ = (Index("idx_schedule_date", "date"),)
|
||||||
|
|
||||||
@@ -508,17 +516,15 @@ class MaiZoneScheduleStatus(Base):
|
|||||||
|
|
||||||
__tablename__ = "maizone_schedule_status"
|
__tablename__ = "maizone_schedule_status"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
datetime_hour = Column(
|
datetime_hour: Mapped[str] = mapped_column(get_string_field(13), nullable=False, unique=True, index=True)
|
||||||
get_string_field(13), nullable=False, unique=True, index=True
|
activity: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
) # YYYY-MM-DD HH格式,精确到小时
|
is_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
activity = Column(Text, nullable=False) # 该小时的活动内容
|
processed_at: Mapped[datetime.datetime | None] = mapped_column(DateTime, nullable=True)
|
||||||
is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理
|
story_content: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
processed_at = Column(DateTime, nullable=True) # 处理时间
|
send_success: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
story_content = Column(Text, nullable=True) # 生成的说说内容
|
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
send_success = Column(Boolean, nullable=False, default=False) # 是否发送成功
|
updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_maizone_datetime_hour", "datetime_hour"),
|
Index("idx_maizone_datetime_hour", "datetime_hour"),
|
||||||
@@ -527,16 +533,20 @@ class MaiZoneScheduleStatus(Base):
|
|||||||
|
|
||||||
|
|
||||||
class BanUser(Base):
|
class BanUser(Base):
|
||||||
"""被禁用用户模型"""
|
"""被禁用用户模型
|
||||||
|
|
||||||
|
使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型,
|
||||||
|
避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "ban_users"
|
__tablename__ = "ban_users"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
platform = Column(Text, nullable=False)
|
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
user_id = Column(get_string_field(50), nullable=False, index=True)
|
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
||||||
violation_num = Column(Integer, nullable=False, default=0)
|
violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||||
reason = Column(Text, nullable=False)
|
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_violation_num", "violation_num"),
|
Index("idx_violation_num", "violation_num"),
|
||||||
@@ -551,38 +561,38 @@ class AntiInjectionStats(Base):
|
|||||||
|
|
||||||
__tablename__ = "anti_injection_stats"
|
__tablename__ = "anti_injection_stats"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
total_messages = Column(Integer, nullable=False, default=0)
|
total_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
"""总处理消息数"""
|
"""总处理消息数"""
|
||||||
|
|
||||||
detected_injections = Column(Integer, nullable=False, default=0)
|
detected_injections: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
"""检测到的注入攻击数"""
|
"""检测到的注入攻击数"""
|
||||||
|
|
||||||
blocked_messages = Column(Integer, nullable=False, default=0)
|
blocked_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
"""被阻止的消息数"""
|
"""被阻止的消息数"""
|
||||||
|
|
||||||
shielded_messages = Column(Integer, nullable=False, default=0)
|
shielded_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
"""被加盾的消息数"""
|
"""被加盾的消息数"""
|
||||||
|
|
||||||
processing_time_total = Column(Float, nullable=False, default=0.0)
|
processing_time_total: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
"""总处理时间"""
|
"""总处理时间"""
|
||||||
|
|
||||||
total_process_time = Column(Float, nullable=False, default=0.0)
|
total_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
"""累计总处理时间"""
|
"""累计总处理时间"""
|
||||||
|
|
||||||
last_process_time = Column(Float, nullable=False, default=0.0)
|
last_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
"""最近一次处理时间"""
|
"""最近一次处理时间"""
|
||||||
|
|
||||||
error_count = Column(Integer, nullable=False, default=0)
|
error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
"""错误计数"""
|
"""错误计数"""
|
||||||
|
|
||||||
start_time = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
start_time: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
"""统计开始时间"""
|
"""统计开始时间"""
|
||||||
|
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
"""记录创建时间"""
|
"""记录创建时间"""
|
||||||
|
|
||||||
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||||
"""记录更新时间"""
|
"""记录更新时间"""
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
@@ -596,26 +606,26 @@ class CacheEntries(Base):
|
|||||||
|
|
||||||
__tablename__ = "cache_entries"
|
__tablename__ = "cache_entries"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
cache_key: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||||
"""缓存键,包含工具名、参数和代码哈希"""
|
"""缓存键,包含工具名、参数和代码哈希"""
|
||||||
|
|
||||||
cache_value = Column(Text, nullable=False)
|
cache_value: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
"""缓存的数据,JSON格式"""
|
"""缓存的数据,JSON格式"""
|
||||||
|
|
||||||
expires_at = Column(Float, nullable=False, index=True)
|
expires_at: Mapped[float] = mapped_column(Float, nullable=False, index=True)
|
||||||
"""过期时间戳"""
|
"""过期时间戳"""
|
||||||
|
|
||||||
tool_name = Column(get_string_field(100), nullable=False, index=True)
|
tool_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
||||||
"""工具名称"""
|
"""工具名称"""
|
||||||
|
|
||||||
created_at = Column(Float, nullable=False, default=lambda: time.time())
|
created_at: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time())
|
||||||
"""创建时间戳"""
|
"""创建时间戳"""
|
||||||
|
|
||||||
last_accessed = Column(Float, nullable=False, default=lambda: time.time())
|
last_accessed: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time())
|
||||||
"""最后访问时间戳"""
|
"""最后访问时间戳"""
|
||||||
|
|
||||||
access_count = Column(Integer, nullable=False, default=0)
|
access_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
"""访问次数"""
|
"""访问次数"""
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
@@ -631,18 +641,16 @@ class MonthlyPlan(Base):
|
|||||||
|
|
||||||
__tablename__ = "monthly_plans"
|
__tablename__ = "monthly_plans"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
plan_text = Column(Text, nullable=False)
|
plan_text: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM"
|
target_month: Mapped[str] = mapped_column(String(7), nullable=False, index=True)
|
||||||
status = Column(
|
status: Mapped[str] = mapped_column(get_string_field(20), nullable=False, default="active", index=True)
|
||||||
get_string_field(20), nullable=False, default="active", index=True
|
usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
) # 'active', 'completed', 'archived'
|
last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True)
|
||||||
usage_count = Column(Integer, nullable=False, default=0)
|
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format
|
|
||||||
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
|
||||||
|
|
||||||
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
|
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
|
||||||
is_deleted = Column(Boolean, nullable=False, default=False)
|
is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
|
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
|
||||||
@@ -807,12 +815,12 @@ class PermissionNodes(Base):
|
|||||||
|
|
||||||
__tablename__ = "permission_nodes"
|
__tablename__ = "permission_nodes"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称
|
node_name: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True)
|
||||||
description = Column(Text, nullable=False) # 权限描述
|
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件
|
plugin_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
||||||
default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权
|
default_granted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False)
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_permission_plugin", "plugin_name"),
|
Index("idx_permission_plugin", "plugin_name"),
|
||||||
@@ -825,13 +833,13 @@ class UserPermissions(Base):
|
|||||||
|
|
||||||
__tablename__ = "user_permissions"
|
__tablename__ = "user_permissions"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型
|
platform: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
||||||
user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID
|
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True)
|
||||||
permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称
|
permission_node: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True)
|
||||||
granted = Column(Boolean, default=True, nullable=False) # 是否授权
|
granted: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False)
|
||||||
granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间
|
granted_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||||
granted_by = Column(get_string_field(100), nullable=True) # 授权者信息
|
granted_by: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_user_platform_id", "platform", "user_id"),
|
Index("idx_user_platform_id", "platform", "user_id"),
|
||||||
@@ -845,13 +853,13 @@ class UserRelationships(Base):
|
|||||||
|
|
||||||
__tablename__ = "user_relationships"
|
__tablename__ = "user_relationships"
|
||||||
|
|
||||||
id = Column(Integer, primary_key=True, autoincrement=True)
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
|
user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True)
|
||||||
user_name = Column(get_string_field(100), nullable=True) # 用户名
|
user_name: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True)
|
||||||
relationship_text = Column(Text, nullable=True) # 关系印象描述
|
relationship_text: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
relationship_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
||||||
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
|
last_updated: Mapped[float] = mapped_column(Float, nullable=False, default=time.time)
|
||||||
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False)
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("idx_user_relationship_id", "user_id"),
|
Index("idx_user_relationship_id", "user_id"),
|
||||||
|
|||||||
872
src/common/database/sqlalchemy_models.py.bak
Normal file
872
src/common/database/sqlalchemy_models.py.bak
Normal file
@@ -0,0 +1,872 @@
|
|||||||
|
"""SQLAlchemy数据库模型定义
|
||||||
|
|
||||||
|
替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力
|
||||||
|
|
||||||
|
说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到
|
||||||
|
SQLAlchemy 2.0 推荐的带类型注解的声明式风格:
|
||||||
|
|
||||||
|
field_name: Mapped[PyType] = mapped_column(Type, ...)
|
||||||
|
|
||||||
|
这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。
|
||||||
|
当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import datetime
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column
|
||||||
|
|
||||||
|
from src.common.database.connection_pool_manager import get_connection_pool_manager
|
||||||
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("sqlalchemy_models")
|
||||||
|
|
||||||
|
# 创建基类
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
async def enable_sqlite_wal_mode(engine):
|
||||||
|
"""为 SQLite 启用 WAL 模式以提高并发性能"""
|
||||||
|
try:
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
# 启用 WAL 模式
|
||||||
|
await conn.execute(text("PRAGMA journal_mode = WAL"))
|
||||||
|
# 设置适中的同步级别,平衡性能和安全性
|
||||||
|
await conn.execute(text("PRAGMA synchronous = NORMAL"))
|
||||||
|
# 启用外键约束
|
||||||
|
await conn.execute(text("PRAGMA foreign_keys = ON"))
|
||||||
|
# 设置 busy_timeout,避免锁定错误
|
||||||
|
await conn.execute(text("PRAGMA busy_timeout = 60000")) # 60秒
|
||||||
|
|
||||||
|
logger.info("[SQLite] WAL 模式已启用,并发性能已优化")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[SQLite] 启用 WAL 模式失败: {e},将使用默认配置")
|
||||||
|
|
||||||
|
|
||||||
|
async def maintain_sqlite_database():
|
||||||
|
"""定期维护 SQLite 数据库性能"""
|
||||||
|
try:
|
||||||
|
engine, SessionLocal = await initialize_database()
|
||||||
|
if not engine:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with engine.begin() as conn:
|
||||||
|
# 检查并确保 WAL 模式仍然启用
|
||||||
|
result = await conn.execute(text("PRAGMA journal_mode"))
|
||||||
|
journal_mode = result.scalar()
|
||||||
|
|
||||||
|
if journal_mode != "wal":
|
||||||
|
await conn.execute(text("PRAGMA journal_mode = WAL"))
|
||||||
|
logger.info("[SQLite] WAL 模式已重新启用")
|
||||||
|
|
||||||
|
# 优化数据库性能
|
||||||
|
await conn.execute(text("PRAGMA synchronous = NORMAL"))
|
||||||
|
await conn.execute(text("PRAGMA busy_timeout = 60000"))
|
||||||
|
await conn.execute(text("PRAGMA foreign_keys = ON"))
|
||||||
|
|
||||||
|
# 定期清理(可选,根据需要启用)
|
||||||
|
# await conn.execute(text("PRAGMA optimize"))
|
||||||
|
|
||||||
|
logger.info("[SQLite] 数据库维护完成")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[SQLite] 数据库维护失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_sqlite_performance_config():
|
||||||
|
"""获取 SQLite 性能优化配置"""
|
||||||
|
return {
|
||||||
|
"journal_mode": "WAL", # 提高并发性能
|
||||||
|
"synchronous": "NORMAL", # 平衡性能和安全性
|
||||||
|
"busy_timeout": 60000, # 60秒超时
|
||||||
|
"foreign_keys": "ON", # 启用外键约束
|
||||||
|
"cache_size": -10000, # 10MB 缓存
|
||||||
|
"temp_store": "MEMORY", # 临时存储使用内存
|
||||||
|
"mmap_size": 268435456, # 256MB 内存映射
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# MySQL兼容的字段类型辅助函数
|
||||||
|
def get_string_field(max_length=255, **kwargs):
|
||||||
|
"""
|
||||||
|
根据数据库类型返回合适的字符串字段
|
||||||
|
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
||||||
|
"""
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
if global_config.database.database_type == "mysql":
|
||||||
|
return String(max_length, **kwargs)
|
||||||
|
else:
|
||||||
|
return Text(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatStreams(Base):
|
||||||
|
"""聊天流模型"""
|
||||||
|
|
||||||
|
__tablename__ = "chat_streams"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True)
|
||||||
|
create_time = Column(Float, nullable=False)
|
||||||
|
group_platform = Column(Text, nullable=True)
|
||||||
|
group_id = Column(get_string_field(100), nullable=True, index=True)
|
||||||
|
group_name = Column(Text, nullable=True)
|
||||||
|
last_active_time = Column(Float, nullable=False)
|
||||||
|
platform = Column(Text, nullable=False)
|
||||||
|
user_platform = Column(Text, nullable=False)
|
||||||
|
user_id = Column(get_string_field(100), nullable=False, index=True)
|
||||||
|
user_nickname = Column(Text, nullable=False)
|
||||||
|
user_cardname = Column(Text, nullable=True)
|
||||||
|
energy_value = Column(Float, nullable=True, default=5.0)
|
||||||
|
sleep_pressure = Column(Float, nullable=True, default=0.0)
|
||||||
|
focus_energy = Column(Float, nullable=True, default=0.5)
|
||||||
|
# 动态兴趣度系统字段
|
||||||
|
base_interest_energy = Column(Float, nullable=True, default=0.5)
|
||||||
|
message_interest_total = Column(Float, nullable=True, default=0.0)
|
||||||
|
message_count = Column(Integer, nullable=True, default=0)
|
||||||
|
action_count = Column(Integer, nullable=True, default=0)
|
||||||
|
reply_count = Column(Integer, nullable=True, default=0)
|
||||||
|
last_interaction_time = Column(Float, nullable=True, default=None)
|
||||||
|
consecutive_no_reply = Column(Integer, nullable=True, default=0)
|
||||||
|
# 消息打断系统字段
|
||||||
|
interruption_count = Column(Integer, nullable=True, default=0)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_chatstreams_stream_id", "stream_id"),
|
||||||
|
Index("idx_chatstreams_user_id", "user_id"),
|
||||||
|
Index("idx_chatstreams_group_id", "group_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class LLMUsage(Base):
|
||||||
|
"""LLM使用记录模型"""
|
||||||
|
|
||||||
|
__tablename__ = "llm_usage"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
model_name = Column(get_string_field(100), nullable=False, index=True)
|
||||||
|
model_assign_name = Column(get_string_field(100), index=True) # 添加索引
|
||||||
|
model_api_provider = Column(get_string_field(100), index=True) # 添加索引
|
||||||
|
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||||
|
request_type = Column(get_string_field(50), nullable=False, index=True)
|
||||||
|
endpoint = Column(Text, nullable=False)
|
||||||
|
prompt_tokens = Column(Integer, nullable=False)
|
||||||
|
completion_tokens = Column(Integer, nullable=False)
|
||||||
|
time_cost = Column(Float, nullable=True)
|
||||||
|
total_tokens = Column(Integer, nullable=False)
|
||||||
|
cost = Column(Float, nullable=False)
|
||||||
|
status = Column(Text, nullable=False)
|
||||||
|
timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_llmusage_model_name", "model_name"),
|
||||||
|
Index("idx_llmusage_model_assign_name", "model_assign_name"),
|
||||||
|
Index("idx_llmusage_model_api_provider", "model_api_provider"),
|
||||||
|
Index("idx_llmusage_time_cost", "time_cost"),
|
||||||
|
Index("idx_llmusage_user_id", "user_id"),
|
||||||
|
Index("idx_llmusage_request_type", "request_type"),
|
||||||
|
Index("idx_llmusage_timestamp", "timestamp"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Emoji(Base):
|
||||||
|
"""表情包模型"""
|
||||||
|
|
||||||
|
__tablename__ = "emoji"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
full_path = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||||
|
format = Column(Text, nullable=False)
|
||||||
|
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||||
|
description = Column(Text, nullable=False)
|
||||||
|
query_count = Column(Integer, nullable=False, default=0)
|
||||||
|
is_registered = Column(Boolean, nullable=False, default=False)
|
||||||
|
is_banned = Column(Boolean, nullable=False, default=False)
|
||||||
|
emotion = Column(Text, nullable=True)
|
||||||
|
record_time = Column(Float, nullable=False)
|
||||||
|
register_time = Column(Float, nullable=True)
|
||||||
|
usage_count = Column(Integer, nullable=False, default=0)
|
||||||
|
last_used_time = Column(Float, nullable=True)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_emoji_full_path", "full_path"),
|
||||||
|
Index("idx_emoji_hash", "emoji_hash"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Messages(Base):
|
||||||
|
"""消息模型"""
|
||||||
|
|
||||||
|
__tablename__ = "messages"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
message_id = Column(get_string_field(100), nullable=False, index=True)
|
||||||
|
time = Column(Float, nullable=False)
|
||||||
|
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||||
|
reply_to = Column(Text, nullable=True)
|
||||||
|
interest_value = Column(Float, nullable=True)
|
||||||
|
key_words = Column(Text, nullable=True)
|
||||||
|
key_words_lite = Column(Text, nullable=True)
|
||||||
|
is_mentioned = Column(Boolean, nullable=True)
|
||||||
|
|
||||||
|
# 从 chat_info 扁平化而来的字段
|
||||||
|
chat_info_stream_id = Column(Text, nullable=False)
|
||||||
|
chat_info_platform = Column(Text, nullable=False)
|
||||||
|
chat_info_user_platform = Column(Text, nullable=False)
|
||||||
|
chat_info_user_id = Column(Text, nullable=False)
|
||||||
|
chat_info_user_nickname = Column(Text, nullable=False)
|
||||||
|
chat_info_user_cardname = Column(Text, nullable=True)
|
||||||
|
chat_info_group_platform = Column(Text, nullable=True)
|
||||||
|
chat_info_group_id = Column(Text, nullable=True)
|
||||||
|
chat_info_group_name = Column(Text, nullable=True)
|
||||||
|
chat_info_create_time = Column(Float, nullable=False)
|
||||||
|
chat_info_last_active_time = Column(Float, nullable=False)
|
||||||
|
|
||||||
|
# 从顶层 user_info 扁平化而来的字段
|
||||||
|
user_platform = Column(Text, nullable=True)
|
||||||
|
user_id = Column(get_string_field(100), nullable=True, index=True)
|
||||||
|
user_nickname = Column(Text, nullable=True)
|
||||||
|
user_cardname = Column(Text, nullable=True)
|
||||||
|
|
||||||
|
processed_plain_text = Column(Text, nullable=True)
|
||||||
|
display_message = Column(Text, nullable=True)
|
||||||
|
memorized_times = Column(Integer, nullable=False, default=0)
|
||||||
|
priority_mode = Column(Text, nullable=True)
|
||||||
|
priority_info = Column(Text, nullable=True)
|
||||||
|
additional_config = Column(Text, nullable=True)
|
||||||
|
is_emoji = Column(Boolean, nullable=False, default=False)
|
||||||
|
is_picid = Column(Boolean, nullable=False, default=False)
|
||||||
|
is_command = Column(Boolean, nullable=False, default=False)
|
||||||
|
is_notify = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
# 兴趣度系统字段
|
||||||
|
actions = Column(Text, nullable=True) # JSON格式存储动作列表
|
||||||
|
should_reply = Column(Boolean, nullable=True, default=False)
|
||||||
|
should_act = Column(Boolean, nullable=True, default=False)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_messages_message_id", "message_id"),
|
||||||
|
Index("idx_messages_chat_id", "chat_id"),
|
||||||
|
Index("idx_messages_time", "time"),
|
||||||
|
Index("idx_messages_user_id", "user_id"),
|
||||||
|
Index("idx_messages_should_reply", "should_reply"),
|
||||||
|
Index("idx_messages_should_act", "should_act"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionRecords(Base):
|
||||||
|
"""动作记录模型"""
|
||||||
|
|
||||||
|
__tablename__ = "action_records"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
action_id = Column(get_string_field(100), nullable=False, index=True)
|
||||||
|
time = Column(Float, nullable=False)
|
||||||
|
action_name = Column(Text, nullable=False)
|
||||||
|
action_data = Column(Text, nullable=False)
|
||||||
|
action_done = Column(Boolean, nullable=False, default=False)
|
||||||
|
action_build_into_prompt = Column(Boolean, nullable=False, default=False)
|
||||||
|
action_prompt_display = Column(Text, nullable=False)
|
||||||
|
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||||
|
chat_info_stream_id = Column(Text, nullable=False)
|
||||||
|
chat_info_platform = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_actionrecords_action_id", "action_id"),
|
||||||
|
Index("idx_actionrecords_chat_id", "chat_id"),
|
||||||
|
Index("idx_actionrecords_time", "time"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Images(Base):
|
||||||
|
"""图像信息模型"""
|
||||||
|
|
||||||
|
__tablename__ = "images"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
image_id = Column(Text, nullable=False, default="")
|
||||||
|
emoji_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
path = Column(get_string_field(500), nullable=False, unique=True)
|
||||||
|
count = Column(Integer, nullable=False, default=1)
|
||||||
|
timestamp = Column(Float, nullable=False)
|
||||||
|
type = Column(Text, nullable=False)
|
||||||
|
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_images_emoji_hash", "emoji_hash"),
|
||||||
|
Index("idx_images_path", "path"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ImageDescriptions(Base):
|
||||||
|
"""图像描述信息模型"""
|
||||||
|
|
||||||
|
__tablename__ = "image_descriptions"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
type = Column(Text, nullable=False)
|
||||||
|
image_description_hash = Column(get_string_field(64), nullable=False, index=True)
|
||||||
|
description = Column(Text, nullable=False)
|
||||||
|
timestamp = Column(Float, nullable=False)
|
||||||
|
|
||||||
|
__table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),)
|
||||||
|
|
||||||
|
|
||||||
|
class Videos(Base):
|
||||||
|
"""视频信息模型"""
|
||||||
|
|
||||||
|
__tablename__ = "videos"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
video_id = Column(Text, nullable=False, default="")
|
||||||
|
video_hash = Column(get_string_field(64), nullable=False, index=True, unique=True)
|
||||||
|
description = Column(Text, nullable=True)
|
||||||
|
count = Column(Integer, nullable=False, default=1)
|
||||||
|
timestamp = Column(Float, nullable=False)
|
||||||
|
vlm_processed = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
# 视频特有属性
|
||||||
|
duration = Column(Float, nullable=True) # 视频时长(秒)
|
||||||
|
frame_count = Column(Integer, nullable=True) # 总帧数
|
||||||
|
fps = Column(Float, nullable=True) # 帧率
|
||||||
|
resolution = Column(Text, nullable=True) # 分辨率
|
||||||
|
file_size = Column(Integer, nullable=True) # 文件大小(字节)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_videos_video_hash", "video_hash"),
|
||||||
|
Index("idx_videos_timestamp", "timestamp"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OnlineTime(Base):
|
||||||
|
"""在线时长记录模型"""
|
||||||
|
|
||||||
|
__tablename__ = "online_time"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now))
|
||||||
|
duration = Column(Integer, nullable=False)
|
||||||
|
start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
end_timestamp = Column(DateTime, nullable=False, index=True)
|
||||||
|
|
||||||
|
__table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),)
|
||||||
|
|
||||||
|
|
||||||
|
class PersonInfo(Base):
|
||||||
|
"""人物信息模型"""
|
||||||
|
|
||||||
|
__tablename__ = "person_info"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
person_id = Column(get_string_field(100), nullable=False, unique=True, index=True)
|
||||||
|
person_name = Column(Text, nullable=True)
|
||||||
|
name_reason = Column(Text, nullable=True)
|
||||||
|
platform = Column(Text, nullable=False)
|
||||||
|
user_id = Column(get_string_field(50), nullable=False, index=True)
|
||||||
|
nickname = Column(Text, nullable=True)
|
||||||
|
impression = Column(Text, nullable=True)
|
||||||
|
short_impression = Column(Text, nullable=True)
|
||||||
|
points = Column(Text, nullable=True)
|
||||||
|
forgotten_points = Column(Text, nullable=True)
|
||||||
|
info_list = Column(Text, nullable=True)
|
||||||
|
know_times = Column(Float, nullable=True)
|
||||||
|
know_since = Column(Float, nullable=True)
|
||||||
|
last_know = Column(Float, nullable=True)
|
||||||
|
attitude = Column(Integer, nullable=True, default=50)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_personinfo_person_id", "person_id"),
|
||||||
|
Index("idx_personinfo_user_id", "user_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BotPersonalityInterests(Base):
|
||||||
|
"""机器人人格兴趣标签模型"""
|
||||||
|
|
||||||
|
__tablename__ = "bot_personality_interests"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
personality_id = Column(get_string_field(100), nullable=False, index=True)
|
||||||
|
personality_description = Column(Text, nullable=False)
|
||||||
|
interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表
|
||||||
|
embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002")
|
||||||
|
version = Column(Integer, nullable=False, default=1)
|
||||||
|
last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_botpersonality_personality_id", "personality_id"),
|
||||||
|
Index("idx_botpersonality_version", "version"),
|
||||||
|
Index("idx_botpersonality_last_updated", "last_updated"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Memory(Base):
|
||||||
|
"""记忆模型"""
|
||||||
|
|
||||||
|
__tablename__ = "memory"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
memory_id = Column(get_string_field(64), nullable=False, index=True)
|
||||||
|
chat_id = Column(Text, nullable=True)
|
||||||
|
memory_text = Column(Text, nullable=True)
|
||||||
|
keywords = Column(Text, nullable=True)
|
||||||
|
create_time = Column(Float, nullable=True)
|
||||||
|
last_view_time = Column(Float, nullable=True)
|
||||||
|
|
||||||
|
__table_args__ = (Index("idx_memory_memory_id", "memory_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class Expression(Base):
|
||||||
|
"""表达风格模型"""
|
||||||
|
|
||||||
|
__tablename__ = "expression"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
situation: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
style: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
count: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
last_active_time: Mapped[float] = mapped_column(Float, nullable=False)
|
||||||
|
chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True)
|
||||||
|
type: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
create_date: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||||
|
|
||||||
|
__table_args__ = (Index("idx_expression_chat_id", "chat_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class ThinkingLog(Base):
|
||||||
|
"""思考日志模型"""
|
||||||
|
|
||||||
|
__tablename__ = "thinking_logs"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
chat_id = Column(get_string_field(64), nullable=False, index=True)
|
||||||
|
trigger_text = Column(Text, nullable=True)
|
||||||
|
response_text = Column(Text, nullable=True)
|
||||||
|
trigger_info_json = Column(Text, nullable=True)
|
||||||
|
response_info_json = Column(Text, nullable=True)
|
||||||
|
timing_results_json = Column(Text, nullable=True)
|
||||||
|
chat_history_json = Column(Text, nullable=True)
|
||||||
|
chat_history_in_thinking_json = Column(Text, nullable=True)
|
||||||
|
chat_history_after_response_json = Column(Text, nullable=True)
|
||||||
|
heartflow_data_json = Column(Text, nullable=True)
|
||||||
|
reasoning_data_json = Column(Text, nullable=True)
|
||||||
|
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
|
||||||
|
__table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphNodes(Base):
|
||||||
|
"""记忆图节点模型"""
|
||||||
|
|
||||||
|
__tablename__ = "graph_nodes"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
concept = Column(get_string_field(255), nullable=False, unique=True, index=True)
|
||||||
|
memory_items = Column(Text, nullable=False)
|
||||||
|
hash = Column(Text, nullable=False)
|
||||||
|
weight = Column(Float, nullable=False, default=1.0)
|
||||||
|
created_time = Column(Float, nullable=False)
|
||||||
|
last_modified = Column(Float, nullable=False)
|
||||||
|
|
||||||
|
__table_args__ = (Index("idx_graphnodes_concept", "concept"),)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphEdges(Base):
|
||||||
|
"""记忆图边模型"""
|
||||||
|
|
||||||
|
__tablename__ = "graph_edges"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
source = Column(get_string_field(255), nullable=False, index=True)
|
||||||
|
target = Column(get_string_field(255), nullable=False, index=True)
|
||||||
|
strength = Column(Integer, nullable=False)
|
||||||
|
hash = Column(Text, nullable=False)
|
||||||
|
created_time = Column(Float, nullable=False)
|
||||||
|
last_modified = Column(Float, nullable=False)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_graphedges_source", "source"),
|
||||||
|
Index("idx_graphedges_target", "target"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Schedule(Base):
|
||||||
|
"""日程模型"""
|
||||||
|
|
||||||
|
__tablename__ = "schedule"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式
|
||||||
|
schedule_data = Column(Text, nullable=False) # JSON格式的日程数据
|
||||||
|
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||||
|
|
||||||
|
__table_args__ = (Index("idx_schedule_date", "date"),)
|
||||||
|
|
||||||
|
|
||||||
|
class MaiZoneScheduleStatus(Base):
|
||||||
|
"""麦麦空间日程处理状态模型"""
|
||||||
|
|
||||||
|
__tablename__ = "maizone_schedule_status"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
datetime_hour = Column(
|
||||||
|
get_string_field(13), nullable=False, unique=True, index=True
|
||||||
|
) # YYYY-MM-DD HH格式,精确到小时
|
||||||
|
activity = Column(Text, nullable=False) # 该小时的活动内容
|
||||||
|
is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理
|
||||||
|
processed_at = Column(DateTime, nullable=True) # 处理时间
|
||||||
|
story_content = Column(Text, nullable=True) # 生成的说说内容
|
||||||
|
send_success = Column(Boolean, nullable=False, default=False) # 是否发送成功
|
||||||
|
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_maizone_datetime_hour", "datetime_hour"),
|
||||||
|
Index("idx_maizone_is_processed", "is_processed"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BanUser(Base):
|
||||||
|
"""被禁用用户模型
|
||||||
|
|
||||||
|
使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型,
|
||||||
|
避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "ban_users"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
||||||
|
violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||||
|
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_violation_num", "violation_num"),
|
||||||
|
Index("idx_banuser_user_id", "user_id"),
|
||||||
|
Index("idx_banuser_platform", "platform"),
|
||||||
|
Index("idx_banuser_platform_user_id", "platform", "user_id"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AntiInjectionStats(Base):
|
||||||
|
"""反注入系统统计模型"""
|
||||||
|
|
||||||
|
__tablename__ = "anti_injection_stats"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
total_messages = Column(Integer, nullable=False, default=0)
|
||||||
|
"""总处理消息数"""
|
||||||
|
|
||||||
|
detected_injections = Column(Integer, nullable=False, default=0)
|
||||||
|
"""检测到的注入攻击数"""
|
||||||
|
|
||||||
|
blocked_messages = Column(Integer, nullable=False, default=0)
|
||||||
|
"""被阻止的消息数"""
|
||||||
|
|
||||||
|
shielded_messages = Column(Integer, nullable=False, default=0)
|
||||||
|
"""被加盾的消息数"""
|
||||||
|
|
||||||
|
processing_time_total = Column(Float, nullable=False, default=0.0)
|
||||||
|
"""总处理时间"""
|
||||||
|
|
||||||
|
total_process_time = Column(Float, nullable=False, default=0.0)
|
||||||
|
"""累计总处理时间"""
|
||||||
|
|
||||||
|
last_process_time = Column(Float, nullable=False, default=0.0)
|
||||||
|
"""最近一次处理时间"""
|
||||||
|
|
||||||
|
error_count = Column(Integer, nullable=False, default=0)
|
||||||
|
"""错误计数"""
|
||||||
|
|
||||||
|
start_time = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
"""统计开始时间"""
|
||||||
|
|
||||||
|
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
"""记录创建时间"""
|
||||||
|
|
||||||
|
updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now)
|
||||||
|
"""记录更新时间"""
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_anti_injection_stats_created_at", "created_at"),
|
||||||
|
Index("idx_anti_injection_stats_updated_at", "updated_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CacheEntries(Base):
|
||||||
|
"""工具缓存条目模型"""
|
||||||
|
|
||||||
|
__tablename__ = "cache_entries"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True)
|
||||||
|
"""缓存键,包含工具名、参数和代码哈希"""
|
||||||
|
|
||||||
|
cache_value = Column(Text, nullable=False)
|
||||||
|
"""缓存的数据,JSON格式"""
|
||||||
|
|
||||||
|
expires_at = Column(Float, nullable=False, index=True)
|
||||||
|
"""过期时间戳"""
|
||||||
|
|
||||||
|
tool_name = Column(get_string_field(100), nullable=False, index=True)
|
||||||
|
"""工具名称"""
|
||||||
|
|
||||||
|
created_at = Column(Float, nullable=False, default=lambda: time.time())
|
||||||
|
"""创建时间戳"""
|
||||||
|
|
||||||
|
last_accessed = Column(Float, nullable=False, default=lambda: time.time())
|
||||||
|
"""最后访问时间戳"""
|
||||||
|
|
||||||
|
access_count = Column(Integer, nullable=False, default=0)
|
||||||
|
"""访问次数"""
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_cache_entries_key", "cache_key"),
|
||||||
|
Index("idx_cache_entries_expires_at", "expires_at"),
|
||||||
|
Index("idx_cache_entries_tool_name", "tool_name"),
|
||||||
|
Index("idx_cache_entries_created_at", "created_at"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MonthlyPlan(Base):
|
||||||
|
"""月度计划模型"""
|
||||||
|
|
||||||
|
__tablename__ = "monthly_plans"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
plan_text = Column(Text, nullable=False)
|
||||||
|
target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM"
|
||||||
|
status = Column(
|
||||||
|
get_string_field(20), nullable=False, default="active", index=True
|
||||||
|
) # 'active', 'completed', 'archived'
|
||||||
|
usage_count = Column(Integer, nullable=False, default=0)
|
||||||
|
last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format
|
||||||
|
created_at = Column(DateTime, nullable=False, default=datetime.datetime.now)
|
||||||
|
|
||||||
|
# 保留 is_deleted 字段以兼容现有数据,但标记为已弃用
|
||||||
|
is_deleted = Column(Boolean, nullable=False, default=False)
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_monthlyplan_target_month_status", "target_month", "status"),
|
||||||
|
Index("idx_monthlyplan_last_used_date", "last_used_date"),
|
||||||
|
Index("idx_monthlyplan_usage_count", "usage_count"),
|
||||||
|
# 保留旧索引以兼容
|
||||||
|
Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# 数据库引擎和会话管理
|
||||||
|
_engine = None
|
||||||
|
_SessionLocal = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_database_url():
|
||||||
|
"""获取数据库连接URL"""
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
config = global_config.database
|
||||||
|
|
||||||
|
if config.database_type == "mysql":
|
||||||
|
# 对用户名和密码进行URL编码,处理特殊字符
|
||||||
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
encoded_user = quote_plus(config.mysql_user)
|
||||||
|
encoded_password = quote_plus(config.mysql_password)
|
||||||
|
|
||||||
|
# 检查是否配置了Unix socket连接
|
||||||
|
if config.mysql_unix_socket:
|
||||||
|
# 使用Unix socket连接
|
||||||
|
encoded_socket = quote_plus(config.mysql_unix_socket)
|
||||||
|
return (
|
||||||
|
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||||
|
f"@/{config.mysql_database}"
|
||||||
|
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 使用标准TCP连接
|
||||||
|
return (
|
||||||
|
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
|
||||||
|
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||||
|
f"?charset={config.mysql_charset}"
|
||||||
|
)
|
||||||
|
else: # SQLite
|
||||||
|
# 如果是相对路径,则相对于项目根目录
|
||||||
|
if not os.path.isabs(config.sqlite_path):
|
||||||
|
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", ".."))
|
||||||
|
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||||
|
else:
|
||||||
|
db_path = config.sqlite_path
|
||||||
|
|
||||||
|
# 确保数据库目录存在
|
||||||
|
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||||
|
|
||||||
|
return f"sqlite+aiosqlite:///{db_path}"
|
||||||
|
|
||||||
|
|
||||||
|
async def initialize_database():
|
||||||
|
"""初始化异步数据库引擎和会话"""
|
||||||
|
global _engine, _SessionLocal
|
||||||
|
|
||||||
|
if _engine is not None:
|
||||||
|
return _engine, _SessionLocal
|
||||||
|
|
||||||
|
database_url = get_database_url()
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
config = global_config.database
|
||||||
|
|
||||||
|
# 配置引擎参数
|
||||||
|
engine_kwargs: dict[str, Any] = {
|
||||||
|
"echo": False, # 生产环境关闭SQL日志
|
||||||
|
"future": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.database_type == "mysql":
|
||||||
|
# MySQL连接池配置 - 异步引擎使用默认连接池
|
||||||
|
engine_kwargs.update(
|
||||||
|
{
|
||||||
|
"pool_size": config.connection_pool_size,
|
||||||
|
"max_overflow": config.connection_pool_size * 2,
|
||||||
|
"pool_timeout": config.connection_timeout,
|
||||||
|
"pool_recycle": 3600, # 1小时回收连接
|
||||||
|
"pool_pre_ping": True, # 连接前ping检查
|
||||||
|
"connect_args": {
|
||||||
|
"autocommit": config.mysql_autocommit,
|
||||||
|
"charset": config.mysql_charset,
|
||||||
|
"connect_timeout": config.connection_timeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# SQLite配置 - aiosqlite不支持连接池参数
|
||||||
|
engine_kwargs.update(
|
||||||
|
{
|
||||||
|
"connect_args": {
|
||||||
|
"check_same_thread": False,
|
||||||
|
"timeout": 60, # 增加超时时间
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_engine = create_async_engine(database_url, **engine_kwargs)
|
||||||
|
_SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
# 调用新的迁移函数,它会处理表的创建和列的添加
|
||||||
|
from src.common.database.db_migration import check_and_migrate_database
|
||||||
|
|
||||||
|
await check_and_migrate_database()
|
||||||
|
|
||||||
|
# 如果是 SQLite,启用 WAL 模式以提高并发性能
|
||||||
|
if config.database_type == "sqlite":
|
||||||
|
await enable_sqlite_wal_mode(_engine)
|
||||||
|
|
||||||
|
logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}")
|
||||||
|
return _engine, _SessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def get_db_session() -> AsyncGenerator[AsyncSession]:
|
||||||
|
"""
|
||||||
|
异步数据库会话上下文管理器。
|
||||||
|
在初始化失败时会yield None,调用方需要检查会话是否为None。
|
||||||
|
|
||||||
|
现在使用透明的连接池管理器来复用现有连接,提高并发性能。
|
||||||
|
"""
|
||||||
|
SessionLocal = None
|
||||||
|
try:
|
||||||
|
_, SessionLocal = await initialize_database()
|
||||||
|
if not SessionLocal:
|
||||||
|
raise RuntimeError("数据库会话工厂 (_SessionLocal) 未初始化。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"数据库初始化失败,无法创建会话: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# 使用连接池管理器获取会话
|
||||||
|
pool_manager = get_connection_pool_manager()
|
||||||
|
|
||||||
|
async with pool_manager.get_session(SessionLocal) as session:
|
||||||
|
# 对于 SQLite,在会话开始时设置 PRAGMA(仅对新连接)
|
||||||
|
from src.config.config import global_config
|
||||||
|
|
||||||
|
if global_config.database.database_type == "sqlite":
|
||||||
|
try:
|
||||||
|
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||||
|
await session.execute(text("PRAGMA foreign_keys = ON"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}")
|
||||||
|
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
async def get_engine():
|
||||||
|
"""获取异步数据库引擎"""
|
||||||
|
engine, _ = await initialize_database()
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
class PermissionNodes(Base):
|
||||||
|
"""权限节点模型"""
|
||||||
|
|
||||||
|
__tablename__ = "permission_nodes"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称
|
||||||
|
description = Column(Text, nullable=False) # 权限描述
|
||||||
|
plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件
|
||||||
|
default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_permission_plugin", "plugin_name"),
|
||||||
|
Index("idx_permission_node", "node_name"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserPermissions(Base):
|
||||||
|
"""用户权限模型"""
|
||||||
|
|
||||||
|
__tablename__ = "user_permissions"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型
|
||||||
|
user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID
|
||||||
|
permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称
|
||||||
|
granted = Column(Boolean, default=True, nullable=False) # 是否授权
|
||||||
|
granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间
|
||||||
|
granted_by = Column(get_string_field(100), nullable=True) # 授权者信息
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_user_platform_id", "platform", "user_id"),
|
||||||
|
Index("idx_user_permission", "platform", "user_id", "permission_node"),
|
||||||
|
Index("idx_permission_granted", "permission_node", "granted"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class UserRelationships(Base):
|
||||||
|
"""用户关系模型 - 存储用户与bot的关系数据"""
|
||||||
|
|
||||||
|
__tablename__ = "user_relationships"
|
||||||
|
|
||||||
|
id = Column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID
|
||||||
|
user_name = Column(get_string_field(100), nullable=True) # 用户名
|
||||||
|
relationship_text = Column(Text, nullable=True) # 关系印象描述
|
||||||
|
relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1)
|
||||||
|
last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间
|
||||||
|
created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间
|
||||||
|
|
||||||
|
__table_args__ = (
|
||||||
|
Index("idx_user_relationship_id", "user_id"),
|
||||||
|
Index("idx_relationship_score", "relationship_score"),
|
||||||
|
Index("idx_relationship_updated", "last_updated"),
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user