diff --git a/bot.py b/bot.py index 96c5f165f..debeaac5f 100644 --- a/bot.py +++ b/bot.py @@ -543,7 +543,7 @@ class MaiBotMain: """设置时区""" try: if platform.system().lower() != "windows": - time.tzset() + time.tzset() # type: ignore logger.info("时区设置完成") else: logger.info("Windows系统,跳过时区设置") diff --git a/plugins/echo_example/__init__.py b/plugins/echo_example/__init__.py deleted file mode 100644 index 0a78bbfa7..000000000 --- a/plugins/echo_example/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from src.plugin_system.base.plugin_metadata import PluginMetadata - -__plugin_meta__ = PluginMetadata( - name="Echo Example Plugin", - description="An example plugin that echoes messages.", - usage="!echo [message]", - version="1.0.0", - author="Your Name", - license="MIT", -) diff --git a/plugins/echo_example/_manifest.json b/plugins/echo_example/_manifest.json deleted file mode 100644 index 02cd1a72f..000000000 --- a/plugins/echo_example/_manifest.json +++ /dev/null @@ -1,53 +0,0 @@ -{ - "manifest_version": 1, - "format_version": "1.0.0", - "name": "Echo 示例插件", - "description": "展示增强命令系统的Echo命令示例插件", - "version": "1.0.0", - "author": { - "name": "MoFox" - }, - "license": "MIT", - "keywords": ["echo", "example", "command"], - "categories": ["utility", "example"], - "host_application": { - "name": "MaiBot", - "min_version": "0.10.0" - }, - "entry_points": { - "main": "plugin.py" - }, - "plugin_info": { - "is_built_in": false, - "plugin_type": "example", - "components": [ - { - "type": "command", - "name": "echo", - "description": "回显命令,支持别名 say, repeat" - }, - { - "type": "command", - "name": "hello", - "description": "问候命令,支持别名 hi, greet" - }, - { - "type": "command", - "name": "info", - "description": "显示插件信息,支持别名 about" - }, - { - "type": "command", - "name": "test", - "description": "测试命令,展示参数解析功能" - } - ], - "features": [ - "增强命令系统示例", - "无需正则表达式的命令定义", - "命令别名支持", - "参数解析功能", - "聊天类型限制" - ] - } -} diff --git a/plugins/echo_example/plugin.py b/plugins/echo_example/plugin.py deleted file mode 100644 index e03429805..000000000 --- a/plugins/echo_example/plugin.py +++ /dev/null @@ -1,204 +0,0 @@ -""" -Echo 示例插件 - -展示增强命令系统的使用方法 -""" - -from typing import Union - -from src.plugin_system import ( - BasePlugin, - ChatType, - CommandArgs, - ConfigField, - PlusCommand, - PlusCommandInfo, - register_plugin, -) -from src.plugin_system.base.component_types import PythonDependency - - -class EchoCommand(PlusCommand): - """Echo命令示例""" - - command_name = "echo" - command_description = "回显命令" - command_aliases = ["say", "repeat"] - priority = 5 - chat_type_allow = ChatType.ALL - intercept_message = True - - async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: - """执行echo命令""" - if args.is_empty(): - await self.send_text("❓ 请提供要回显的内容\n用法: /echo <内容>") - return True, "参数不足", True - - content = args.get_raw() - - # 检查内容长度限制 - max_length = self.get_config("commands.max_content_length", 500) - if len(content) > max_length: - await self.send_text(f"❌ 内容过长,最大允许 {max_length} 字符") - return True, "内容过长", True - - await self.send_text(f"🔊 {content}") - - return True, "Echo命令执行成功", True - - -class HelloCommand(PlusCommand): - """Hello命令示例""" - - command_name = "hello" - command_description = "问候命令" - command_aliases = ["hi", "greet"] - priority = 3 - chat_type_allow = ChatType.ALL - intercept_message = True - - async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: - """执行hello命令""" - if args.is_empty(): - await self.send_text("👋 Hello! 很高兴见到你!") - else: - name = args.get_first() - await self.send_text(f"👋 Hello, {name}! 很高兴见到你!") - - return True, "Hello命令执行成功", True - - -class InfoCommand(PlusCommand): - """信息命令示例""" - - command_name = "info" - command_description = "显示插件信息" - command_aliases = ["about"] - priority = 1 - chat_type_allow = ChatType.ALL - intercept_message = True - - async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: - """执行info命令""" - info_text = ( - "📋 Echo 示例插件信息\n" - "版本: 1.0.0\n" - "作者: MaiBot Team\n" - "描述: 展示增强命令系统的使用方法\n\n" - "🎯 可用命令:\n" - "• /echo|/say|/repeat <内容> - 回显内容\n" - "• /hello|/hi|/greet [名字] - 问候\n" - "• /info|/about - 显示此信息\n" - "• /test <子命令> [参数] - 测试各种功能" - ) - await self.send_text(info_text) - - return True, "Info命令执行成功", True - - -class TestCommand(PlusCommand): - """测试命令示例,展示参数解析功能""" - - command_name = "test" - command_description = "测试命令,展示参数解析功能" - command_aliases = ["t"] - priority = 2 - chat_type_allow = ChatType.ALL - intercept_message = True - - async def execute(self, args: CommandArgs) -> tuple[bool, str | None, bool]: - """执行test命令""" - if args.is_empty(): - help_text = ( - "🧪 测试命令帮助\n" - "用法: /test <子命令> [参数]\n\n" - "可用子命令:\n" - "• args - 显示参数解析结果\n" - "• flags - 测试标志参数\n" - "• count - 计算参数数量\n" - "• join - 连接所有参数" - ) - await self.send_text(help_text) - return True, "显示帮助", True - - subcommand = args.get_first().lower() - - if subcommand == "args": - result = ( - f"🔍 参数解析结果:\n" - f"原始字符串: '{args.get_raw()}'\n" - f"解析后参数: {args.get_args()}\n" - f"参数数量: {args.count()}\n" - f"第一个参数: '{args.get_first()}'\n" - f"剩余参数: '{args.get_remaining()}'" - ) - await self.send_text(result) - - elif subcommand == "flags": - result = ( - f"🏴 标志测试结果:\n" - f"包含 --verbose: {args.has_flag('--verbose')}\n" - f"包含 -v: {args.has_flag('-v')}\n" - f"--output 的值: '{args.get_flag_value('--output', '未设置')}'\n" - f"--name 的值: '{args.get_flag_value('--name', '未设置')}'" - ) - await self.send_text(result) - - elif subcommand == "count": - count = args.count() - 1 # 减去子命令本身 - await self.send_text(f"📊 除子命令外的参数数量: {count}") - - elif subcommand == "join": - remaining = args.get_remaining() - if remaining: - await self.send_text(f"🔗 连接结果: {remaining}") - else: - await self.send_text("❌ 没有可连接的参数") - - else: - await self.send_text(f"❓ 未知的子命令: {subcommand}") - - return True, "Test命令执行成功", True - - -@register_plugin -class EchoExamplePlugin(BasePlugin): - """Echo 示例插件""" - - plugin_name: str = "echo_example_plugin" - enable_plugin: bool = True - dependencies: list[str] = [] - python_dependencies: list[Union[str, "PythonDependency"]] = [] - config_file_name: str = "config.toml" - - config_schema = { - "plugin": { - "enabled": ConfigField(bool, default=True, description="是否启用插件"), - "config_version": ConfigField(str, default="1.0.0", description="配置文件版本"), - }, - "commands": { - "echo_enabled": ConfigField(bool, default=True, description="是否启用 Echo 命令"), - "cooldown": ConfigField(int, default=0, description="命令冷却时间(秒)"), - "max_content_length": ConfigField(int, default=500, description="最大回显内容长度"), - }, - } - - config_section_descriptions = { - "plugin": "插件基本配置", - "commands": "命令相关配置", - } - - def get_plugin_components(self) -> list[tuple[PlusCommandInfo, type]]: - """获取插件组件""" - components = [] - - if self.get_config("plugin.enabled", True): - # 添加所有命令,直接使用PlusCommand类 - if self.get_config("commands.echo_enabled", True): - components.append((EchoCommand.get_plus_command_info(), EchoCommand)) - - components.append((HelloCommand.get_plus_command_info(), HelloCommand)) - components.append((InfoCommand.get_plus_command_info(), InfoCommand)) - components.append((TestCommand.get_plus_command_info(), TestCommand)) - - return components diff --git a/plugins/hello_world_plugin/plugin.py b/plugins/hello_world_plugin/plugin.py index 2c71293a1..e3a716429 100644 --- a/plugins/hello_world_plugin/plugin.py +++ b/plugins/hello_world_plugin/plugin.py @@ -107,7 +107,7 @@ class HelloWorldPlugin(BasePlugin): components.append((GetSystemInfoTool.get_tool_info(), GetSystemInfoTool)) if self.get_config("components.hello_command_enabled", True): - components.append((HelloCommand.get_command_info(), HelloCommand)) + components.append((HelloCommand.get_plus_command_info(), HelloCommand)) if self.get_config("components.random_emoji_action_enabled", True): components.append((RandomEmojiAction.get_action_info(), RandomEmojiAction)) diff --git a/pyrightconfig.json b/pyrightconfig.json new file mode 100644 index 000000000..0dff0f212 --- /dev/null +++ b/pyrightconfig.json @@ -0,0 +1,31 @@ +{ + "$schema": "https://raw.githubusercontent.com/microsoft/pyright/main/packages/vscode-pyright/schemas/pyrightconfig.schema.json", + "include": [ + "src", + "bot.py", + "__main__.py" + ], + "exclude": [ + "**/__pycache__", + "data", + "logs", + "tests", + "target", + "*.egg-info" + ], + "typeCheckingMode": "standard", + "reportMissingImports": false, + "reportMissingTypeStubs": false, + "reportMissingModuleSource": false, + "diagnosticSeverityOverrides": { + "reportMissingImports": "none", + "reportMissingTypeStubs": "none", + "reportMissingModuleSource": "none" + }, + "pythonVersion": "3.12", + "venvPath": ".", + "venv": ".venv", + "executionEnvironments": [ + {"root": "src"} + ] +} diff --git a/scripts/convert_sqlalchemy_models.py b/scripts/convert_sqlalchemy_models.py new file mode 100644 index 000000000..31cdc3a47 --- /dev/null +++ b/scripts/convert_sqlalchemy_models.py @@ -0,0 +1,220 @@ +"""批量将经典 SQLAlchemy 模型字段写法 + + field = Column(Integer, nullable=False, default=0) + +转换为 2.0 推荐的带类型注解写法: + + field: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + +脚本特点: +1. 仅处理指定文件(默认: src/common/database/sqlalchemy_models.py)。 +2. 自动识别多行 Column(...) 定义 (括号未闭合会继续合并)。 +3. 已经是 Mapped 写法的行会跳过。 +4. 根据类型名 (Integer / Float / Boolean / Text / String / DateTime / get_string_field) 推断 Python 类型。 +5. nullable=True 时自动添加 "| None"。 +6. 保留 Column(...) 内的原始参数顺序与内容。 +7. 生成 .bak 备份文件,确保可回滚。 +8. 支持 --dry-run 查看差异,不写回文件。 + +局限/注意: +- 简单基于正则/括号计数,不解析完整 AST;非常规写法(例如变量中构造 Column 再赋值)不会处理。 +- 复杂工厂/自定义类型未在映射表中的,统一映射为 Any。 +- 不自动添加 from __future__ import annotations;如需 Python 3.10 以下更先进类型表达式,请自行处理。 + +使用方式: (在项目根目录执行) + + python scripts/convert_sqlalchemy_models.py \ + --file src/common/database/sqlalchemy_models.py --dry-run + +确认无误后去掉 --dry-run 真实写入。 +""" + +from __future__ import annotations + +import argparse +import re +import shutil +from pathlib import Path +from typing import Any + + +TYPE_MAP = { + "Integer": "int", + "Float": "float", + "Boolean": "bool", + "Text": "str", + "String": "str", + "DateTime": "datetime.datetime", + # 自定义帮助函数 get_string_field(...) 也返回字符串类型 + "get_string_field": "str", +} + + +COLUMN_ASSIGN_RE = re.compile(r"^(?P\s+)(?P[A-Za-z_][A-Za-z0-9_]*)\s*=\s*Column\(") +ALREADY_MAPPED_RE = re.compile(r"^[ \t]*[A-Za-z_][A-Za-z0-9_]*\s*:\s*Mapped\[") + + +def detect_column_block(lines: list[str], start_index: int) -> tuple[int, int] | None: + """检测从 start_index 开始的 Column(...) 语句跨越的行范围 (包含结束行)。 + + 使用括号计数法处理多行。 + 返回 (start, end) 行号 (包含 end)。""" + line = lines[start_index] + if "Column(" not in line: + return None + open_parens = line.count("(") - line.count(")") + i = start_index + while open_parens > 0 and i + 1 < len(lines): + i += 1 + l2 = lines[i] + open_parens += l2.count("(") - l2.count(")") + return (start_index, i) + + +def extract_column_body(block_lines: list[str]) -> str: + """提取 Column(...) 内部参数文本 (去掉首尾 Column( 和 最后一个 ) )。""" + joined = "\n".join(block_lines) + # 找到第一次出现 Column( + start_pos = joined.find("Column(") + if start_pos == -1: + return "" + inner = joined[start_pos + len("Column(") :] + # 去掉最后一个 ) —— 简单方式: 找到最后一个 ) 并截断 + last_paren = inner.rfind(")") + if last_paren != -1: + inner = inner[:last_paren] + return inner.strip() + + +def guess_python_type(column_body: str) -> str: + # 简单取第一个类型标识符 (去掉前导装饰/空格) + # 可能形式: Integer, Text, get_string_field(50), DateTime, Boolean + # 利用正则抓取第一个标识符 + m = re.search(r"([A-Za-z_][A-Za-z0-9_]*)", column_body) + if not m: + return "Any" + type_token = m.group(1) + py_type = TYPE_MAP.get(type_token, "Any") + # nullable 检测 + if "nullable=True" in column_body or "nullable = True" in column_body: + # 避免重复 Optional + if py_type != "Any" and not py_type.endswith(" | None"): + py_type = f"{py_type} | None" + elif py_type == "Any": + py_type = "Any | None" + return py_type + + +def convert_block(block_lines: list[str]) -> list[str]: + first_line = block_lines[0] + m_name = re.match(r"^(?P\s+)(?P[A-Za-z_][A-Za-z0-9_]*)\s*=", first_line) + if not m_name: + return block_lines + indent = m_name.group("indent") + name = m_name.group("name") + body = extract_column_body(block_lines) + py_type = guess_python_type(body) + # 构造新的多行 mapped_column 写法 + # 保留内部参数的换行缩进: 重新缩进为 indent + 4 空格 (延续原风格: 在 indent 基础上再加 4 空格) + inner_lines = body.split("\n") + if len(inner_lines) == 1: + new_line = f"{indent}{name}: Mapped[{py_type}] = mapped_column({inner_lines[0].strip()})\n" + return [new_line] + else: + # 多行情况 + ind2 = indent + " " + rebuilt = [f"{indent}{name}: Mapped[{py_type}] = mapped_column(",] + for il in inner_lines: + if il.strip(): + rebuilt.append(f"{ind2}{il.rstrip()}") + rebuilt.append(f"{indent})\n") + return [l + ("\n" if not l.endswith("\n") else "") for l in rebuilt] + + +def ensure_imports(content: str) -> str: + if "Mapped," in content or "Mapped[" in content: + # 已经可能存在导入 + if "from sqlalchemy.orm import Mapped, mapped_column" not in content: + # 简单插到第一个 import sqlalchemy 之后 + lines = content.splitlines() + for i, line in enumerate(lines): + if "sqlalchemy" in line and line.startswith("from sqlalchemy"): + lines.insert(i + 1, "from sqlalchemy.orm import Mapped, mapped_column") + return "\n".join(lines) + return content + + +def process_file(path: Path) -> tuple[str, str]: + original = path.read_text(encoding="utf-8") + lines = original.splitlines(keepends=True) + i = 0 + out: list[str] = [] + changed = 0 + while i < len(lines): + line = lines[i] + # 跳过已是 Mapped 风格 + if ALREADY_MAPPED_RE.match(line): + out.append(line) + i += 1 + continue + if "= Column(" in line and re.match(r"^\s+[A-Za-z_][A-Za-z0-9_]*\s*=", line): + start, end = detect_column_block(lines, i) or (i, i) + block = lines[start : end + 1] + converted = convert_block(block) + out.extend(converted) + i = end + 1 + # 如果转换结果与原始不同,计数 + if "".join(converted) != "".join(block): + changed += 1 + else: + out.append(line) + i += 1 + new_content = "".join(out) + new_content = ensure_imports(new_content) + # 在文件末尾或头部预留统计信息打印(不写入文件,只返回) + return original, new_content if changed else original + + +def main(): + parser = argparse.ArgumentParser(description="批量转换 SQLAlchemy 模型字段为 2.0 Mapped 写法") + parser.add_argument("--file", default="src/common/database/sqlalchemy_models.py", help="目标模型文件") + parser.add_argument("--dry-run", action="store_true", help="仅显示差异,不写回") + parser.add_argument("--write", action="store_true", help="执行写回 (与 --dry-run 互斥)") + args = parser.parse_args() + + target = Path(args.file) + if not target.exists(): + raise SystemExit(f"文件不存在: {target}") + + original, new_content = process_file(target) + + if original == new_content: + print("[INFO] 没有需要转换的内容或转换后无差异。") + return + + # 简单差异输出 (行对比) + if args.dry_run or not args.write: + print("[DRY-RUN] 以下为转换后预览 (仅显示不同段落):") + import difflib + + diff = difflib.unified_diff( + original.splitlines(), new_content.splitlines(), fromfile="original", tofile="converted", lineterm="" + ) + count = 0 + for d in diff: + print(d) + count += 1 + if count == 0: + print("[INFO] 差异为空 (可能未匹配到 Column 定义)。") + if not args.write: + print("\n未写回。若确认无误,添加 --write 执行替换。") + return + + backup = target.with_suffix(target.suffix + ".bak") + shutil.copyfile(target, backup) + target.write_text(new_content, encoding="utf-8") + print(f"[DONE] 已写回: {target},备份文件: {backup.name}") + + +if __name__ == "__main__": # pragma: no cover + main() diff --git a/scripts/interest_value_analysis.py b/scripts/interest_value_analysis.py deleted file mode 100644 index e464c905c..000000000 --- a/scripts/interest_value_analysis.py +++ /dev/null @@ -1,284 +0,0 @@ -import os -import sys -import time -from datetime import datetime - -# Add project root to Python path -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.insert(0, project_root) -from src.common.database.database_model import Messages, ChatStreams # noqa - - -def get_chat_name(chat_id: str) -> str: - """Get chat name from chat_id by querying ChatStreams table directly""" - try: - chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) - if chat_stream is None: - return f"未知聊天 ({chat_id})" - - if chat_stream.group_name: - return f"{chat_stream.group_name} ({chat_id})" - elif chat_stream.user_nickname: - return f"{chat_stream.user_nickname}的私聊 ({chat_id})" - else: - return f"未知聊天 ({chat_id})" - except Exception: - return f"查询失败 ({chat_id})" - - -def format_timestamp(timestamp: float) -> str: - """Format timestamp to readable date string""" - try: - return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S") - except (ValueError, OSError): - return "未知时间" - - -def calculate_interest_value_distribution(messages) -> dict[str, int]: - """Calculate distribution of interest_value""" - distribution = { - "0.000-0.010": 0, - "0.010-0.050": 0, - "0.050-0.100": 0, - "0.100-0.500": 0, - "0.500-1.000": 0, - "1.000-2.000": 0, - "2.000-5.000": 0, - "5.000-10.000": 0, - "10.000+": 0, - } - - for msg in messages: - if msg.interest_value is None or msg.interest_value == 0.0: - continue - - value = float(msg.interest_value) - if value < 0.010: - distribution["0.000-0.010"] += 1 - elif value < 0.050: - distribution["0.010-0.050"] += 1 - elif value < 0.100: - distribution["0.050-0.100"] += 1 - elif value < 0.500: - distribution["0.100-0.500"] += 1 - elif value < 1.000: - distribution["0.500-1.000"] += 1 - elif value < 2.000: - distribution["1.000-2.000"] += 1 - elif value < 5.000: - distribution["2.000-5.000"] += 1 - elif value < 10.000: - distribution["5.000-10.000"] += 1 - else: - distribution["10.000+"] += 1 - - return distribution - - -def get_interest_value_stats(messages) -> dict[str, float]: - """Calculate basic statistics for interest_value""" - values = [ - float(msg.interest_value) for msg in messages if msg.interest_value is not None and msg.interest_value != 0.0 - ] - - if not values: - return {"count": 0, "min": 0, "max": 0, "avg": 0, "median": 0} - - values.sort() - count = len(values) - - return { - "count": count, - "min": min(values), - "max": max(values), - "avg": sum(values) / count, - "median": values[count // 2] if count % 2 == 1 else (values[count // 2 - 1] + values[count // 2]) / 2, - } - - -def get_available_chats() -> list[tuple[str, str, int]]: - """Get all available chats with message counts""" - try: - # 获取所有有消息的chat_id - chat_counts = {} - for msg in Messages.select(Messages.chat_id).distinct(): - chat_id = msg.chat_id - count = ( - Messages.select() - .where( - (Messages.chat_id == chat_id) - & (Messages.interest_value.is_null(False)) - & (Messages.interest_value != 0.0) - ) - .count() - ) - if count > 0: - chat_counts[chat_id] = count - - # 获取聊天名称 - result = [] - for chat_id, count in chat_counts.items(): - chat_name = get_chat_name(chat_id) - result.append((chat_id, chat_name, count)) - - # 按消息数量排序 - result.sort(key=lambda x: x[2], reverse=True) - return result - except Exception as e: - print(f"获取聊天列表失败: {e}") - return [] - - -def get_time_range_input() -> tuple[float | None, float | None]: - """Get time range input from user""" - print("\n时间范围选择:") - print("1. 最近1天") - print("2. 最近3天") - print("3. 最近7天") - print("4. 最近30天") - print("5. 自定义时间范围") - print("6. 不限制时间") - - choice = input("请选择时间范围 (1-6): ").strip() - - now = time.time() - - if choice == "1": - return now - 24 * 3600, now - elif choice == "2": - return now - 3 * 24 * 3600, now - elif choice == "3": - return now - 7 * 24 * 3600, now - elif choice == "4": - return now - 30 * 24 * 3600, now - elif choice == "5": - print("请输入开始时间 (格式: YYYY-MM-DD HH:MM:SS):") - start_str = input().strip() - print("请输入结束时间 (格式: YYYY-MM-DD HH:MM:SS):") - end_str = input().strip() - - try: - start_time = datetime.strptime(start_str, "%Y-%m-%d %H:%M:%S").timestamp() - end_time = datetime.strptime(end_str, "%Y-%m-%d %H:%M:%S").timestamp() - return start_time, end_time - except ValueError: - print("时间格式错误,将不限制时间范围") - return None, None - else: - return None, None - - -def analyze_interest_values( - chat_id: str | None = None, start_time: float | None = None, end_time: float | None = None -) -> None: - """Analyze interest values with optional filters""" - - # 构建查询条件 - query = Messages.select().where((Messages.interest_value.is_null(False)) & (Messages.interest_value != 0.0)) - - if chat_id: - query = query.where(Messages.chat_id == chat_id) - - if start_time: - query = query.where(Messages.time >= start_time) - - if end_time: - query = query.where(Messages.time <= end_time) - - messages = list(query) - - if not messages: - print("没有找到符合条件的消息") - return - - # 计算统计信息 - distribution = calculate_interest_value_distribution(messages) - stats = get_interest_value_stats(messages) - - # 显示结果 - print("\n=== Interest Value 分析结果 ===") - if chat_id: - print(f"聊天: {get_chat_name(chat_id)}") - else: - print("聊天: 全部聊天") - - if start_time and end_time: - print(f"时间范围: {format_timestamp(start_time)} 到 {format_timestamp(end_time)}") - elif start_time: - print(f"时间范围: {format_timestamp(start_time)} 之后") - elif end_time: - print(f"时间范围: {format_timestamp(end_time)} 之前") - else: - print("时间范围: 不限制") - - print("\n基本统计:") - print(f"有效消息数量: {stats['count']} (排除null和0值)") - print(f"最小值: {stats['min']:.3f}") - print(f"最大值: {stats['max']:.3f}") - print(f"平均值: {stats['avg']:.3f}") - print(f"中位数: {stats['median']:.3f}") - - print("\nInterest Value 分布:") - total = stats["count"] - for range_name, count in distribution.items(): - if count > 0: - percentage = count / total * 100 - print(f"{range_name}: {count} ({percentage:.2f}%)") - - -def interactive_menu() -> None: - """Interactive menu for interest value analysis""" - - while True: - print("\n" + "=" * 50) - print("Interest Value 分析工具") - print("=" * 50) - print("1. 分析全部聊天") - print("2. 选择特定聊天分析") - print("q. 退出") - - choice = input("\n请选择分析模式 (1-2, q): ").strip() - - if choice.lower() == "q": - print("再见!") - break - - chat_id = None - - if choice == "2": - # 显示可用的聊天列表 - chats = get_available_chats() - if not chats: - print("没有找到有interest_value数据的聊天") - continue - - print(f"\n可用的聊天 (共{len(chats)}个):") - for i, (_cid, name, count) in enumerate(chats, 1): - print(f"{i}. {name} ({count}条有效消息)") - - try: - chat_choice = int(input(f"\n请选择聊天 (1-{len(chats)}): ").strip()) - if 1 <= chat_choice <= len(chats): - chat_id = chats[chat_choice - 1][0] - else: - print("无效选择") - continue - except ValueError: - print("请输入有效数字") - continue - - elif choice != "1": - print("无效选择") - continue - - # 获取时间范围 - start_time, end_time = get_time_range_input() - - # 执行分析 - analyze_interest_values(chat_id, start_time, end_time) - - input("\n按回车键继续...") - - -if __name__ == "__main__": - interactive_menu() diff --git a/scripts/log_viewer_optimized.py b/scripts/log_viewer_optimized.py deleted file mode 100644 index 950c725d6..000000000 --- a/scripts/log_viewer_optimized.py +++ /dev/null @@ -1,1433 +0,0 @@ -import os -import threading -import time -import tkinter as tk -from collections import defaultdict -from datetime import datetime -from pathlib import Path -from tkinter import colorchooser, filedialog, messagebox, ttk - -import orjson -import toml - - -class LogIndex: - """日志索引,用于快速检索和过滤""" - - def __init__(self): - self.entries = [] # 所有日志条目 - self.module_index = defaultdict(list) # 按模块索引 - self.level_index = defaultdict(list) # 按级别索引 - self.filtered_indices = [] # 当前过滤结果的索引 - self.total_entries = 0 - - def add_entry(self, index, entry): - """添加日志条目到索引""" - if index >= len(self.entries): - self.entries.extend([None] * (index - len(self.entries) + 1)) - - self.entries[index] = entry - self.total_entries = max(self.total_entries, index + 1) - - # 更新各种索引 - logger_name = entry.get("logger_name", "") - level = entry.get("level", "") - - self.module_index[logger_name].append(index) - self.level_index[level].append(index) - - def filter_entries(self, modules=None, level=None, search_text=None): - """根据条件过滤日志条目""" - if not modules and not level and not search_text: - self.filtered_indices = list(range(self.total_entries)) - return self.filtered_indices - - candidate_indices = set(range(self.total_entries)) - - # 模块过滤 - if modules and "全部" not in modules: - module_indices = set() - for module in modules: - module_indices.update(self.module_index.get(module, [])) - candidate_indices &= module_indices - - # 级别过滤 - if level and level != "全部": - level_indices = set(self.level_index.get(level, [])) - candidate_indices &= level_indices - - # 文本搜索过滤 - if search_text: - search_text = search_text.lower() - text_indices = set() - for i in candidate_indices: - if i < len(self.entries) and self.entries[i]: - entry = self.entries[i] - text_content = f"{entry.get('logger_name', '')} {entry.get('event', '')}".lower() - if search_text in text_content: - text_indices.add(i) - candidate_indices &= text_indices - - self.filtered_indices = sorted(candidate_indices) - return self.filtered_indices - - def get_filtered_count(self): - """获取过滤后的条目数量""" - return len(self.filtered_indices) - - def get_entry_at_filtered_position(self, position): - """获取过滤结果中指定位置的条目""" - if 0 <= position < len(self.filtered_indices): - index = self.filtered_indices[position] - return self.entries[index] if index < len(self.entries) else None - return None - - -class LogFormatter: - """日志格式化器""" - - def __init__(self, config, custom_module_colors=None, custom_level_colors=None): - self.config = config - - # 日志级别颜色 - self.level_colors = { - "debug": "#FFA500", - "info": "#0000FF", - "success": "#008000", - "warning": "#FFFF00", - "error": "#FF0000", - "critical": "#800080", - } - - # 模块颜色映射 - self.module_colors = { - "api": "#00FF00", - "emoji": "#00FF00", - "chat": "#0080FF", - "config": "#FFFF00", - "common": "#FF00FF", - "tools": "#00FFFF", - "lpmm": "#00FFFF", - "plugin_system": "#FF0080", - "experimental": "#FFFFFF", - "person_info": "#008000", - "individuality": "#000080", - "manager": "#800080", - "llm_models": "#008080", - "plugins": "#800000", - "plugin_api": "#808000", - "remote": "#8000FF", - } - - # 应用自定义颜色 - if custom_module_colors: - self.module_colors.update(custom_module_colors) - if custom_level_colors: - self.level_colors.update(custom_level_colors) - - # 根据配置决定颜色启用状态 - color_text = self.config.get("color_text", "full") - if color_text == "none": - self.enable_colors = False - self.enable_module_colors = False - self.enable_level_colors = False - elif color_text == "title": - self.enable_colors = True - self.enable_module_colors = True - self.enable_level_colors = False - elif color_text == "full": - self.enable_colors = True - self.enable_module_colors = True - self.enable_level_colors = True - else: - self.enable_colors = True - self.enable_module_colors = True - self.enable_level_colors = False - - def format_log_entry(self, log_entry): - """格式化日志条目,返回格式化后的文本和样式标签""" - timestamp = log_entry.get("timestamp", "") - level = log_entry.get("level", "info") - logger_name = log_entry.get("logger_name", "") - event = log_entry.get("event", "") - - # 格式化时间戳 - formatted_timestamp = self.format_timestamp(timestamp) - - # 构建输出部分 - parts = [] - tags = [] - - # 日志级别样式配置 - log_level_style = self.config.get("log_level_style", "lite") - - # 时间戳 - if formatted_timestamp: - if log_level_style == "lite" and self.enable_level_colors: - parts.append(formatted_timestamp) - tags.append(f"level_{level}") - else: - parts.append(formatted_timestamp) - tags.append("timestamp") - - # 日志级别显示 - if log_level_style == "full": - level_text = f"[{level.upper():>8}]" - parts.append(level_text) - if self.enable_level_colors: - tags.append(f"level_{level}") - else: - tags.append("level") - elif log_level_style == "compact": - level_text = f"[{level.upper()[0]:>8}]" - parts.append(level_text) - if self.enable_level_colors: - tags.append(f"level_{level}") - else: - tags.append("level") - - # 模块名称 - if logger_name: - module_text = f"[{logger_name}]" - parts.append(module_text) - if self.enable_module_colors: - tags.append(f"module_{logger_name}") - else: - tags.append("module") - - # 消息内容 - if isinstance(event, str): - parts.append(event) - elif isinstance(event, dict): - try: - parts.append(orjson.dumps(event).decode("utf-8")) - except (TypeError, ValueError): - parts.append(str(event)) - else: - parts.append(str(event)) - tags.append("message") - - # 处理其他字段 - extras = [] - for key, value in log_entry.items(): - if key not in ("timestamp", "level", "logger_name", "event"): - if isinstance(value, dict | list): - try: - value_str = orjson.dumps(value).decode("utf-8") - except (TypeError, ValueError): - value_str = str(value) - else: - value_str = str(value) - extras.append(f"{key}={value_str}") - - if extras: - parts.append(" ".join(extras)) - tags.append("extras") - - return parts, tags - - def format_timestamp(self, timestamp): - """格式化时间戳""" - if not timestamp: - return "" - - try: - if "T" in timestamp: - dt = datetime.fromisoformat(timestamp.replace("Z", "+00:00")) - else: - return timestamp - - date_style = self.config.get("date_style", "m-d H:i:s") - format_map = { - "Y": "%Y", - "m": "%m", - "d": "%d", - "H": "%H", - "i": "%M", - "s": "%S", - } - - python_format = date_style - for php_char, python_char in format_map.items(): - python_format = python_format.replace(php_char, python_char) - - return dt.strftime(python_format) - except Exception: - return timestamp - - -class VirtualLogDisplay: - """虚拟滚动日志显示组件""" - - def __init__(self, parent, formatter): - self.parent = parent - self.formatter = formatter - self.line_height = 20 # 每行高度(像素) - self.visible_lines = 30 # 可见行数 - - # 创建主框架 - self.main_frame = ttk.Frame(parent) - - # 创建文本框和滚动条 - self.scrollbar = ttk.Scrollbar(self.main_frame) - self.scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - - self.text_widget = tk.Text( - self.main_frame, - wrap=tk.WORD, - yscrollcommand=self.scrollbar.set, - background="#1e1e1e", - foreground="#ffffff", - insertbackground="#ffffff", - selectbackground="#404040", - font=("Consolas", 10), - ) - self.text_widget.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) - self.scrollbar.config(command=self.text_widget.yview) - - # 配置文本标签样式 - self.configure_text_tags() - - # 数据源 - self.log_index = None - self.current_page = 0 - self.page_size = 500 # 每页显示条数 - self.max_display_lines = 2000 # 最大显示行数 - - def pack(self, **kwargs): - """包装pack方法""" - self.main_frame.pack(**kwargs) - - def configure_text_tags(self): - """配置文本标签样式""" - # 基础标签 - self.text_widget.tag_configure("timestamp", foreground="#808080") - self.text_widget.tag_configure("level", foreground="#808080") - self.text_widget.tag_configure("module", foreground="#808080") - self.text_widget.tag_configure("message", foreground="#ffffff") - self.text_widget.tag_configure("extras", foreground="#808080") - - # 日志级别颜色标签 - for level, color in self.formatter.level_colors.items(): - self.text_widget.tag_configure(f"level_{level}", foreground=color) - - # 模块颜色标签 - for module, color in self.formatter.module_colors.items(): - self.text_widget.tag_configure(f"module_{module}", foreground=color) - - def set_log_index(self, log_index): - """设置日志索引数据源""" - self.log_index = log_index - self.current_page = 0 - self.refresh_display() - - def refresh_display(self): - """刷新显示""" - if not self.log_index: - self.text_widget.delete(1.0, tk.END) - return - - # 清空显示 - self.text_widget.delete(1.0, tk.END) - - # 批量加载和显示日志 - total_count = self.log_index.get_filtered_count() - if total_count == 0: - self.text_widget.insert(tk.END, "没有符合条件的日志记录\n") - return - - # 计算显示范围 - start_index = 0 - end_index = min(total_count, self.max_display_lines) - - # 批量处理和显示 - batch_size = 100 - for batch_start in range(start_index, end_index, batch_size): - batch_end = min(batch_start + batch_size, end_index) - self.display_batch(batch_start, batch_end) - - # 让UI有机会响应 - self.parent.update_idletasks() - - # 滚动到底部(如果需要) - self.text_widget.see(tk.END) - - def display_batch(self, start_index, end_index): - """批量显示日志条目""" - for i in range(start_index, end_index): - log_entry = self.log_index.get_entry_at_filtered_position(i) - if log_entry: - self.append_entry(log_entry, scroll=False) - - def append_entry(self, log_entry, scroll=True): - """将单个日志条目附加到文本小部件""" - # 检查在添加新内容之前视图是否已滚动到底部 - should_scroll = scroll and self.text_widget.yview()[1] > 0.99 - - parts, tags = self.formatter.format_log_entry(log_entry) - line_text = " ".join(parts) + "\n" - - # 获取插入前的末尾位置 - start_pos = self.text_widget.index(tk.END + "-1c") - self.text_widget.insert(tk.END, line_text) - - # 为每个部分应用正确的标签 - current_len = 0 - # Python 3.9 兼容性:不使用 strict=False 参数 - min_len = min(len(parts), len(tags)) - for i in range(min_len): - part = parts[i] - tag_name = tags[i] - start_index = f"{start_pos}+{current_len}c" - end_index = f"{start_pos}+{current_len + len(part)}c" - self.text_widget.tag_add(tag_name, start_index, end_index) - current_len += len(part) + 1 # 计入空格 - - if should_scroll: - self.text_widget.see(tk.END) - - -class AsyncLogLoader: - """异步日志加载器""" - - def __init__(self, callback): - self.callback = callback - self.loading = False - self.should_stop = False - - def load_file_async(self, file_path, progress_callback=None): - """异步加载日志文件""" - if self.loading: - return - - self.loading = True - self.should_stop = False - - def load_worker(): - try: - log_index = LogIndex() - - if not os.path.exists(file_path): - self.callback(log_index, "文件不存在") - return - - file_size = os.path.getsize(file_path) - processed_size = 0 - - with open(file_path, encoding="utf-8") as f: - line_count = 0 - batch_size = 1000 # 批量处理 - - while not self.should_stop: - lines = [] - for _ in range(batch_size): - line = f.readline() - if not line: - break - lines.append(line) - processed_size += len(line.encode("utf-8")) - - if not lines: - break - - # 处理这批数据 - for line in lines: - try: - log_entry = orjson.loads(line.strip()) - log_index.add_entry(line_count, log_entry) - line_count += 1 - except orjson.JSONDecodeError: - continue - - # 更新进度 - if progress_callback: - progress = min(100, (processed_size / file_size) * 100) - progress_callback(progress, line_count) - - if not self.should_stop: - self.callback(log_index, None) - - except Exception as e: - self.callback(None, str(e)) - finally: - self.loading = False - - thread = threading.Thread(target=load_worker) - thread.daemon = True - thread.start() - - def stop_loading(self): - """停止加载""" - self.should_stop = True - self.loading = False - - -class LogViewer: - def __init__(self, root): - self.root = root - self.root.title("MaiBot日志查看器 (优化版)") - self.root.geometry("1200x800") - - # 加载配置 - self.load_config() - - # 初始化日志格式化器 - self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) - - # 初始化日志文件路径 - self.current_log_file = Path("logs/app.log.jsonl") - self.last_file_size = 0 - self.watching_thread = None - self.is_watching = tk.BooleanVar(value=True) - - # 初始化异步加载器 - self.async_loader = AsyncLogLoader(self.on_file_loaded) - - # 初始化日志索引 - self.log_index = LogIndex() - - # 创建主框架 - self.main_frame = ttk.Frame(root) - self.main_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) - - # 创建菜单栏 - self.create_menu() - - # 创建控制面板 - self.create_control_panel() - - # 创建虚拟滚动日志显示区域 - self.log_display = VirtualLogDisplay(self.main_frame, self.formatter) - self.log_display.pack(fill=tk.BOTH, expand=True) - - # 模块名映射 - self.module_name_mapping = { - "api": "API接口", - "async_task_manager": "异步任务管理器", - "background_tasks": "后台任务", - "base_tool": "基础工具", - "chat_stream": "聊天流", - "component_registry": "组件注册器", - "config": "配置", - "database_model": "数据库模型", - "emoji": "表情", - "heartflow": "心流", - "local_storage": "本地存储", - "lpmm": "LPMM", - "maibot_statistic": "MaiBot统计", - "main_message": "主消息", - "main": "主程序", - "memory": "内存", - "mood": "情绪", - "plugin_manager": "插件管理器", - "remote": "远程", - "willing": "意愿", - } - - # 加载自定义映射 - self.load_module_mapping() - - # 选中的模块集合 - self.selected_modules = set() - self.modules = set() - - # 绑定事件 - self.level_combo.bind("<>", self.filter_logs) - self.search_var.trace("w", self.filter_logs) - - # 绑定快捷键 - self.root.bind("", lambda e: self.select_log_file()) - self.root.bind("", lambda e: self.refresh_log_file()) - self.root.bind("", lambda e: self.export_logs()) - - # 初始加载文件 - if self.current_log_file.exists(): - self.load_log_file_async() - - def load_config(self): - """加载配置文件""" - # 默认配置 - self.default_config = { - "log": {"date_style": "m-d H:i:s", "log_level_style": "lite", "color_text": "full", "log_level": "INFO"}, - "viewer": { - "theme": "dark", - "font_size": 10, - "max_lines": 1000, - "auto_scroll": True, - "show_milliseconds": False, - "window": {"width": 1200, "height": 800, "remember_position": True}, - }, - } - - # 从bot_config.toml加载日志配置 - config_path = Path("config/bot_config.toml") - self.log_config = self.default_config["log"].copy() - self.viewer_config = self.default_config["viewer"].copy() - - try: - if config_path.exists(): - with open(config_path, encoding="utf-8") as f: - bot_config = toml.load(f) - if "log" in bot_config: - self.log_config.update(bot_config["log"]) - except Exception as e: - print(f"加载bot配置失败: {e}") - - # 从viewer配置文件加载查看器配置 - viewer_config_path = Path("config/log_viewer_config.toml") - self.custom_module_colors = {} - self.custom_level_colors = {} - - try: - if viewer_config_path.exists(): - with open(viewer_config_path, encoding="utf-8") as f: - viewer_config = toml.load(f) - if "viewer" in viewer_config: - self.viewer_config.update(viewer_config["viewer"]) - - # 加载自定义模块颜色 - if "module_colors" in viewer_config["viewer"]: - self.custom_module_colors = viewer_config["viewer"]["module_colors"] - - # 加载自定义级别颜色 - if "level_colors" in viewer_config["viewer"]: - self.custom_level_colors = viewer_config["viewer"]["level_colors"] - - if "log" in viewer_config: - self.log_config.update(viewer_config["log"]) - except Exception as e: - print(f"加载查看器配置失败: {e}") - - # 应用窗口配置 - window_config = self.viewer_config.get("window", {}) - window_width = window_config.get("width", 1200) - window_height = window_config.get("height", 800) - self.root.geometry(f"{window_width}x{window_height}") - - def save_viewer_config(self): - """保存查看器配置""" - # 准备完整的配置数据 - viewer_config_copy = self.viewer_config.copy() - - # 保存自定义颜色(只保存与默认值不同的颜色) - if self.custom_module_colors: - viewer_config_copy["module_colors"] = self.custom_module_colors - if self.custom_level_colors: - viewer_config_copy["level_colors"] = self.custom_level_colors - - config_data = {"log": self.log_config, "viewer": viewer_config_copy} - - config_path = Path("config/log_viewer_config.toml") - config_path.parent.mkdir(exist_ok=True) - - try: - with open(config_path, "w", encoding="utf-8") as f: - toml.dump(config_data, f) - except Exception as e: - print(f"保存查看器配置失败: {e}") - - def create_menu(self): - """创建菜单栏""" - menubar = tk.Menu(self.root) - self.root.config(menu=menubar) - - # 配置菜单 - config_menu = tk.Menu(menubar, tearoff=0) - menubar.add_cascade(label="配置", menu=config_menu) - config_menu.add_command(label="日志格式设置", command=self.show_format_settings) - config_menu.add_command(label="颜色设置", command=self.show_color_settings) - config_menu.add_command(label="查看器设置", command=self.show_viewer_settings) - config_menu.add_separator() - config_menu.add_command(label="重新加载配置", command=self.reload_config) - - # 文件菜单 - file_menu = tk.Menu(menubar, tearoff=0) - menubar.add_cascade(label="文件", menu=file_menu) - file_menu.add_command(label="选择日志文件", command=self.select_log_file, accelerator="Ctrl+O") - file_menu.add_command(label="刷新当前文件", command=self.refresh_log_file, accelerator="F5") - file_menu.add_separator() - file_menu.add_command(label="导出当前日志", command=self.export_logs, accelerator="Ctrl+S") - - # 工具菜单 - tools_menu = tk.Menu(menubar, tearoff=0) - menubar.add_cascade(label="工具", menu=tools_menu) - tools_menu.add_command(label="清空日志显示", command=self.clear_log_display) - - def show_format_settings(self): - """显示格式设置窗口""" - format_window = tk.Toplevel(self.root) - format_window.title("日志格式设置") - format_window.geometry("400x300") - - frame = ttk.Frame(format_window) - frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) - - # 日期格式 - ttk.Label(frame, text="日期格式:").pack(anchor="w", pady=2) - date_style_var = tk.StringVar(value=self.log_config.get("date_style", "m-d H:i:s")) - date_entry = ttk.Entry(frame, textvariable=date_style_var, width=30) - date_entry.pack(anchor="w", pady=2) - ttk.Label(frame, text="格式说明: Y=年份, m=月份, d=日期, H=小时, i=分钟, s=秒", font=("", 8)).pack( - anchor="w", pady=2 - ) - - # 日志级别样式 - ttk.Label(frame, text="日志级别样式:").pack(anchor="w", pady=(10, 2)) - level_style_var = tk.StringVar(value=self.log_config.get("log_level_style", "lite")) - level_frame = ttk.Frame(frame) - level_frame.pack(anchor="w", pady=2) - - ttk.Radiobutton(level_frame, text="简洁(lite)", variable=level_style_var, value="lite").pack( - side="left", padx=(0, 10) - ) - ttk.Radiobutton(level_frame, text="紧凑(compact)", variable=level_style_var, value="compact").pack( - side="left", padx=(0, 10) - ) - ttk.Radiobutton(level_frame, text="完整(full)", variable=level_style_var, value="full").pack( - side="left", padx=(0, 10) - ) - - # 颜色文本设置 - ttk.Label(frame, text="文本颜色设置:").pack(anchor="w", pady=(10, 2)) - color_text_var = tk.StringVar(value=self.log_config.get("color_text", "full")) - color_frame = ttk.Frame(frame) - color_frame.pack(anchor="w", pady=2) - - ttk.Radiobutton(color_frame, text="无颜色(none)", variable=color_text_var, value="none").pack( - side="left", padx=(0, 10) - ) - ttk.Radiobutton(color_frame, text="仅标题(title)", variable=color_text_var, value="title").pack( - side="left", padx=(0, 10) - ) - ttk.Radiobutton(color_frame, text="全部(full)", variable=color_text_var, value="full").pack( - side="left", padx=(0, 10) - ) - - # 按钮 - button_frame = ttk.Frame(frame) - button_frame.pack(fill="x", pady=(20, 0)) - - def apply_format(): - self.log_config["date_style"] = date_style_var.get() - self.log_config["log_level_style"] = level_style_var.get() - self.log_config["color_text"] = color_text_var.get() - - # 重新初始化格式化器 - self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) - self.log_display.formatter = self.formatter - self.log_display.configure_text_tags() - - # 保存配置 - self.save_viewer_config() - - # 重新过滤日志以应用新格式 - self.filter_logs() - - format_window.destroy() - - ttk.Button(button_frame, text="应用", command=apply_format).pack(side="right", padx=(5, 0)) - ttk.Button(button_frame, text="取消", command=format_window.destroy).pack(side="right") - - def show_viewer_settings(self): - """显示查看器设置窗口""" - viewer_window = tk.Toplevel(self.root) - viewer_window.title("查看器设置") - viewer_window.geometry("350x250") - - frame = ttk.Frame(viewer_window) - frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10) - - # 主题设置 - ttk.Label(frame, text="主题:").pack(anchor="w", pady=2) - theme_var = tk.StringVar(value=self.viewer_config.get("theme", "dark")) - theme_frame = ttk.Frame(frame) - theme_frame.pack(anchor="w", pady=2) - ttk.Radiobutton(theme_frame, text="深色", variable=theme_var, value="dark").pack(side="left", padx=(0, 10)) - ttk.Radiobutton(theme_frame, text="浅色", variable=theme_var, value="light").pack(side="left") - - # 字体大小 - ttk.Label(frame, text="字体大小:").pack(anchor="w", pady=(10, 2)) - font_size_var = tk.IntVar(value=self.viewer_config.get("font_size", 10)) - font_size_spin = ttk.Spinbox(frame, from_=8, to=20, textvariable=font_size_var, width=10) - font_size_spin.pack(anchor="w", pady=2) - - # 最大行数 - ttk.Label(frame, text="最大显示行数:").pack(anchor="w", pady=(10, 2)) - max_lines_var = tk.IntVar(value=self.viewer_config.get("max_lines", 1000)) - max_lines_spin = ttk.Spinbox(frame, from_=100, to=10000, increment=100, textvariable=max_lines_var, width=10) - max_lines_spin.pack(anchor="w", pady=2) - - # 自动滚动 - auto_scroll_var = tk.BooleanVar(value=self.viewer_config.get("auto_scroll", True)) - ttk.Checkbutton(frame, text="自动滚动到底部", variable=auto_scroll_var).pack(anchor="w", pady=(10, 2)) - - # 按钮 - button_frame = ttk.Frame(frame) - button_frame.pack(fill="x", pady=(20, 0)) - - def apply_viewer_settings(): - self.viewer_config["theme"] = theme_var.get() - self.viewer_config["font_size"] = font_size_var.get() - self.viewer_config["max_lines"] = max_lines_var.get() - self.viewer_config["auto_scroll"] = auto_scroll_var.get() - - # 应用主题 - self.apply_theme() - - # 保存配置 - self.save_viewer_config() - - viewer_window.destroy() - - ttk.Button(button_frame, text="应用", command=apply_viewer_settings).pack(side="right", padx=(5, 0)) - ttk.Button(button_frame, text="取消", command=viewer_window.destroy).pack(side="right") - - def apply_theme(self): - """应用主题设置""" - theme = self.viewer_config.get("theme", "dark") - font_size = self.viewer_config.get("font_size", 10) - - # 更新虚拟显示组件的主题 - if theme == "dark": - bg_color = "#1e1e1e" - fg_color = "#ffffff" - select_bg = "#404040" - else: - bg_color = "#ffffff" - fg_color = "#000000" - select_bg = "#c0c0c0" - - self.log_display.text_widget.config( - background=bg_color, foreground=fg_color, selectbackground=select_bg, font=("Consolas", font_size) - ) - - # 重新配置标签样式 - self.log_display.configure_text_tags() - - def reload_config(self): - """重新加载配置""" - self.load_config() - self.formatter = LogFormatter(self.log_config, self.custom_module_colors, self.custom_level_colors) - self.log_display.formatter = self.formatter - self.log_display.configure_text_tags() - self.apply_theme() - self.filter_logs() - - def clear_log_display(self): - """清空日志显示""" - self.log_display.text_widget.delete(1.0, tk.END) - - def export_logs(self): - """导出当前显示的日志""" - filename = filedialog.asksaveasfilename( - defaultextension=".txt", filetypes=[("文本文件", "*.txt"), ("所有文件", "*.*")] - ) - if filename: - try: - # 获取当前显示的所有日志条目 - if self.log_index: - filtered_count = self.log_index.get_filtered_count() - log_lines = [] - for i in range(filtered_count): - log_entry = self.log_index.get_entry_at_filtered_position(i) - if log_entry: - parts, tags = self.formatter.format_log_entry(log_entry) - line_text = " ".join(parts) - log_lines.append(line_text) - - with open(filename, "w", encoding="utf-8") as f: - f.write("\n".join(log_lines)) - messagebox.showinfo("导出成功", f"日志已导出到: {filename}") - else: - messagebox.showwarning("导出失败", "没有日志可导出") - except Exception as e: - messagebox.showerror("导出失败", f"导出日志时出错: {e}") - - def load_module_mapping(self): - """加载自定义模块映射""" - mapping_file = Path("config/module_mapping.json") - if mapping_file.exists(): - try: - with open(mapping_file, encoding="utf-8") as f: - custom_mapping = orjson.loads(f.read()) - self.module_name_mapping.update(custom_mapping) - except Exception as e: - print(f"加载模块映射失败: {e}") - - def save_module_mapping(self): - """保存自定义模块映射""" - mapping_file = Path("config/module_mapping.json") - mapping_file.parent.mkdir(exist_ok=True) - try: - with open(mapping_file, "w", encoding="utf-8") as f: - f.write(orjson.dumps(self.module_name_mapping, option=orjson.OPT_INDENT_2).decode("utf-8")) - except Exception as e: - print(f"保存模块映射失败: {e}") - - def show_color_settings(self): - """显示颜色设置窗口""" - color_window = tk.Toplevel(self.root) - color_window.title("颜色设置") - color_window.geometry("300x400") - - # 创建滚动框架 - frame = ttk.Frame(color_window) - frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) - - # 创建滚动条 - scrollbar = ttk.Scrollbar(frame) - scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - - # 创建颜色设置列表 - canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set) - canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) - scrollbar.config(command=canvas.yview) - - # 创建内部框架 - inner_frame = ttk.Frame(canvas) - canvas.create_window((0, 0), window=inner_frame, anchor="nw") - - # 添加日志级别颜色设置 - ttk.Label(inner_frame, text="日志级别颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5) - for level in ["info", "warning", "error"]: - frame = ttk.Frame(inner_frame) - frame.pack(fill=tk.X, padx=5, pady=2) - ttk.Label(frame, text=level).pack(side=tk.LEFT) - color_btn = ttk.Button( - frame, text="选择颜色", command=lambda level_name=level: self.choose_color(level_name) - ) - color_btn.pack(side=tk.RIGHT) - # 显示当前颜色 - color_label = ttk.Label(frame, text="■", foreground=self.formatter.level_colors[level]) - color_label.pack(side=tk.RIGHT, padx=5) - - # 添加模块颜色设置 - ttk.Label(inner_frame, text="\n模块颜色", font=("", 10, "bold")).pack(anchor="w", padx=5, pady=5) - for module in sorted(self.modules): - frame = ttk.Frame(inner_frame) - frame.pack(fill=tk.X, padx=5, pady=2) - ttk.Label(frame, text=module).pack(side=tk.LEFT) - color_btn = ttk.Button(frame, text="选择颜色", command=lambda m=module: self.choose_module_color(m)) - color_btn.pack(side=tk.RIGHT) - # 显示当前颜色 - color = self.formatter.module_colors.get(module, "black") - color_label = ttk.Label(frame, text="■", foreground=color) - color_label.pack(side=tk.RIGHT, padx=5) - - # 更新画布滚动区域 - inner_frame.update_idletasks() - canvas.config(scrollregion=canvas.bbox("all")) - - # 添加确定按钮 - ttk.Button(color_window, text="确定", command=color_window.destroy).pack(pady=5) - - def choose_color(self, level): - """选择日志级别颜色""" - color = colorchooser.askcolor(color=self.formatter.level_colors[level])[1] - if color: - self.formatter.level_colors[level] = color - self.custom_level_colors[level] = color # 保存到自定义颜色 - self.log_display.formatter = self.formatter - self.log_display.configure_text_tags() - self.save_viewer_config() # 自动保存配置 - self.filter_logs() - - def choose_module_color(self, module): - """选择模块颜色""" - color = colorchooser.askcolor(color=self.formatter.module_colors.get(module, "black"))[1] - if color: - self.formatter.module_colors[module] = color - self.custom_module_colors[module] = color # 保存到自定义颜色 - self.log_display.formatter = self.formatter - self.log_display.configure_text_tags() - self.save_viewer_config() # 自动保存配置 - self.filter_logs() - - def create_control_panel(self): - """创建控制面板""" - # 控制面板 - self.control_frame = ttk.Frame(self.main_frame) - self.control_frame.pack(fill=tk.X, pady=(0, 5)) - - # 文件选择框架 - self.file_frame = ttk.LabelFrame(self.control_frame, text="日志文件") - self.file_frame.pack(side=tk.TOP, fill=tk.X, padx=5, pady=(0, 5)) - - # 当前文件显示 - self.current_file_var = tk.StringVar(value=str(self.current_log_file)) - self.file_label = ttk.Label(self.file_frame, textvariable=self.current_file_var, foreground="blue") - self.file_label.pack(side=tk.LEFT, padx=5, pady=2) - - # 进度条 - self.progress_var = tk.DoubleVar() - self.progress_bar = ttk.Progressbar(self.file_frame, variable=self.progress_var, length=200) - self.progress_bar.pack(side=tk.LEFT, padx=5, pady=2) - self.progress_bar.pack_forget() - - # 状态标签 - self.status_var = tk.StringVar(value="就绪") - self.status_label = ttk.Label(self.file_frame, textvariable=self.status_var) - self.status_label.pack(side=tk.LEFT, padx=5, pady=2) - - # 按钮区域 - button_frame = ttk.Frame(self.file_frame) - button_frame.pack(side=tk.RIGHT, padx=5, pady=2) - - ttk.Button(button_frame, text="选择文件", command=self.select_log_file).pack(side=tk.LEFT, padx=2) - ttk.Button(button_frame, text="刷新", command=self.refresh_log_file).pack(side=tk.LEFT, padx=2) - ttk.Checkbutton(button_frame, text="实时更新", variable=self.is_watching, command=self.toggle_watching).pack( - side=tk.LEFT, padx=2 - ) - - # 模块选择框架 - self.module_frame = ttk.LabelFrame(self.control_frame, text="模块") - self.module_frame.pack(side=tk.LEFT, fill=tk.X, expand=True, padx=5) - - # 创建模块选择滚动区域 - self.module_canvas = tk.Canvas(self.module_frame, height=80) - self.module_canvas.pack(side=tk.LEFT, fill=tk.X, expand=True) - - # 创建模块选择内部框架 - self.module_inner_frame = ttk.Frame(self.module_canvas) - self.module_canvas.create_window((0, 0), window=self.module_inner_frame, anchor="nw") - - # 创建右侧控制区域(级别和搜索) - self.right_control_frame = ttk.Frame(self.control_frame) - self.right_control_frame.pack(side=tk.RIGHT, padx=5) - - # 映射编辑按钮 - mapping_btn = ttk.Button(self.right_control_frame, text="模块映射", command=self.edit_module_mapping) - mapping_btn.pack(side=tk.TOP, fill=tk.X, pady=1) - - # 日志级别选择 - level_frame = ttk.Frame(self.right_control_frame) - level_frame.pack(side=tk.TOP, fill=tk.X, pady=1) - ttk.Label(level_frame, text="级别:").pack(side=tk.LEFT, padx=2) - self.level_var = tk.StringVar(value="全部") - self.level_combo = ttk.Combobox(level_frame, textvariable=self.level_var, width=8) - self.level_combo["values"] = ["全部", "debug", "info", "warning", "error", "critical"] - self.level_combo.pack(side=tk.LEFT, padx=2) - - # 搜索框 - search_frame = ttk.Frame(self.right_control_frame) - search_frame.pack(side=tk.TOP, fill=tk.X, pady=1) - ttk.Label(search_frame, text="搜索:").pack(side=tk.LEFT, padx=2) - self.search_var = tk.StringVar() - self.search_entry = ttk.Entry(search_frame, textvariable=self.search_var, width=15) - self.search_entry.pack(side=tk.LEFT, padx=2) - - def on_file_loaded(self, log_index, error): - """文件加载完成回调""" - self.progress_bar.pack_forget() - - if error: - self.status_var.set(f"加载失败: {error}") - messagebox.showerror("错误", f"加载日志文件失败: {error}") - return - - self.log_index = log_index - try: - self.last_file_size = os.path.getsize(self.current_log_file) - except OSError: - self.last_file_size = 0 - self.status_var.set(f"已加载 {log_index.total_entries} 条日志") - - # 更新模块列表 - self.modules = set(log_index.module_index.keys()) - self.update_module_list() - - # 应用过滤并显示 - self.filter_logs() - - # 如果开启了实时更新,则开始监视 - if self.is_watching.get(): - self.start_watching() - - def on_loading_progress(self, progress, line_count): - """加载进度回调""" - self.root.after(0, lambda: self.update_progress(progress, line_count)) - - def update_progress(self, progress, line_count): - """更新进度显示""" - self.progress_var.set(progress) - self.status_var.set(f"正在加载... {line_count} 条 ({progress:.1f}%)") - - def load_log_file_async(self): - """异步加载日志文件""" - self.stop_watching() # 停止任何正在运行的监视器 - - if not self.current_log_file.exists(): - self.status_var.set("文件不存在") - return - - # 显示进度条 - self.progress_bar.pack(side=tk.LEFT, padx=5, pady=2, before=self.status_label) - self.progress_var.set(0) - self.status_var.set("正在加载...") - - # 清空当前数据 - self.log_index = LogIndex() - self.selected_modules.clear() - - # 开始异步加载 - self.async_loader.load_file_async(str(self.current_log_file), self.on_loading_progress) - - def filter_logs(self, *args): - """过滤日志""" - if not self.log_index: - return - - # 获取过滤条件 - selected_modules = self.selected_modules if self.selected_modules else None - level = self.level_var.get() if self.level_var.get() != "全部" else None - search_text = self.search_var.get().strip() if self.search_var.get().strip() else None - - # 应用过滤 - self.log_index.filter_entries(selected_modules, level, search_text) - - # 更新显示 - self.log_display.set_log_index(self.log_index) - - # 更新状态 - filtered_count = self.log_index.get_filtered_count() - total_count = self.log_index.total_entries - if filtered_count == total_count: - self.status_var.set(f"显示 {total_count} 条日志") - else: - self.status_var.set(f"显示 {filtered_count}/{total_count} 条日志") - - def select_log_file(self): - """选择日志文件""" - filename = filedialog.askopenfilename( - title="选择日志文件", - filetypes=[("JSONL日志文件", "*.jsonl"), ("所有文件", "*.*")], - initialdir="logs" if Path("logs").exists() else ".", - ) - if filename: - new_file = Path(filename) - if new_file != self.current_log_file: - self.current_log_file = new_file - self.current_file_var.set(str(self.current_log_file)) - self.load_log_file_async() - - def refresh_log_file(self): - """刷新日志文件""" - self.load_log_file_async() - - def toggle_watching(self): - """切换实时更新状态""" - if self.is_watching.get(): - self.start_watching() - else: - self.stop_watching() - - def start_watching(self): - """开始监视文件变化""" - if self.watching_thread and self.watching_thread.is_alive(): - return # 已经在监视 - - if not self.current_log_file.exists(): - self.is_watching.set(False) - messagebox.showwarning("警告", "日志文件不存在,无法开启实时更新。") - return - - self.watching_thread = threading.Thread(target=self.watch_file_loop, daemon=True) - self.watching_thread.start() - - def stop_watching(self): - """停止监视文件变化""" - self.is_watching.set(False) - # 线程通过检查 is_watching 变量来停止,这里不需要强制干预 - self.watching_thread = None - - def watch_file_loop(self): - """监视文件循环""" - while self.is_watching.get(): - try: - if not self.current_log_file.exists(): - self.root.after( - 0, - lambda: messagebox.showwarning("警告", "日志文件丢失,已停止实时更新。"), - ) - self.root.after(0, self.is_watching.set, False) - break - - current_size = os.path.getsize(self.current_log_file) - if current_size > self.last_file_size: - new_entries = self.read_new_logs(self.last_file_size) - self.last_file_size = current_size - if new_entries: - self.root.after(0, self.append_new_logs, new_entries) - elif current_size < self.last_file_size: - # 文件被截断或替换 - self.last_file_size = 0 - self.root.after(0, self.refresh_log_file) - break # 刷新会重新启动监视(如果需要),所以结束当前循环 - - except Exception as e: - print(f"监视日志文件时出错: {e}") - self.root.after(0, self.is_watching.set, False) - break - - time.sleep(1) - - self.watching_thread = None - - def read_new_logs(self, from_position): - """读取新的日志条目并返回它们""" - new_entries = [] - new_modules = set() # 收集新发现的模块 - with open(self.current_log_file, encoding="utf-8") as f: - f.seek(from_position) - line_count = self.log_index.total_entries - for line in f: - if line.strip(): - try: - log_entry = orjson.loads(line) - self.log_index.add_entry(line_count, log_entry) - new_entries.append(log_entry) - - logger_name = log_entry.get("logger_name", "") - if logger_name and logger_name not in self.modules: - new_modules.add(logger_name) - - line_count += 1 - except orjson.JSONDecodeError: - continue - - # 如果发现了新模块,在主线程中更新模块集合 - if new_modules: - - def update_modules(): - self.modules.update(new_modules) - self.update_module_list() - - self.root.after(0, update_modules) - - return new_entries - - def append_new_logs(self, new_entries): - """将新日志附加到显示中""" - # 检查是否应附加或执行完全刷新(例如,如果过滤器处于活动状态) - selected_modules = ( - self.selected_modules if (self.selected_modules and "全部" not in self.selected_modules) else None - ) - level = self.level_var.get() if self.level_var.get() != "全部" else None - search_text = self.search_var.get().strip() if self.search_var.get().strip() else None - - is_filtered = selected_modules or level or search_text - - if is_filtered: - # 如果过滤器处于活动状态,我们必须执行完全刷新以应用它们 - self.filter_logs() - return - - # 如果没有过滤器,只需附加新日志 - for entry in new_entries: - self.log_display.append_entry(entry) - - # 更新状态 - total_count = self.log_index.total_entries - self.status_var.set(f"显示 {total_count} 条日志") - - def update_module_list(self): - """更新模块列表""" - # 清空现有选项 - for widget in self.module_inner_frame.winfo_children(): - widget.destroy() - - # 计算总模块数(包括"全部") - total_modules = len(self.modules) + 1 - max_cols = min(4, max(2, total_modules)) # 减少最大列数,避免超出边界 - - # 配置网格列权重,让每列平均分配空间 - for i in range(max_cols): - self.module_inner_frame.grid_columnconfigure(i, weight=1, uniform="module_col") - - # 创建一个多行布局 - current_row = 0 - current_col = 0 - - # 添加"全部"选项 - all_frame = ttk.Frame(self.module_inner_frame) - all_frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew") - - all_var = tk.BooleanVar(value="全部" in self.selected_modules) - all_check = ttk.Checkbutton( - all_frame, text="全部", variable=all_var, command=lambda: self.toggle_module("全部", all_var) - ) - all_check.pack(side=tk.LEFT) - - # 使用颜色标签替代按钮 - all_color = self.formatter.module_colors.get("全部", "black") - all_color_label = ttk.Label(all_frame, text="■", foreground=all_color, width=2, cursor="hand2") - all_color_label.pack(side=tk.LEFT, padx=2) - all_color_label.bind("", lambda e: self.choose_module_color("全部")) - - current_col += 1 - - # 添加其他模块选项 - for module in sorted(self.modules): - if current_col >= max_cols: - current_row += 1 - current_col = 0 - - frame = ttk.Frame(self.module_inner_frame) - frame.grid(row=current_row, column=current_col, padx=3, pady=2, sticky="ew") - - var = tk.BooleanVar(value=module in self.selected_modules) - - # 使用中文映射名称显示 - display_name = self.get_display_name(module) - if len(display_name) > 12: - display_name = display_name[:10] + "..." - - check = ttk.Checkbutton( - frame, text=display_name, variable=var, command=lambda m=module, v=var: self.toggle_module(m, v) - ) - check.pack(side=tk.LEFT) - - # 添加工具提示显示完整名称和英文名 - full_tooltip = f"{self.get_display_name(module)}" - if module != self.get_display_name(module): - full_tooltip += f"\n({module})" - self.create_tooltip(check, full_tooltip) - - # 使用颜色标签替代按钮 - color = self.formatter.module_colors.get(module, "black") - color_label = ttk.Label(frame, text="■", foreground=color, width=2, cursor="hand2") - color_label.pack(side=tk.LEFT, padx=2) - color_label.bind("", lambda e, m=module: self.choose_module_color(m)) - - current_col += 1 - - # 更新画布滚动区域 - self.module_inner_frame.update_idletasks() - self.module_canvas.config(scrollregion=self.module_canvas.bbox("all")) - - # 添加垂直滚动条 - if not hasattr(self, "module_scrollbar"): - self.module_scrollbar = ttk.Scrollbar( - self.module_frame, orient=tk.VERTICAL, command=self.module_canvas.yview - ) - self.module_scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - self.module_canvas.config(yscrollcommand=self.module_scrollbar.set) - - def create_tooltip(self, widget, text): - """为控件创建工具提示""" - - def on_enter(event): - tooltip = tk.Toplevel() - tooltip.wm_overrideredirect(True) - tooltip.wm_geometry(f"+{event.x_root + 10}+{event.y_root + 10}") - label = ttk.Label(tooltip, text=text, background="lightyellow", relief="solid", borderwidth=1) - label.pack() - widget.tooltip = tooltip - - def on_leave(event): - if hasattr(widget, "tooltip"): - widget.tooltip.destroy() - del widget.tooltip - - widget.bind("", on_enter) - widget.bind("", on_leave) - - def toggle_module(self, module, var): - """切换模块选择状态""" - if module == "全部": - if var.get(): - self.selected_modules = {"全部"} - else: - self.selected_modules.clear() - else: - if var.get(): - self.selected_modules.add(module) - if "全部" in self.selected_modules: - self.selected_modules.remove("全部") - else: - self.selected_modules.discard(module) - - self.filter_logs() - - def get_display_name(self, module_name): - """获取模块的显示名称""" - return self.module_name_mapping.get(module_name, module_name) - - def edit_module_mapping(self): - """编辑模块映射""" - mapping_window = tk.Toplevel(self.root) - mapping_window.title("编辑模块映射") - mapping_window.geometry("500x600") - - # 创建滚动框架 - frame = ttk.Frame(mapping_window) - frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5) - - # 创建滚动条 - scrollbar = ttk.Scrollbar(frame) - scrollbar.pack(side=tk.RIGHT, fill=tk.Y) - - # 创建映射编辑列表 - canvas = tk.Canvas(frame, yscrollcommand=scrollbar.set) - canvas.pack(side=tk.LEFT, fill=tk.BOTH, expand=True) - scrollbar.config(command=canvas.yview) - - # 创建内部框架 - inner_frame = ttk.Frame(canvas) - canvas.create_window((0, 0), window=inner_frame, anchor="nw") - - # 添加标题 - ttk.Label(inner_frame, text="模块映射编辑", font=("", 12, "bold")).pack(anchor="w", padx=5, pady=5) - ttk.Label(inner_frame, text="英文名 -> 中文名", font=("", 10)).pack(anchor="w", padx=5, pady=2) - - # 映射编辑字典 - mapping_vars = {} - - # 添加现有模块的映射编辑 - all_modules = sorted(self.modules) - for module in all_modules: - frame_row = ttk.Frame(inner_frame) - frame_row.pack(fill=tk.X, padx=5, pady=2) - - ttk.Label(frame_row, text=module, width=20).pack(side=tk.LEFT, padx=5) - ttk.Label(frame_row, text="->").pack(side=tk.LEFT, padx=5) - - var = tk.StringVar(value=self.module_name_mapping.get(module, module)) - mapping_vars[module] = var - entry = ttk.Entry(frame_row, textvariable=var, width=25) - entry.pack(side=tk.LEFT, padx=5) - - # 更新画布滚动区域 - inner_frame.update_idletasks() - canvas.config(scrollregion=canvas.bbox("all")) - - def save_mappings(): - # 更新映射 - for module, var in mapping_vars.items(): - new_name = var.get().strip() - if new_name and new_name != module: - self.module_name_mapping[module] = new_name - elif module in self.module_name_mapping and not new_name: - del self.module_name_mapping[module] - - # 保存到文件 - self.save_module_mapping() - # 更新模块列表显示 - self.update_module_list() - mapping_window.destroy() - - # 添加按钮 - button_frame = ttk.Frame(mapping_window) - button_frame.pack(fill=tk.X, padx=5, pady=5) - ttk.Button(button_frame, text="保存", command=save_mappings).pack(side=tk.RIGHT, padx=5) - ttk.Button(button_frame, text="取消", command=mapping_window.destroy).pack(side=tk.RIGHT, padx=5) - - -def main(): - root = tk.Tk() - LogViewer(root) - root.mainloop() - - -if __name__ == "__main__": - main() diff --git a/scripts/manifest_tool.py b/scripts/manifest_tool.py deleted file mode 100644 index c18b6a208..000000000 --- a/scripts/manifest_tool.py +++ /dev/null @@ -1,239 +0,0 @@ -""" -插件Manifest管理命令行工具 - -提供插件manifest文件的创建、验证和管理功能 -""" - -import argparse -import os -import sys -from pathlib import Path - -import orjson - -from src.common.logger import get_logger -from src.plugin_system.utils.manifest_utils import ( - ManifestValidator, -) - -# 添加项目根目录到Python路径 -project_root = Path(__file__).parent.parent.parent.parent -sys.path.insert(0, str(project_root)) - - -logger = get_logger("manifest_tool") - - -def create_minimal_manifest(plugin_dir: str, plugin_name: str, description: str = "", author: str = "") -> bool: - """创建最小化的manifest文件 - - Args: - plugin_dir: 插件目录 - plugin_name: 插件名称 - description: 插件描述 - author: 插件作者 - - Returns: - bool: 是否创建成功 - """ - manifest_path = os.path.join(plugin_dir, "_manifest.json") - - if os.path.exists(manifest_path): - print(f"❌ Manifest文件已存在: {manifest_path}") - return False - - # 创建最小化manifest - minimal_manifest = { - "manifest_version": 1, - "name": plugin_name, - "version": "1.0.0", - "description": description or f"{plugin_name}插件", - "author": {"name": author or "Unknown"}, - } - - try: - with open(manifest_path, "w", encoding="utf-8") as f: - f.write(orjson.dumps(minimal_manifest, option=orjson.OPT_INDENT_2).decode("utf-8")) - print(f"✅ 已创建最小化manifest文件: {manifest_path}") - return True - except Exception as e: - print(f"❌ 创建manifest文件失败: {e}") - return False - - -def create_complete_manifest(plugin_dir: str, plugin_name: str) -> bool: - """创建完整的manifest模板文件 - - Args: - plugin_dir: 插件目录 - plugin_name: 插件名称 - - Returns: - bool: 是否创建成功 - """ - manifest_path = os.path.join(plugin_dir, "_manifest.json") - - if os.path.exists(manifest_path): - print(f"❌ Manifest文件已存在: {manifest_path}") - return False - - # 创建完整模板 - complete_manifest = { - "manifest_version": 1, - "name": plugin_name, - "version": "1.0.0", - "description": f"{plugin_name}插件描述", - "author": {"name": "插件作者", "url": "https://github.com/your-username"}, - "license": "MIT", - "host_application": {"min_version": "1.0.0", "max_version": "4.0.0"}, - "homepage_url": "https://github.com/your-repo", - "repository_url": "https://github.com/your-repo", - "keywords": ["keyword1", "keyword2"], - "categories": ["Category1"], - "default_locale": "zh-CN", - "locales_path": "_locales", - "plugin_info": { - "is_built_in": False, - "plugin_type": "general", - "components": [{"type": "action", "name": "sample_action", "description": "示例动作组件"}], - }, - } - - try: - with open(manifest_path, "w", encoding="utf-8") as f: - f.write(orjson.dumps(complete_manifest, option=orjson.OPT_INDENT_2).decode("utf-8")) - print(f"✅ 已创建完整manifest模板: {manifest_path}") - print("💡 请根据实际情况修改manifest文件中的内容") - return True - except Exception as e: - print(f"❌ 创建manifest文件失败: {e}") - return False - - -def validate_manifest_file(plugin_dir: str) -> bool: - """验证manifest文件 - - Args: - plugin_dir: 插件目录 - - Returns: - bool: 是否验证通过 - """ - manifest_path = os.path.join(plugin_dir, "_manifest.json") - - if not os.path.exists(manifest_path): - print(f"❌ 未找到manifest文件: {manifest_path}") - return False - - try: - with open(manifest_path, encoding="utf-8") as f: - manifest_data = orjson.loads(f.read()) - - validator = ManifestValidator() - is_valid = validator.validate_manifest(manifest_data) - - # 显示验证结果 - print("📋 Manifest验证结果:") - print(validator.get_validation_report()) - - if is_valid: - print("✅ Manifest文件验证通过") - else: - print("❌ Manifest文件验证失败") - - return is_valid - - except orjson.JSONDecodeError as e: - print(f"❌ Manifest文件格式错误: {e}") - return False - except Exception as e: - print(f"❌ 验证过程中发生错误: {e}") - return False - - -def scan_plugins_without_manifest(root_dir: str) -> None: - """扫描缺少manifest文件的插件 - - Args: - root_dir: 扫描的根目录 - """ - print(f"🔍 扫描目录: {root_dir}") - - plugins_without_manifest = [] - - for root, dirs, files in os.walk(root_dir): - # 跳过隐藏目录和__pycache__ - dirs[:] = [d for d in dirs if not d.startswith(".") and d != "__pycache__"] - - # 检查是否包含plugin.py文件(标识为插件目录) - if "plugin.py" in files: - manifest_path = os.path.join(root, "_manifest.json") - if not os.path.exists(manifest_path): - plugins_without_manifest.append(root) - - if plugins_without_manifest: - print(f"❌ 发现 {len(plugins_without_manifest)} 个插件缺少manifest文件:") - for plugin_dir in plugins_without_manifest: - plugin_name = os.path.basename(plugin_dir) - print(f" - {plugin_name}: {plugin_dir}") - print("💡 使用 'python manifest_tool.py create-minimal <插件目录>' 创建manifest文件") - else: - print("✅ 所有插件都有manifest文件") - - -def main(): - """主函数""" - parser = argparse.ArgumentParser(description="插件Manifest管理工具") - subparsers = parser.add_subparsers(dest="command", help="可用命令") - - # 创建最小化manifest命令 - create_minimal_parser = subparsers.add_parser("create-minimal", help="创建最小化manifest文件") - create_minimal_parser.add_argument("plugin_dir", help="插件目录路径") - create_minimal_parser.add_argument("--name", help="插件名称") - create_minimal_parser.add_argument("--description", help="插件描述") - create_minimal_parser.add_argument("--author", help="插件作者") - - # 创建完整manifest命令 - create_complete_parser = subparsers.add_parser("create-complete", help="创建完整manifest模板") - create_complete_parser.add_argument("plugin_dir", help="插件目录路径") - create_complete_parser.add_argument("--name", help="插件名称") - - # 验证manifest命令 - validate_parser = subparsers.add_parser("validate", help="验证manifest文件") - validate_parser.add_argument("plugin_dir", help="插件目录路径") - - # 扫描插件命令 - scan_parser = subparsers.add_parser("scan", help="扫描缺少manifest的插件") - scan_parser.add_argument("root_dir", help="扫描的根目录路径") - - args = parser.parse_args() - - if not args.command: - parser.print_help() - return - - try: - if args.command == "create-minimal": - plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir)) - success = create_minimal_manifest(args.plugin_dir, plugin_name, args.description or "", args.author or "") - sys.exit(0 if success else 1) - - elif args.command == "create-complete": - plugin_name = args.name or os.path.basename(os.path.abspath(args.plugin_dir)) - success = create_complete_manifest(args.plugin_dir, plugin_name) - sys.exit(0 if success else 1) - - elif args.command == "validate": - success = validate_manifest_file(args.plugin_dir) - sys.exit(0 if success else 1) - - elif args.command == "scan": - scan_plugins_without_manifest(args.root_dir) - - except Exception as e: - print(f"❌ 执行命令时发生错误: {e}") - sys.exit(1) - - -if __name__ == "__main__": - main() diff --git a/scripts/mongodb_to_sqlite.py b/scripts/mongodb_to_sqlite.py deleted file mode 100644 index a502e018f..000000000 --- a/scripts/mongodb_to_sqlite.py +++ /dev/null @@ -1,922 +0,0 @@ -import os - -# import time -import pickle -import sys # 新增系统模块导入 -from pathlib import Path - -import orjson - -sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) -from dataclasses import dataclass, field -from datetime import datetime -from typing import Any - -from peewee import Field, IntegrityError, Model -from pymongo import MongoClient -from pymongo.errors import ConnectionFailure - -# Rich 进度条和显示组件 -from rich.console import Console -from rich.panel import Panel -from rich.progress import ( - BarColumn, - Progress, - SpinnerColumn, - TaskProgressColumn, - TextColumn, - TimeElapsedColumn, - TimeRemainingColumn, -) -from rich.table import Table - -# from rich.text import Text -from src.common.database.database import db -from src.common.database.sqlalchemy_models import ( - ChatStreams, - Emoji, - GraphEdges, - GraphNodes, - ImageDescriptions, - Images, - Knowledges, - Messages, - PersonInfo, - ThinkingLog, -) -from src.common.logger import get_logger - -logger = get_logger("mongodb_to_sqlite") - -ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) - - -@dataclass -class MigrationConfig: - """迁移配置类""" - - mongo_collection: str - target_model: type[Model] - field_mapping: dict[str, str] - batch_size: int = 500 - enable_validation: bool = True - skip_duplicates: bool = True - unique_fields: list[str] = field(default_factory=list) # 用于重复检查的字段 - - -# 数据验证相关类已移除 - 用户要求不要数据验证 - - -@dataclass -class MigrationCheckpoint: - """迁移断点数据""" - - collection_name: str - processed_count: int - last_processed_id: Any - timestamp: datetime - batch_errors: list[dict[str, Any]] = field(default_factory=list) - - -@dataclass -class MigrationStats: - """迁移统计信息""" - - total_documents: int = 0 - processed_count: int = 0 - success_count: int = 0 - error_count: int = 0 - skipped_count: int = 0 - duplicate_count: int = 0 - validation_errors: int = 0 - batch_insert_count: int = 0 - errors: list[dict[str, Any]] = field(default_factory=list) - start_time: datetime | None = None - end_time: datetime | None = None - - def add_error(self, doc_id: Any, error: str, doc_data: dict | None = None): - """添加错误记录""" - self.errors.append( - {"doc_id": str(doc_id), "error": error, "timestamp": datetime.now().isoformat(), "doc_data": doc_data} - ) - self.error_count += 1 - - def add_validation_error(self, doc_id: Any, field: str, error: str): - """添加验证错误""" - self.add_error(doc_id, f"验证失败 - {field}: {error}") - self.validation_errors += 1 - - -class MongoToSQLiteMigrator: - """MongoDB到SQLite数据迁移器 - 使用Peewee ORM""" - - def __init__(self, mongo_uri: str | None = None, database_name: str | None = None): - self.database_name = database_name or os.getenv("DATABASE_NAME", "MegBot") - self.mongo_uri = mongo_uri or self._build_mongo_uri() - self.mongo_client: MongoClient | None = None - self.mongo_db = None - - # 迁移配置 - self.migration_configs = self._initialize_migration_configs() - - # 进度条控制台 - self.console = Console() - # 检查点目录 - self.checkpoint_dir = Path(os.path.join(ROOT_PATH, "data", "checkpoints")) - self.checkpoint_dir.mkdir(exist_ok=True) - - # 验证规则已禁用 - self.validation_rules = self._initialize_validation_rules() - - def _build_mongo_uri(self) -> str: - """构建MongoDB连接URI""" - if mongo_uri := os.getenv("MONGODB_URI"): - return mongo_uri - - user = os.getenv("MONGODB_USER") - password = os.getenv("MONGODB_PASS") - host = os.getenv("MONGODB_HOST", "localhost") - port = os.getenv("MONGODB_PORT", "27017") - auth_source = os.getenv("MONGODB_AUTH_SOURCE", "admin") - - if user and password: - return f"mongodb://{user}:{password}@{host}:{port}/{self.database_name}?authSource={auth_source}" - else: - return f"mongodb://{host}:{port}/{self.database_name}" - - def _initialize_migration_configs(self) -> list[MigrationConfig]: - """初始化迁移配置""" - return [ # 表情包迁移配置 - MigrationConfig( - mongo_collection="emoji", - target_model=Emoji, - field_mapping={ - "full_path": "full_path", - "format": "format", - "hash": "emoji_hash", - "description": "description", - "emotion": "emotion", - "usage_count": "usage_count", - "last_used_time": "last_used_time", - # record_time字段将在转换时自动设置为当前时间 - }, - enable_validation=False, # 禁用数据验证 - unique_fields=["full_path", "emoji_hash"], - ), - # 聊天流迁移配置 - MigrationConfig( - mongo_collection="chat_streams", - target_model=ChatStreams, - field_mapping={ - "stream_id": "stream_id", - "create_time": "create_time", - "group_info.platform": "group_platform", # 由于Mongodb处理私聊时会让group_info值为null,而新的数据库不允许为null,所以私聊聊天流是没法迁移的,等更新吧。 - "group_info.group_id": "group_id", # 同上 - "group_info.group_name": "group_name", # 同上 - "last_active_time": "last_active_time", - "platform": "platform", - "user_info.platform": "user_platform", - "user_info.user_id": "user_id", - "user_info.user_nickname": "user_nickname", - "user_info.user_cardname": "user_cardname", - }, - enable_validation=False, # 禁用数据验证 - unique_fields=["stream_id"], - ), - # 消息迁移配置 - MigrationConfig( - mongo_collection="messages", - target_model=Messages, - field_mapping={ - "message_id": "message_id", - "time": "time", - "chat_id": "chat_id", - "chat_info.stream_id": "chat_info_stream_id", - "chat_info.platform": "chat_info_platform", - "chat_info.user_info.platform": "chat_info_user_platform", - "chat_info.user_info.user_id": "chat_info_user_id", - "chat_info.user_info.user_nickname": "chat_info_user_nickname", - "chat_info.user_info.user_cardname": "chat_info_user_cardname", - "chat_info.group_info.platform": "chat_info_group_platform", - "chat_info.group_info.group_id": "chat_info_group_id", - "chat_info.group_info.group_name": "chat_info_group_name", - "chat_info.create_time": "chat_info_create_time", - "chat_info.last_active_time": "chat_info_last_active_time", - "user_info.platform": "user_platform", - "user_info.user_id": "user_id", - "user_info.user_nickname": "user_nickname", - "user_info.user_cardname": "user_cardname", - "processed_plain_text": "processed_plain_text", - "memorized_times": "memorized_times", - }, - enable_validation=False, # 禁用数据验证 - unique_fields=["message_id"], - ), - # 图片迁移配置 - MigrationConfig( - mongo_collection="images", - target_model=Images, - field_mapping={ - "hash": "emoji_hash", - "description": "description", - "path": "path", - "timestamp": "timestamp", - "type": "type", - }, - unique_fields=["path"], - ), - # 图片描述迁移配置 - MigrationConfig( - mongo_collection="image_descriptions", - target_model=ImageDescriptions, - field_mapping={ - "type": "type", - "hash": "image_description_hash", - "description": "description", - "timestamp": "timestamp", - }, - unique_fields=["image_description_hash", "type"], - ), - # 个人信息迁移配置 - MigrationConfig( - mongo_collection="person_info", - target_model=PersonInfo, - field_mapping={ - "person_id": "person_id", - "person_name": "person_name", - "name_reason": "name_reason", - "platform": "platform", - "user_id": "user_id", - "nickname": "nickname", - "relationship_value": "relationship_value", - "konw_time": "know_time", - }, - unique_fields=["person_id"], - ), - # 知识库迁移配置 - MigrationConfig( - mongo_collection="knowledges", - target_model=Knowledges, - field_mapping={"content": "content", "embedding": "embedding"}, - unique_fields=["content"], # 假设内容唯一 - ), - # 思考日志迁移配置 - MigrationConfig( - mongo_collection="thinking_log", - target_model=ThinkingLog, - field_mapping={ - "chat_id": "chat_id", - "trigger_text": "trigger_text", - "response_text": "response_text", - "trigger_info": "trigger_info_json", - "response_info": "response_info_json", - "timing_results": "timing_results_json", - "chat_history": "chat_history_json", - "chat_history_in_thinking": "chat_history_in_thinking_json", - "chat_history_after_response": "chat_history_after_response_json", - "heartflow_data": "heartflow_data_json", - "reasoning_data": "reasoning_data_json", - }, - unique_fields=["chat_id", "trigger_text"], - ), - # 图节点迁移配置 - MigrationConfig( - mongo_collection="graph_data.nodes", - target_model=GraphNodes, - field_mapping={ - "concept": "concept", - "memory_items": "memory_items", - "hash": "hash", - "created_time": "created_time", - "last_modified": "last_modified", - }, - unique_fields=["concept"], - ), - # 图边迁移配置 - MigrationConfig( - mongo_collection="graph_data.edges", - target_model=GraphEdges, - field_mapping={ - "source": "source", - "target": "target", - "strength": "strength", - "hash": "hash", - "created_time": "created_time", - "last_modified": "last_modified", - }, - unique_fields=["source", "target"], # 组合唯一性 - ), - ] - - def _initialize_validation_rules(self) -> dict[str, Any]: - """数据验证已禁用 - 返回空字典""" - return {} - - def connect_mongodb(self) -> bool: - """连接到MongoDB""" - try: - self.mongo_client = MongoClient( - self.mongo_uri, serverSelectionTimeoutMS=5000, connectTimeoutMS=10000, maxPoolSize=10 - ) - - # 测试连接 - self.mongo_client.admin.command("ping") - self.mongo_db = self.mongo_client[self.database_name] - - logger.info(f"成功连接到MongoDB: {self.database_name}") - return True - - except ConnectionFailure as e: - logger.error(f"MongoDB连接失败: {e}") - return False - except Exception as e: - logger.error(f"MongoDB连接异常: {e}") - return False - - def disconnect_mongodb(self): - """断开MongoDB连接""" - if self.mongo_client: - self.mongo_client.close() - logger.info("MongoDB连接已关闭") - - def _get_nested_value(self, document: dict[str, Any], field_path: str) -> Any: - """获取嵌套字段的值""" - if "." not in field_path: - return document.get(field_path) - - parts = field_path.split(".") - value = document - - for part in parts: - if isinstance(value, dict): - value = value.get(part) - else: - return None - - if value is None: - break - - return value - - def _convert_field_value(self, value: Any, target_field: Field) -> Any: - """根据目标字段类型转换值""" - if value is None: - return None - - field_type = target_field.__class__.__name__ - - try: - if target_field.name == "record_time" and field_type == "DateTimeField": - return datetime.now() - - if field_type in ["CharField", "TextField"]: - if isinstance(value, list | dict): - return orjson.dumps(value, ensure_ascii=False) - return str(value) if value is not None else "" - - elif field_type == "IntegerField": - if isinstance(value, str): - # 处理字符串数字 - clean_value = value.strip() - if clean_value.replace(".", "").replace("-", "").isdigit(): - return int(float(clean_value)) - return 0 - return int(value) if value is not None else 0 - - elif field_type in ["FloatField", "DoubleField"]: - return float(value) if value is not None else 0.0 - - elif field_type == "BooleanField": - if isinstance(value, str): - return value.lower() in ("true", "1", "yes", "on") - return bool(value) - - elif field_type == "DateTimeField": - if isinstance(value, int | float): - return datetime.fromtimestamp(value) - elif isinstance(value, str): - try: - # 尝试解析ISO格式日期 - return datetime.fromisoformat(value.replace("Z", "+00:00")) - except ValueError: - try: - # 尝试解析时间戳字符串 - return datetime.fromtimestamp(float(value)) - except ValueError: - return datetime.now() - return datetime.now() - - return value - - except (ValueError, TypeError) as e: - logger.warning(f"字段值转换失败 ({field_type}): {value} -> {e}") - return self._get_default_value_for_field(target_field) - - def _get_default_value_for_field(self, field: Field) -> Any: - """获取字段的默认值""" - field_type = field.__class__.__name__ - - if hasattr(field, "default") and field.default is not None: - return field.default - - if field.null: - return None - - # 根据字段类型返回默认值 - if field_type in ["CharField", "TextField"]: - return "" - elif field_type == "IntegerField": - return 0 - elif field_type in ["FloatField", "DoubleField"]: - return 0.0 - elif field_type == "BooleanField": - return False - elif field_type == "DateTimeField": - return datetime.now() - - return None - - def _validate_data(self, collection_name: str, data: dict[str, Any], doc_id: Any, stats: MigrationStats) -> bool: - """数据验证已禁用 - 始终返回True""" - return True - - def _save_checkpoint(self, collection_name: str, processed_count: int, last_id: Any): - """保存迁移断点""" - checkpoint = MigrationCheckpoint( - collection_name=collection_name, - processed_count=processed_count, - last_processed_id=last_id, - timestamp=datetime.now(), - ) - - checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" - try: - with open(checkpoint_file, "wb") as f: - pickle.dump(checkpoint, f) - except Exception as e: - logger.warning(f"保存断点失败: {e}") - - def _load_checkpoint(self, collection_name: str) -> MigrationCheckpoint | None: - """加载迁移断点""" - checkpoint_file = self.checkpoint_dir / f"{collection_name}_checkpoint.pkl" - if not checkpoint_file.exists(): - return None - - try: - with open(checkpoint_file, "rb") as f: - return pickle.load(f) - except Exception as e: - logger.warning(f"加载断点失败: {e}") - return None - - def _batch_insert(self, model: type[Model], data_list: list[dict[str, Any]]) -> int: - """批量插入数据""" - if not data_list: - return 0 - - success_count = 0 - try: - with db.atomic(): - # 分批插入,避免SQL语句过长 - batch_size = 100 - for i in range(0, len(data_list), batch_size): - batch = data_list[i : i + batch_size] - model.insert_many(batch).execute() - success_count += len(batch) - except Exception as e: - logger.error(f"批量插入失败: {e}") - # 如果批量插入失败,尝试逐个插入 - for data in data_list: - try: - model.create(**data) - success_count += 1 - except Exception: - pass # 忽略单个插入失败 - - return success_count - - def _check_duplicate_by_unique_fields( - self, model: type[Model], data: dict[str, Any], unique_fields: list[str] - ) -> bool: - """根据唯一字段检查重复""" - if not unique_fields: - return False - - try: - query = model.select() - for field_name in unique_fields: - if field_name in data and data[field_name] is not None: - field_obj = getattr(model, field_name) - query = query.where(field_obj == data[field_name]) - - return query.exists() - except Exception as e: - logger.debug(f"重复检查失败: {e}") - return False - - def _create_model_instance(self, model: type[Model], data: dict[str, Any]) -> Model | None: - """使用ORM创建模型实例""" - try: - # 过滤掉不存在的字段 - valid_data = {} - for field_name, value in data.items(): - if hasattr(model, field_name): - valid_data[field_name] = value - else: - logger.debug(f"跳过未知字段: {field_name}") - - # 创建实例 - instance = model.create(**valid_data) - return instance - - except IntegrityError as e: - # 处理唯一约束冲突等完整性错误 - logger.debug(f"完整性约束冲突: {e}") - return None - except Exception as e: - logger.error(f"创建模型实例失败: {e}") - return None - - def migrate_collection(self, config: MigrationConfig) -> MigrationStats: - """迁移单个集合 - 使用优化的批量插入和进度条""" - stats = MigrationStats() - stats.start_time = datetime.now() - - # 检查是否有断点 - checkpoint = self._load_checkpoint(config.mongo_collection) - start_from_id = checkpoint.last_processed_id if checkpoint else None - if checkpoint: - stats.processed_count = checkpoint.processed_count - logger.info(f"从断点恢复: 已处理 {checkpoint.processed_count} 条记录") - - logger.info(f"开始迁移: {config.mongo_collection} -> {config.target_model._meta.table_name}") - - try: - # 获取MongoDB集合 - mongo_collection = self.mongo_db[config.mongo_collection] - - # 构建查询条件(用于断点恢复) - query = {} - if start_from_id: - query = {"_id": {"$gt": start_from_id}} - - stats.total_documents = mongo_collection.count_documents(query) - - if stats.total_documents == 0: - logger.warning(f"集合 {config.mongo_collection} 为空,跳过迁移") - return stats - - logger.info(f"待迁移文档数量: {stats.total_documents}") - - # 创建Rich进度条 - with Progress( - SpinnerColumn(), - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - TimeElapsedColumn(), - TimeRemainingColumn(), - console=self.console, - refresh_per_second=10, - ) as progress: - task = progress.add_task(f"迁移 {config.mongo_collection}", total=stats.total_documents) - # 批量处理数据 - batch_data = [] - batch_count = 0 - last_processed_id = None - - for mongo_doc in mongo_collection.find(query).batch_size(config.batch_size): - try: - doc_id = mongo_doc.get("_id", "unknown") - last_processed_id = doc_id - - # 构建目标数据 - target_data = {} - for mongo_field, sqlite_field in config.field_mapping.items(): - value = self._get_nested_value(mongo_doc, mongo_field) - - # 获取目标字段对象并转换类型 - if hasattr(config.target_model, sqlite_field): - field_obj = getattr(config.target_model, sqlite_field) - converted_value = self._convert_field_value(value, field_obj) - target_data[sqlite_field] = converted_value - - # 数据验证已禁用 - # if config.enable_validation: - # if not self._validate_data(config.mongo_collection, target_data, doc_id, stats): - # stats.skipped_count += 1 - # continue - - # 重复检查 - if config.skip_duplicates and self._check_duplicate_by_unique_fields( - config.target_model, target_data, config.unique_fields - ): - stats.duplicate_count += 1 - stats.skipped_count += 1 - logger.debug(f"跳过重复记录: {doc_id}") - continue - - # 添加到批量数据 - batch_data.append(target_data) - stats.processed_count += 1 - - # 执行批量插入 - if len(batch_data) >= config.batch_size: - success_count = self._batch_insert(config.target_model, batch_data) - stats.success_count += success_count - stats.batch_insert_count += 1 - - # 保存断点 - self._save_checkpoint(config.mongo_collection, stats.processed_count, last_processed_id) - - batch_data.clear() - batch_count += 1 - - # 更新进度条 - progress.update(task, advance=config.batch_size) - - except Exception as e: - doc_id = mongo_doc.get("_id", "unknown") - stats.add_error(doc_id, f"处理文档异常: {e}", mongo_doc) - logger.error(f"处理文档失败 (ID: {doc_id}): {e}") - - # 处理剩余的批量数据 - if batch_data: - success_count = self._batch_insert(config.target_model, batch_data) - stats.success_count += success_count - stats.batch_insert_count += 1 - progress.update(task, advance=len(batch_data)) - - # 完成进度条 - progress.update(task, completed=stats.total_documents) - - stats.end_time = datetime.now() - duration = stats.end_time - stats.start_time - - logger.info( - f"迁移完成: {config.mongo_collection} -> {config.target_model._meta.table_name}\n" - f"总计: {stats.total_documents}, 成功: {stats.success_count}, " - f"错误: {stats.error_count}, 跳过: {stats.skipped_count}, 重复: {stats.duplicate_count}\n" - f"耗时: {duration.total_seconds():.2f}秒, 批量插入次数: {stats.batch_insert_count}" - ) - - # 清理断点文件 - checkpoint_file = self.checkpoint_dir / f"{config.mongo_collection}_checkpoint.pkl" - if checkpoint_file.exists(): - checkpoint_file.unlink() - - except Exception as e: - logger.error(f"迁移集合 {config.mongo_collection} 时发生异常: {e}") - stats.add_error("collection_error", str(e)) - - return stats - - def migrate_all(self) -> dict[str, MigrationStats]: - """执行所有迁移任务""" - logger.info("开始执行数据库迁移...") - - if not self.connect_mongodb(): - logger.error("无法连接到MongoDB,迁移终止") - return {} - - all_stats = {} - - try: - # 创建总体进度表格 - total_collections = len(self.migration_configs) - self.console.print( - Panel( - f"[bold blue]MongoDB 到 SQLite 数据迁移[/bold blue]\n" - f"[yellow]总集合数: {total_collections}[/yellow]", - title="迁移开始", - expand=False, - ) - ) - for idx, config in enumerate(self.migration_configs, 1): - self.console.print( - f"\n[bold green]正在处理集合 {idx}/{total_collections}: {config.mongo_collection}[/bold green]" - ) - stats = self.migrate_collection(config) - all_stats[config.mongo_collection] = stats - - # 显示单个集合的快速统计 - if stats.processed_count > 0: - success_rate = stats.success_count / stats.processed_count * 100 - if success_rate >= 95: - status_emoji = "✅" - status_color = "bright_green" - elif success_rate >= 80: - status_emoji = "⚠️" - status_color = "yellow" - else: - status_emoji = "❌" - status_color = "red" - - self.console.print( - f" {status_emoji} [{status_color}]完成: {stats.success_count}/{stats.processed_count} " - f"({success_rate:.1f}%) 错误: {stats.error_count}[/{status_color}]" - ) - - # 错误率检查 - if stats.processed_count > 0: - error_rate = stats.error_count / stats.processed_count - if error_rate > 0.1: # 错误率超过10% - self.console.print( - f" [red]⚠️ 警告: 错误率较高 {error_rate:.1%} " - f"({stats.error_count}/{stats.processed_count})[/red]" - ) - - finally: - self.disconnect_mongodb() - - self._print_migration_summary(all_stats) - return all_stats - - def _print_migration_summary(self, all_stats: dict[str, MigrationStats]): - """使用Rich打印美观的迁移汇总信息""" - # 计算总体统计 - total_processed = sum(stats.processed_count for stats in all_stats.values()) - total_success = sum(stats.success_count for stats in all_stats.values()) - total_errors = sum(stats.error_count for stats in all_stats.values()) - total_skipped = sum(stats.skipped_count for stats in all_stats.values()) - total_duplicates = sum(stats.duplicate_count for stats in all_stats.values()) - total_validation_errors = sum(stats.validation_errors for stats in all_stats.values()) - total_batch_inserts = sum(stats.batch_insert_count for stats in all_stats.values()) - - # 计算总耗时 - total_duration_seconds = 0 - for stats in all_stats.values(): - if stats.start_time and stats.end_time: - duration = stats.end_time - stats.start_time - total_duration_seconds += duration.total_seconds() - - # 创建详细统计表格 - table = Table(title="[bold blue]数据迁移汇总报告[/bold blue]", show_header=True, header_style="bold magenta") - table.add_column("集合名称", style="cyan", width=20) - table.add_column("文档总数", justify="right", style="blue") - table.add_column("处理数量", justify="right", style="green") - table.add_column("成功数量", justify="right", style="green") - table.add_column("错误数量", justify="right", style="red") - table.add_column("跳过数量", justify="right", style="yellow") - table.add_column("重复数量", justify="right", style="bright_yellow") - table.add_column("验证错误", justify="right", style="red") - table.add_column("批次数", justify="right", style="purple") - table.add_column("成功率", justify="right", style="bright_green") - table.add_column("耗时(秒)", justify="right", style="blue") - - for collection_name, stats in all_stats.items(): - success_rate = (stats.success_count / stats.processed_count * 100) if stats.processed_count > 0 else 0 - duration = 0 - if stats.start_time and stats.end_time: - duration = (stats.end_time - stats.start_time).total_seconds() - - # 根据成功率设置颜色 - if success_rate >= 95: - success_rate_style = "[bright_green]" - elif success_rate >= 80: - success_rate_style = "[yellow]" - else: - success_rate_style = "[red]" - - table.add_row( - collection_name, - str(stats.total_documents), - str(stats.processed_count), - str(stats.success_count), - f"[red]{stats.error_count}[/red]" if stats.error_count > 0 else "0", - f"[yellow]{stats.skipped_count}[/yellow]" if stats.skipped_count > 0 else "0", - f"[bright_yellow]{stats.duplicate_count}[/bright_yellow]" if stats.duplicate_count > 0 else "0", - f"[red]{stats.validation_errors}[/red]" if stats.validation_errors > 0 else "0", - str(stats.batch_insert_count), - f"{success_rate_style}{success_rate:.1f}%[/{success_rate_style[1:]}", - f"{duration:.2f}", - ) - - # 添加总计行 - total_success_rate = (total_success / total_processed * 100) if total_processed > 0 else 0 - if total_success_rate >= 95: - total_rate_style = "[bright_green]" - elif total_success_rate >= 80: - total_rate_style = "[yellow]" - else: - total_rate_style = "[red]" - - table.add_section() - table.add_row( - "[bold]总计[/bold]", - f"[bold]{sum(stats.total_documents for stats in all_stats.values())}[/bold]", - f"[bold]{total_processed}[/bold]", - f"[bold]{total_success}[/bold]", - f"[bold red]{total_errors}[/bold red]" if total_errors > 0 else "[bold]0[/bold]", - f"[bold yellow]{total_skipped}[/bold yellow]" if total_skipped > 0 else "[bold]0[/bold]", - f"[bold bright_yellow]{total_duplicates}[/bold bright_yellow]" - if total_duplicates > 0 - else "[bold]0[/bold]", - f"[bold red]{total_validation_errors}[/bold red]" if total_validation_errors > 0 else "[bold]0[/bold]", - f"[bold]{total_batch_inserts}[/bold]", - f"[bold]{total_rate_style}{total_success_rate:.1f}%[/{total_rate_style[1:]}[/bold]", - f"[bold]{total_duration_seconds:.2f}[/bold]", - ) - - self.console.print(table) - - # 创建状态面板 - status_items = [] - if total_errors > 0: - status_items.append(f"[red]⚠️ 发现 {total_errors} 个错误,请检查日志详情[/red]") - - if total_validation_errors > 0: - status_items.append(f"[red]🔍 数据验证失败: {total_validation_errors} 条记录[/red]") - - if total_duplicates > 0: - status_items.append(f"[yellow]📋 跳过重复记录: {total_duplicates} 条[/yellow]") - - if total_success_rate >= 95: - status_items.append(f"[bright_green]✅ 迁移成功率优秀: {total_success_rate:.1f}%[/bright_green]") - elif total_success_rate >= 80: - status_items.append(f"[yellow]⚡ 迁移成功率良好: {total_success_rate:.1f}%[/yellow]") - else: - status_items.append(f"[red]❌ 迁移成功率较低: {total_success_rate:.1f}%,需要检查[/red]") - - if status_items: - status_panel = Panel( - "\n".join(status_items), title="[bold yellow]迁移状态总结[/bold yellow]", border_style="yellow" - ) - self.console.print(status_panel) - - # 性能统计面板 - avg_speed = total_processed / total_duration_seconds if total_duration_seconds > 0 else 0 - performance_info = ( - f"[cyan]总处理时间:[/cyan] {total_duration_seconds:.2f} 秒\n" - f"[cyan]平均处理速度:[/cyan] {avg_speed:.1f} 条记录/秒\n" - f"[cyan]批量插入优化:[/cyan] 执行了 {total_batch_inserts} 次批量操作" - ) - - performance_panel = Panel(performance_info, title="[bold green]性能统计[/bold green]", border_style="green") - self.console.print(performance_panel) - - def add_migration_config(self, config: MigrationConfig): - """添加新的迁移配置""" - self.migration_configs.append(config) - - def migrate_single_collection(self, collection_name: str) -> MigrationStats | None: - """迁移单个指定的集合""" - config = next((c for c in self.migration_configs if c.mongo_collection == collection_name), None) - if not config: - logger.error(f"未找到集合 {collection_name} 的迁移配置") - return None - - if not self.connect_mongodb(): - logger.error("无法连接到MongoDB") - return None - - try: - stats = self.migrate_collection(config) - self._print_migration_summary({collection_name: stats}) - return stats - finally: - self.disconnect_mongodb() - - def export_error_report(self, all_stats: dict[str, MigrationStats], filepath: str): - """导出错误报告""" - error_report = { - "timestamp": datetime.now().isoformat(), - "summary": { - collection: { - "total": stats.total_documents, - "processed": stats.processed_count, - "success": stats.success_count, - "errors": stats.error_count, - "skipped": stats.skipped_count, - "duplicates": stats.duplicate_count, - } - for collection, stats in all_stats.items() - }, - "errors": {collection: stats.errors for collection, stats in all_stats.items() if stats.errors}, - } - - try: - with open(filepath, "w", encoding="utf-8") as f: - orjson.dumps(error_report, f, ensure_ascii=False, indent=2) - logger.info(f"错误报告已导出到: {filepath}") - except Exception as e: - logger.error(f"导出错误报告失败: {e}") - - -def main(): - """主程序入口""" - migrator = MongoToSQLiteMigrator() - - # 执行迁移 - migration_results = migrator.migrate_all() - - # 导出错误报告(如果有错误) - if any(stats.error_count > 0 for stats in migration_results.values()): - error_report_path = f"migration_errors_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json" - migrator.export_error_report(migration_results, error_report_path) - - logger.info("数据迁移完成!") - - -if __name__ == "__main__": - main() diff --git a/scripts/run.sh b/scripts/run.sh deleted file mode 100644 index d702323a6..000000000 --- a/scripts/run.sh +++ /dev/null @@ -1,556 +0,0 @@ -#!/bin/bash - -# MaiCore & NapCat Adapter一键安装脚本 by Cookie_987 -# 适用于Arch/Ubuntu 24.10/Debian 12/CentOS 9 -# 请小心使用任何一键脚本! - -INSTALLER_VERSION="0.0.5-refactor" -LANG=C.UTF-8 - -# 如无法访问GitHub请修改此处镜像地址 -GITHUB_REPO="https://ghfast.top/https://github.com" - -# 颜色输出 -GREEN="\e[32m" -RED="\e[31m" -RESET="\e[0m" - -# 需要的基本软件包 - -declare -A REQUIRED_PACKAGES=( - ["common"]="git sudo python3 curl gnupg" - ["debian"]="python3-venv python3-pip build-essential" - ["ubuntu"]="python3-venv python3-pip build-essential" - ["centos"]="epel-release python3-pip python3-devel gcc gcc-c++ make" - ["arch"]="python-virtualenv python-pip base-devel" -) - -# 默认项目目录 -DEFAULT_INSTALL_DIR="/opt/maicore" - -# 服务名称 -SERVICE_NAME="maicore" -SERVICE_NAME_WEB="maicore-web" -SERVICE_NAME_NBADAPTER="maibot-napcat-adapter" - -IS_INSTALL_NAPCAT=false -IS_INSTALL_DEPENDENCIES=false - -# 检查是否已安装 -check_installed() { - [[ -f /etc/systemd/system/${SERVICE_NAME}.service ]] -} - -# 加载安装信息 -load_install_info() { - if [[ -f /etc/maicore_install.conf ]]; then - source /etc/maicore_install.conf - else - INSTALL_DIR="$DEFAULT_INSTALL_DIR" - BRANCH="refactor" - fi -} - -# 显示管理菜单 -show_menu() { - while true; do - choice=$(whiptail --title "MaiCore管理菜单" --menu "请选择要执行的操作:" 15 60 7 \ - "1" "启动MaiCore" \ - "2" "停止MaiCore" \ - "3" "重启MaiCore" \ - "4" "启动NapCat Adapter" \ - "5" "停止NapCat Adapter" \ - "6" "重启NapCat Adapter" \ - "7" "拉取最新MaiCore仓库" \ - "8" "切换分支" \ - "9" "退出" 3>&1 1>&2 2>&3) - - [[ $? -ne 0 ]] && exit 0 - - case "$choice" in - 1) - systemctl start ${SERVICE_NAME} - whiptail --msgbox "✅MaiCore已启动" 10 60 - ;; - 2) - systemctl stop ${SERVICE_NAME} - whiptail --msgbox "🛑MaiCore已停止" 10 60 - ;; - 3) - systemctl restart ${SERVICE_NAME} - whiptail --msgbox "🔄MaiCore已重启" 10 60 - ;; - 4) - systemctl start ${SERVICE_NAME_NBADAPTER} - whiptail --msgbox "✅NapCat Adapter已启动" 10 60 - ;; - 5) - systemctl stop ${SERVICE_NAME_NBADAPTER} - whiptail --msgbox "🛑NapCat Adapter已停止" 10 60 - ;; - 6) - systemctl restart ${SERVICE_NAME_NBADAPTER} - whiptail --msgbox "🔄NapCat Adapter已重启" 10 60 - ;; - 7) - update_dependencies - ;; - 8) - switch_branch - ;; - 9) - exit 0 - ;; - *) - whiptail --msgbox "无效选项!" 10 60 - ;; - esac - done -} - -# 更新依赖 -update_dependencies() { - whiptail --title "⚠" --msgbox "更新后请阅读教程" 10 60 - systemctl stop ${SERVICE_NAME} - cd "${INSTALL_DIR}/MaiBot" || { - whiptail --msgbox "🚫 无法进入安装目录!" 10 60 - return 1 - } - if ! git pull origin "${BRANCH}"; then - whiptail --msgbox "🚫 代码更新失败!" 10 60 - return 1 - fi - source "${INSTALL_DIR}/venv/bin/activate" - if ! pip install -r requirements.txt; then - whiptail --msgbox "🚫 依赖安装失败!" 10 60 - deactivate - return 1 - fi - deactivate - whiptail --msgbox "✅ 已停止服务并拉取最新仓库提交" 10 60 -} - -# 切换分支 -switch_branch() { - new_branch=$(whiptail --inputbox "请输入要切换的分支名称:" 10 60 "${BRANCH}" 3>&1 1>&2 2>&3) - [[ -z "$new_branch" ]] && { - whiptail --msgbox "🚫 分支名称不能为空!" 10 60 - return 1 - } - - cd "${INSTALL_DIR}/MaiBot" || { - whiptail --msgbox "🚫 无法进入安装目录!" 10 60 - return 1 - } - - if ! git ls-remote --exit-code --heads origin "${new_branch}" >/dev/null 2>&1; then - whiptail --msgbox "🚫 分支 ${new_branch} 不存在!" 10 60 - return 1 - fi - - if ! git checkout "${new_branch}"; then - whiptail --msgbox "🚫 分支切换失败!" 10 60 - return 1 - fi - - if ! git pull origin "${new_branch}"; then - whiptail --msgbox "🚫 代码拉取失败!" 10 60 - return 1 - fi - systemctl stop ${SERVICE_NAME} - source "${INSTALL_DIR}/venv/bin/activate" - pip install -r requirements.txt - deactivate - - sed -i "s/^BRANCH=.*/BRANCH=${new_branch}/" /etc/maicore_install.conf - BRANCH="${new_branch}" - check_eula - whiptail --msgbox "✅ 已停止服务并切换到分支 ${new_branch} !" 10 60 -} - -check_eula() { - # 首先计算当前EULA的MD5值 - current_md5=$(md5sum "${INSTALL_DIR}/MaiBot/EULA.md" | awk '{print $1}') - - # 首先计算当前隐私条款文件的哈希值 - current_md5_privacy=$(md5sum "${INSTALL_DIR}/MaiBot/PRIVACY.md" | awk '{print $1}') - - # 如果当前的md5值为空,则直接返回 - if [[ -z $current_md5 || -z $current_md5_privacy ]]; then - whiptail --msgbox "🚫 未找到使用协议\n 请检查PRIVACY.md和EULA.md是否存在" 10 60 - fi - - # 检查eula.confirmed文件是否存在 - if [[ -f ${INSTALL_DIR}/MaiBot/eula.confirmed ]]; then - # 如果存在则检查其中包含的md5与current_md5是否一致 - confirmed_md5=$(cat ${INSTALL_DIR}/MaiBot/eula.confirmed) - else - confirmed_md5="" - fi - - # 检查privacy.confirmed文件是否存在 - if [[ -f ${INSTALL_DIR}/MaiBot/privacy.confirmed ]]; then - # 如果存在则检查其中包含的md5与current_md5是否一致 - confirmed_md5_privacy=$(cat ${INSTALL_DIR}/MaiBot/privacy.confirmed) - else - confirmed_md5_privacy="" - fi - - # 如果EULA或隐私条款有更新,提示用户重新确认 - if [[ $current_md5 != $confirmed_md5 || $current_md5_privacy != $confirmed_md5_privacy ]]; then - whiptail --title "📜 使用协议更新" --yesno "检测到MaiCore EULA或隐私条款已更新。\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议? \n\n " 12 70 - if [[ $? -eq 0 ]]; then - echo -n $current_md5 > ${INSTALL_DIR}/MaiBot/eula.confirmed - echo -n $current_md5_privacy > ${INSTALL_DIR}/MaiBot/privacy.confirmed - else - exit 1 - fi - fi - -} - -# ----------- 主安装流程 ----------- -run_installation() { - # 1/6: 检测是否安装 whiptail - if ! command -v whiptail &>/dev/null; then - echo -e "${RED}[1/6] whiptail 未安装,正在安装...${RESET}" - - if command -v apt-get &>/dev/null; then - apt-get update && apt-get install -y whiptail - elif command -v pacman &>/dev/null; then - pacman -Syu --noconfirm whiptail - elif command -v yum &>/dev/null; then - yum install -y whiptail - else - echo -e "${RED}[Error] 无受支持的包管理器,无法安装 whiptail!${RESET}" - exit 1 - fi - fi - - whiptail --title "ℹ️ 提示" --msgbox "如果您没有特殊需求,请优先使用docker方式部署。" 10 60 - - # 协议确认 - if ! (whiptail --title "ℹ️ [1/6] 使用协议" --yes-button "我同意" --no-button "我拒绝" --yesno "使用MaiCore及此脚本前请先阅读EULA协议及隐私协议\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/EULA.md\nhttps://github.com/MaiM-with-u/MaiBot/blob/refactor/PRIVACY.md\n\n您是否同意上述协议?" 12 70); then - exit 1 - fi - - # 欢迎信息 - whiptail --title "[2/6] 欢迎使用MaiCore一键安装脚本 by Cookie987" --msgbox "检测到您未安装MaiCore,将自动进入安装流程,安装完成后再次运行此脚本即可进入管理菜单。\n\n项目处于活跃开发阶段,代码可能随时更改\n文档未完善,有问题可以提交 Issue 或者 Discussion\nQQ机器人存在被限制风险,请自行了解,谨慎使用\n由于持续迭代,可能存在一些已知或未知的bug\n由于开发中,可能消耗较多token\n\n本脚本可能更新不及时,如遇到bug请优先尝试手动部署以确定是否为脚本问题" 17 60 - - # 系统检查 - check_system() { - if [[ "$(id -u)" -ne 0 ]]; then - whiptail --title "🚫 权限不足" --msgbox "请使用 root 用户运行此脚本!\n执行方式: sudo bash $0" 10 60 - exit 1 - fi - - if [[ -f /etc/os-release ]]; then - source /etc/os-release - if [[ "$ID" == "debian" && "$VERSION_ID" == "12" ]]; then - return - elif [[ "$ID" == "ubuntu" && "$VERSION_ID" == "24.10" ]]; then - return - elif [[ "$ID" == "centos" && "$VERSION_ID" == "9" ]]; then - return - elif [[ "$ID" == "arch" ]]; then - whiptail --title "⚠️ 兼容性警告" --msgbox "NapCat无可用的 Arch Linux 官方安装方法,将无法自动安装NapCat。\n\n您可尝试在AUR中搜索相关包。" 10 60 - return - else - whiptail --title "🚫 不支持的系统" --msgbox "此脚本仅支持 Arch/Debian 12 (Bookworm)/Ubuntu 24.10 (Oracular Oriole)/CentOS9!\n当前系统: $PRETTY_NAME\n安装已终止。" 10 60 - exit 1 - fi - else - whiptail --title "⚠️ 无法检测系统" --msgbox "无法识别系统版本,安装已终止。" 10 60 - exit 1 - fi - } - check_system - - # 设置包管理器 - case "$ID" in - debian|ubuntu) - PKG_MANAGER="apt" - ;; - centos) - PKG_MANAGER="yum" - ;; - arch) - # 添加arch包管理器 - PKG_MANAGER="pacman" - ;; - esac - - # 检查NapCat - check_napcat() { - if command -v napcat &>/dev/null; then - NAPCAT_INSTALLED=true - else - NAPCAT_INSTALLED=false - fi - } - check_napcat - - # 安装必要软件包 - install_packages() { - missing_packages=() - # 检查 common 及当前系统专属依赖 - for package in ${REQUIRED_PACKAGES["common"]} ${REQUIRED_PACKAGES["$ID"]}; do - case "$PKG_MANAGER" in - apt) - dpkg -s "$package" &>/dev/null || missing_packages+=("$package") - ;; - yum) - rpm -q "$package" &>/dev/null || missing_packages+=("$package") - ;; - pacman) - pacman -Qi "$package" &>/dev/null || missing_packages+=("$package") - ;; - esac - done - - if [[ ${#missing_packages[@]} -gt 0 ]]; then - whiptail --title "📦 [3/6] 依赖检查" --yesno "以下软件包缺失:\n${missing_packages[*]}\n\n是否自动安装?" 10 60 - if [[ $? -eq 0 ]]; then - IS_INSTALL_DEPENDENCIES=true - else - whiptail --title "⚠️ 注意" --yesno "未安装某些依赖,可能影响运行!\n是否继续?" 10 60 || exit 1 - fi - fi - } - install_packages - - # 安装NapCat - install_napcat() { - [[ $NAPCAT_INSTALLED == true ]] && return - whiptail --title "📦 [3/6] 软件包检查" --yesno "检测到未安装NapCat,是否安装?\n如果您想使用远程NapCat,请跳过此步。" 10 60 && { - IS_INSTALL_NAPCAT=true - } - } - - # 仅在非Arch系统上安装NapCat - [[ "$ID" != "arch" ]] && install_napcat - - # Python版本检查 - check_python() { - PYTHON_VERSION=$(python3 -c 'import sys; print(f"{sys.version_info.major}.{sys.version_info.minor}")') - if ! python3 -c "import sys; exit(0) if sys.version_info >= (3,10) else exit(1)"; then - whiptail --title "⚠️ [4/6] Python 版本过低" --msgbox "检测到 Python 版本为 $PYTHON_VERSION,需要 3.10 或以上!\n请升级 Python 后重新运行本脚本。" 10 60 - exit 1 - fi - } - - # 如果没安装python则不检查python版本 - if command -v python3 &>/dev/null; then - check_python - fi - - - # 选择分支 - choose_branch() { - BRANCH=$(whiptail --title "🔀 选择分支" --radiolist "请选择要安装的分支:" 15 60 4 \ - "main" "稳定版本(推荐)" ON \ - "dev" "开发版(不知道什么意思就别选)" OFF \ - "classical" "经典版(0.6.0以前的版本)" OFF \ - "custom" "自定义分支" OFF 3>&1 1>&2 2>&3) - RETVAL=$? - if [ $RETVAL -ne 0 ]; then - whiptail --msgbox "🚫 操作取消!" 10 60 - exit 1 - fi - - if [[ "$BRANCH" == "custom" ]]; then - BRANCH=$(whiptail --title "🔀 自定义分支" --inputbox "请输入自定义分支名称:" 10 60 "refactor" 3>&1 1>&2 2>&3) - RETVAL=$? - if [ $RETVAL -ne 0 ]; then - whiptail --msgbox "🚫 输入取消!" 10 60 - exit 1 - fi - if [[ -z "$BRANCH" ]]; then - whiptail --msgbox "🚫 分支名称不能为空!" 10 60 - exit 1 - fi - fi - } - choose_branch - - # 选择安装路径 - choose_install_dir() { - INSTALL_DIR=$(whiptail --title "📂 [6/6] 选择安装路径" --inputbox "请输入MaiCore的安装目录:" 10 60 "$DEFAULT_INSTALL_DIR" 3>&1 1>&2 2>&3) - [[ -z "$INSTALL_DIR" ]] && { - whiptail --title "⚠️ 取消输入" --yesno "未输入安装路径,是否退出安装?" 10 60 && exit 1 - INSTALL_DIR="$DEFAULT_INSTALL_DIR" - } - } - choose_install_dir - - # 确认安装 - confirm_install() { - local confirm_msg="请确认以下更改:\n\n" - confirm_msg+="📂 安装MaiCore、NapCat Adapter到: $INSTALL_DIR\n" - confirm_msg+="🔀 分支: $BRANCH\n" - [[ $IS_INSTALL_DEPENDENCIES == true ]] && confirm_msg+="📦 安装依赖:${missing_packages[@]}\n" - [[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+="📦 安装额外组件:\n" - - [[ $IS_INSTALL_NAPCAT == true ]] && confirm_msg+=" - NapCat\n" - confirm_msg+="\n注意:本脚本默认使用ghfast.top为GitHub进行加速,如不想使用请手动修改脚本开头的GITHUB_REPO变量。" - - whiptail --title "🔧 安装确认" --yesno "$confirm_msg" 20 60 || exit 1 - } - confirm_install - - # 开始安装 - echo -e "${GREEN}安装${missing_packages[@]}...${RESET}" - - if [[ $IS_INSTALL_DEPENDENCIES == true ]]; then - case "$PKG_MANAGER" in - apt) - apt update && apt install -y "${missing_packages[@]}" - ;; - yum) - yum install -y "${missing_packages[@]}" --nobest - ;; - pacman) - pacman -S --noconfirm "${missing_packages[@]}" - ;; - esac - fi - - if [[ $IS_INSTALL_NAPCAT == true ]]; then - echo -e "${GREEN}安装 NapCat...${RESET}" - curl -o napcat.sh https://nclatest.znin.net/NapNeko/NapCat-Installer/main/script/install.sh && bash napcat.sh --cli y --docker n - fi - - echo -e "${GREEN}创建安装目录...${RESET}" - mkdir -p "$INSTALL_DIR" - cd "$INSTALL_DIR" || exit 1 - - echo -e "${GREEN}设置Python虚拟环境...${RESET}" - python3 -m venv venv - source venv/bin/activate - - echo -e "${GREEN}克隆MaiCore仓库...${RESET}" - git clone -b "$BRANCH" "$GITHUB_REPO/MaiM-with-u/MaiBot" MaiBot || { - echo -e "${RED}克隆MaiCore仓库失败!${RESET}" - exit 1 - } - - echo -e "${GREEN}克隆 maim_message 包仓库...${RESET}" - git clone $GITHUB_REPO/MaiM-with-u/maim_message.git || { - echo -e "${RED}克隆 maim_message 包仓库失败!${RESET}" - exit 1 - } - - echo -e "${GREEN}克隆 nonebot-plugin-maibot-adapters 仓库...${RESET}" - git clone $GITHUB_REPO/MaiM-with-u/MaiBot-Napcat-Adapter.git || { - echo -e "${RED}克隆 MaiBot-Napcat-Adapter.git 仓库失败!${RESET}" - exit 1 - } - - - echo -e "${GREEN}安装Python依赖...${RESET}" - pip install -r MaiBot/requirements.txt - cd MaiBot - pip install uv - uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt - cd .. - - echo -e "${GREEN}安装maim_message依赖...${RESET}" - cd maim_message - uv pip install -i https://mirrors.aliyun.com/pypi/simple -e . - cd .. - - echo -e "${GREEN}部署MaiBot Napcat Adapter...${RESET}" - cd MaiBot-Napcat-Adapter - uv pip install -i https://mirrors.aliyun.com/pypi/simple -r requirements.txt - cd .. - - echo -e "${GREEN}同意协议...${RESET}" - - # 首先计算当前EULA的MD5值 - current_md5=$(md5sum "MaiBot/EULA.md" | awk '{print $1}') - - # 首先计算当前隐私条款文件的哈希值 - current_md5_privacy=$(md5sum "MaiBot/PRIVACY.md" | awk '{print $1}') - - echo -n $current_md5 > MaiBot/eula.confirmed - echo -n $current_md5_privacy > MaiBot/privacy.confirmed - - echo -e "${GREEN}创建系统服务...${RESET}" - cat > /etc/systemd/system/${SERVICE_NAME}.service < /etc/systemd/system/${SERVICE_NAME_WEB}.service < /etc/systemd/system/${SERVICE_NAME_NBADAPTER}.service < /etc/maicore_install.conf - echo "INSTALL_DIR=${INSTALL_DIR}" >> /etc/maicore_install.conf - echo "BRANCH=${BRANCH}" >> /etc/maicore_install.conf - - whiptail --title "🎉 安装完成" --msgbox "MaiCore安装完成!\n已创建系统服务:${SERVICE_NAME}、${SERVICE_NAME_WEB}、${SERVICE_NAME_NBADAPTER}\n\n使用以下命令管理服务:\n启动服务:systemctl start ${SERVICE_NAME}\n查看状态:systemctl status ${SERVICE_NAME}" 14 60 -} - -# ----------- 主执行流程 ----------- -# 检查root权限 -[[ $(id -u) -ne 0 ]] && { - echo -e "${RED}请使用root用户运行此脚本!${RESET}" - exit 1 -} - -# 如果已安装显示菜单,并检查协议是否更新 -if check_installed; then - load_install_info - check_eula - show_menu -else - run_installation - # 安装完成后询问是否启动 - if whiptail --title "安装完成" --yesno "是否立即启动MaiCore服务?" 10 60; then - systemctl start ${SERVICE_NAME} - whiptail --msgbox "✅ 服务已启动!\n使用 systemctl status ${SERVICE_NAME} 查看状态" 10 60 - fi -fi diff --git a/src/chat/antipromptinjector/anti_injector.py b/src/chat/antipromptinjector/anti_injector.py index 23ff3a7ee..979b75fe4 100644 --- a/src/chat/antipromptinjector/anti_injector.py +++ b/src/chat/antipromptinjector/anti_injector.py @@ -265,7 +265,8 @@ class AntiPromptInjector: async with get_db_session() as session: # 删除对应的消息记录 stmt = delete(Messages).where(Messages.message_id == message_id) - result = session.execute(stmt) + # 注意: 异步会话需要 await 执行,否则 result 是 coroutine,无法获取 rowcount + result = await session.execute(stmt) await session.commit() if result.rowcount > 0: @@ -296,7 +297,7 @@ class AntiPromptInjector: .where(Messages.message_id == message_id) .values(processed_plain_text=new_content, display_message=new_content) ) - result = session.execute(stmt) + result = await session.execute(stmt) await session.commit() if result.rowcount > 0: diff --git a/src/chat/antipromptinjector/management/statistics.py b/src/chat/antipromptinjector/management/statistics.py index 0525754f1..9820ea525 100644 --- a/src/chat/antipromptinjector/management/statistics.py +++ b/src/chat/antipromptinjector/management/statistics.py @@ -5,9 +5,9 @@ """ import datetime -from typing import Any +from typing import Any, Optional, TypedDict, Literal, Union, Callable, TypeVar, cast -from sqlalchemy import select +from sqlalchemy import select, delete from src.common.database.sqlalchemy_models import AntiInjectionStats, get_db_session from src.common.logger import get_logger @@ -16,8 +16,30 @@ from src.config.config import global_config logger = get_logger("anti_injector.statistics") +TNum = TypeVar("TNum", int, float) + + +def _add_optional(a: Optional[TNum], b: TNum) -> TNum: + """安全相加:左值可能为 None。 + + Args: + a: 可能为 None 的当前值 + b: 要累加的增量(非 None) + Returns: + 新的累加结果(与 b 同类型) + """ + if a is None: + return b + return cast(TNum, a + b) # a 不为 None,此处显式 cast 便于类型检查 + + class AntiInjectionStatistics: - """反注入系统统计管理类""" + """反注入系统统计管理类 + + 主要改进: + - 对 "可能为 None" 的数值字段做集中安全处理,减少在业务逻辑里反复判空。 + - 补充类型注解,便于静态检查器(Pylance/Pyright)识别。 + """ def __init__(self): """初始化统计管理器""" @@ -25,8 +47,12 @@ class AntiInjectionStatistics: """当前会话开始时间""" @staticmethod - async def get_or_create_stats(): - """获取或创建统计记录""" + async def get_or_create_stats() -> Optional[AntiInjectionStats]: # type: ignore[name-defined] + """获取或创建统计记录 + + Returns: + AntiInjectionStats | None: 成功返回模型实例,否则 None + """ try: async with get_db_session() as session: # 获取最新的统计记录,如果没有则创建 @@ -46,8 +72,15 @@ class AntiInjectionStatistics: return None @staticmethod - async def update_stats(**kwargs): - """更新统计数据""" + async def update_stats(**kwargs: Any) -> None: + """更新统计数据(批量可选字段) + + 支持字段: + - processing_time_delta: float 累加到 processing_time_total + - last_processing_time: float 设置 last_process_time + - total_messages / detected_injections / blocked_messages / shielded_messages / error_count: 累加 + - 其他任意字段:直接赋值(若模型存在该属性) + """ try: async with get_db_session() as session: stats = ( @@ -62,14 +95,13 @@ class AntiInjectionStatistics: # 更新统计字段 for key, value in kwargs.items(): if key == "processing_time_delta": - # 处理 时间累加 - 确保不为None - if stats.processing_time_total is None: - stats.processing_time_total = 0.0 - stats.processing_time_total += value + # 处理时间累加 - 确保不为 None + delta = float(value) + stats.processing_time_total = _add_optional(stats.processing_time_total, delta) # type: ignore[attr-defined] continue elif key == "last_processing_time": # 直接设置最后处理时间 - stats.last_process_time = value + stats.last_process_time = float(value) continue elif hasattr(stats, key): if key in [ @@ -79,12 +111,10 @@ class AntiInjectionStatistics: "shielded_messages", "error_count", ]: - # 累加类型的字段 - 确保不为None - current_value = getattr(stats, key) - if current_value is None: - setattr(stats, key, value) - else: - setattr(stats, key, current_value + value) + # 累加类型的字段 - 统一用辅助函数 + current_value = cast(Optional[int], getattr(stats, key)) + increment = int(value) + setattr(stats, key, _add_optional(current_value, increment)) else: # 直接设置的字段 setattr(stats, key, value) @@ -114,10 +144,11 @@ class AntiInjectionStatistics: stats = await self.get_or_create_stats() - # 计算派生统计信息 - 处理None值 - total_messages = stats.total_messages or 0 - detected_injections = stats.detected_injections or 0 - processing_time_total = stats.processing_time_total or 0.0 + + # 计算派生统计信息 - 处理 None 值 + total_messages = stats.total_messages or 0 # type: ignore[attr-defined] + detected_injections = stats.detected_injections or 0 # type: ignore[attr-defined] + processing_time_total = stats.processing_time_total or 0.0 # type: ignore[attr-defined] detection_rate = (detected_injections / total_messages * 100) if total_messages > 0 else 0 avg_processing_time = (processing_time_total / total_messages) if total_messages > 0 else 0 @@ -127,17 +158,22 @@ class AntiInjectionStatistics: current_time = datetime.datetime.now() uptime = current_time - self.session_start_time + last_proc = stats.last_process_time # type: ignore[attr-defined] + blocked_messages = stats.blocked_messages or 0 # type: ignore[attr-defined] + shielded_messages = stats.shielded_messages or 0 # type: ignore[attr-defined] + error_count = stats.error_count or 0 # type: ignore[attr-defined] + return { "status": "enabled", "uptime": str(uptime), "total_messages": total_messages, "detected_injections": detected_injections, - "blocked_messages": stats.blocked_messages or 0, - "shielded_messages": stats.shielded_messages or 0, + "blocked_messages": blocked_messages, + "shielded_messages": shielded_messages, "detection_rate": f"{detection_rate:.2f}%", "average_processing_time": f"{avg_processing_time:.3f}s", - "last_processing_time": f"{stats.last_process_time:.3f}s" if stats.last_process_time else "0.000s", - "error_count": stats.error_count or 0, + "last_processing_time": f"{last_proc:.3f}s" if last_proc else "0.000s", + "error_count": error_count, } except Exception as e: logger.error(f"获取统计信息失败: {e}") @@ -149,7 +185,7 @@ class AntiInjectionStatistics: try: async with get_db_session() as session: # 删除现有统计记录 - await session.execute(select(AntiInjectionStats).delete()) + await session.execute(delete(AntiInjectionStats)) await session.commit() logger.info("统计信息已重置") except Exception as e: diff --git a/src/chat/antipromptinjector/management/user_ban.py b/src/chat/antipromptinjector/management/user_ban.py index f1b82a8dc..34bf185c6 100644 --- a/src/chat/antipromptinjector/management/user_ban.py +++ b/src/chat/antipromptinjector/management/user_ban.py @@ -51,7 +51,7 @@ class UserBanManager: remaining_time = ban_duration - (datetime.datetime.now() - ban_record.created_at) return False, None, f"用户被封禁中,剩余时间: {remaining_time}" else: - # 封禁已过期,重置违规次数 + # 封禁已过期,重置违规次数与时间(模型已使用 Mapped 类型,可直接赋值) ban_record.violation_num = 0 ban_record.created_at = datetime.datetime.now() await session.commit() @@ -92,7 +92,6 @@ class UserBanManager: await session.commit() - # 检查是否需要自动封禁 if ban_record.violation_num >= self.config.auto_ban_violation_threshold: logger.warning(f"用户 {platform}:{user_id} 违规次数达到 {ban_record.violation_num},触发自动封禁") # 只有在首次达到阈值时才更新封禁开始时间 diff --git a/src/chat/emoji_system/emoji_manager.py b/src/chat/emoji_system/emoji_manager.py index 0d7d71a7c..e6ad20cec 100644 --- a/src/chat/emoji_system/emoji_manager.py +++ b/src/chat/emoji_system/emoji_manager.py @@ -377,11 +377,12 @@ async def clean_unused_emojis(emoji_dir: str, emoji_objects: list["MaiEmoji"], r class EmojiManager: _instance = None + _initialized: bool = False # 显式声明,避免属性未定义错误 def __new__(cls) -> "EmojiManager": if cls._instance is None: cls._instance = super().__new__(cls) - cls._instance._initialized = False + # 类属性已声明,无需再次赋值 return cls._instance def __init__(self) -> None: @@ -399,7 +400,8 @@ class EmojiManager: self.emoji_num_max = global_config.emoji.max_reg_num self.emoji_num_max_reach_deletion = global_config.emoji.do_replace self.emoji_objects: list[MaiEmoji] = [] # 存储MaiEmoji对象的列表,使用类型注解明确列表元素类型 - + logger.info("启动表情包管理器") + self._initialized = True logger.info("启动表情包管理器") def shutdown(self) -> None: @@ -752,8 +754,8 @@ class EmojiManager: try: emoji_record = await self.get_emoji_from_db(emoji_hash) if emoji_record and emoji_record[0].emotion: - logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") - return emoji_record.emotion + logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.emotion[:50]}...") # type: ignore # type: ignore + return emoji_record.emotion # type: ignore except Exception as e: logger.error(f"从数据库查询表情包描述时出错: {e}") diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index 91545e5d5..d229d5823 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -7,6 +7,7 @@ import asyncio import time from typing import Any +from chat.message_manager.adaptive_stream_manager import StreamPriority from src.chat.chatter_manager import ChatterManager from src.chat.energy_system import energy_manager from src.common.data_models.message_manager_data_model import StreamContext diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 274723ea8..9319e11f4 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -1,6 +1,14 @@ """SQLAlchemy数据库模型定义 替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力 + +说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到 +SQLAlchemy 2.0 推荐的带类型注解的声明式风格: + + field_name: Mapped[PyType] = mapped_column(Type, ...) + +这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。 +当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。 """ import datetime @@ -103,31 +111,31 @@ class ChatStreams(Base): __tablename__ = "chat_streams" - id = Column(Integer, primary_key=True, autoincrement=True) - stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True) - create_time = Column(Float, nullable=False) - group_platform = Column(Text, nullable=True) - group_id = Column(get_string_field(100), nullable=True, index=True) - group_name = Column(Text, nullable=True) - last_active_time = Column(Float, nullable=False) - platform = Column(Text, nullable=False) - user_platform = Column(Text, nullable=False) - user_id = Column(get_string_field(100), nullable=False, index=True) - user_nickname = Column(Text, nullable=False) - user_cardname = Column(Text, nullable=True) - energy_value = Column(Float, nullable=True, default=5.0) - sleep_pressure = Column(Float, nullable=True, default=0.0) - focus_energy = Column(Float, nullable=True, default=0.5) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + stream_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, unique=True, index=True) + create_time: Mapped[float] = mapped_column(Float, nullable=False) + group_platform: Mapped[str | None] = mapped_column(Text, nullable=True) + group_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True) + group_name: Mapped[str | None] = mapped_column(Text, nullable=True) + last_active_time: Mapped[float] = mapped_column(Float, nullable=False) + platform: Mapped[str] = mapped_column(Text, nullable=False) + user_platform: Mapped[str] = mapped_column(Text, nullable=False) + user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) + user_nickname: Mapped[str] = mapped_column(Text, nullable=False) + user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True) + energy_value: Mapped[float | None] = mapped_column(Float, nullable=True, default=5.0) + sleep_pressure: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0) + focus_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5) # 动态兴趣度系统字段 - base_interest_energy = Column(Float, nullable=True, default=0.5) - message_interest_total = Column(Float, nullable=True, default=0.0) - message_count = Column(Integer, nullable=True, default=0) - action_count = Column(Integer, nullable=True, default=0) - reply_count = Column(Integer, nullable=True, default=0) - last_interaction_time = Column(Float, nullable=True, default=None) - consecutive_no_reply = Column(Integer, nullable=True, default=0) + base_interest_energy: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.5) + message_interest_total: Mapped[float | None] = mapped_column(Float, nullable=True, default=0.0) + message_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) + action_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) + reply_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) + last_interaction_time: Mapped[float | None] = mapped_column(Float, nullable=True, default=None) + consecutive_no_reply: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) # 消息打断系统字段 - interruption_count = Column(Integer, nullable=True, default=0) + interruption_count: Mapped[int | None] = mapped_column(Integer, nullable=True, default=0) __table_args__ = ( Index("idx_chatstreams_stream_id", "stream_id"), @@ -141,20 +149,20 @@ class LLMUsage(Base): __tablename__ = "llm_usage" - id = Column(Integer, primary_key=True, autoincrement=True) - model_name = Column(get_string_field(100), nullable=False, index=True) - model_assign_name = Column(get_string_field(100), index=True) # 添加索引 - model_api_provider = Column(get_string_field(100), index=True) # 添加索引 - user_id = Column(get_string_field(50), nullable=False, index=True) - request_type = Column(get_string_field(50), nullable=False, index=True) - endpoint = Column(Text, nullable=False) - prompt_tokens = Column(Integer, nullable=False) - completion_tokens = Column(Integer, nullable=False) - time_cost = Column(Float, nullable=True) - total_tokens = Column(Integer, nullable=False) - cost = Column(Float, nullable=False) - status = Column(Text, nullable=False) - timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + model_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) + model_assign_name: Mapped[str] = mapped_column(get_string_field(100), index=True) + model_api_provider: Mapped[str] = mapped_column(get_string_field(100), index=True) + user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) + request_type: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) + endpoint: Mapped[str] = mapped_column(Text, nullable=False) + prompt_tokens: Mapped[int] = mapped_column(Integer, nullable=False) + completion_tokens: Mapped[int] = mapped_column(Integer, nullable=False) + time_cost: Mapped[float | None] = mapped_column(Float, nullable=True) + total_tokens: Mapped[int] = mapped_column(Integer, nullable=False) + cost: Mapped[float] = mapped_column(Float, nullable=False) + status: Mapped[str] = mapped_column(Text, nullable=False) + timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True, default=datetime.datetime.now) __table_args__ = ( Index("idx_llmusage_model_name", "model_name"), @@ -172,19 +180,19 @@ class Emoji(Base): __tablename__ = "emoji" - id = Column(Integer, primary_key=True, autoincrement=True) - full_path = Column(get_string_field(500), nullable=False, unique=True, index=True) - format = Column(Text, nullable=False) - emoji_hash = Column(get_string_field(64), nullable=False, index=True) - description = Column(Text, nullable=False) - query_count = Column(Integer, nullable=False, default=0) - is_registered = Column(Boolean, nullable=False, default=False) - is_banned = Column(Boolean, nullable=False, default=False) - emotion = Column(Text, nullable=True) - record_time = Column(Float, nullable=False) - register_time = Column(Float, nullable=True) - usage_count = Column(Integer, nullable=False, default=0) - last_used_time = Column(Float, nullable=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + full_path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True) + format: Mapped[str] = mapped_column(Text, nullable=False) + emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) + description: Mapped[str] = mapped_column(Text, nullable=False) + query_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + is_registered: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_banned: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + emotion: Mapped[str | None] = mapped_column(Text, nullable=True) + record_time: Mapped[float] = mapped_column(Float, nullable=False) + register_time: Mapped[float | None] = mapped_column(Float, nullable=True) + usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + last_used_time: Mapped[float | None] = mapped_column(Float, nullable=True) __table_args__ = ( Index("idx_emoji_full_path", "full_path"), @@ -197,50 +205,50 @@ class Messages(Base): __tablename__ = "messages" - id = Column(Integer, primary_key=True, autoincrement=True) - message_id = Column(get_string_field(100), nullable=False, index=True) - time = Column(Float, nullable=False) - chat_id = Column(get_string_field(64), nullable=False, index=True) - reply_to = Column(Text, nullable=True) - interest_value = Column(Float, nullable=True) - key_words = Column(Text, nullable=True) - key_words_lite = Column(Text, nullable=True) - is_mentioned = Column(Boolean, nullable=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + message_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) + time: Mapped[float] = mapped_column(Float, nullable=False) + chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) + reply_to: Mapped[str | None] = mapped_column(Text, nullable=True) + interest_value: Mapped[float | None] = mapped_column(Float, nullable=True) + key_words: Mapped[str | None] = mapped_column(Text, nullable=True) + key_words_lite: Mapped[str | None] = mapped_column(Text, nullable=True) + is_mentioned: Mapped[bool | None] = mapped_column(Boolean, nullable=True) # 从 chat_info 扁平化而来的字段 - chat_info_stream_id = Column(Text, nullable=False) - chat_info_platform = Column(Text, nullable=False) - chat_info_user_platform = Column(Text, nullable=False) - chat_info_user_id = Column(Text, nullable=False) - chat_info_user_nickname = Column(Text, nullable=False) - chat_info_user_cardname = Column(Text, nullable=True) - chat_info_group_platform = Column(Text, nullable=True) - chat_info_group_id = Column(Text, nullable=True) - chat_info_group_name = Column(Text, nullable=True) - chat_info_create_time = Column(Float, nullable=False) - chat_info_last_active_time = Column(Float, nullable=False) + chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False) + chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False) + chat_info_user_platform: Mapped[str] = mapped_column(Text, nullable=False) + chat_info_user_id: Mapped[str] = mapped_column(Text, nullable=False) + chat_info_user_nickname: Mapped[str] = mapped_column(Text, nullable=False) + chat_info_user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True) + chat_info_group_platform: Mapped[str | None] = mapped_column(Text, nullable=True) + chat_info_group_id: Mapped[str | None] = mapped_column(Text, nullable=True) + chat_info_group_name: Mapped[str | None] = mapped_column(Text, nullable=True) + chat_info_create_time: Mapped[float] = mapped_column(Float, nullable=False) + chat_info_last_active_time: Mapped[float] = mapped_column(Float, nullable=False) # 从顶层 user_info 扁平化而来的字段 - user_platform = Column(Text, nullable=True) - user_id = Column(get_string_field(100), nullable=True, index=True) - user_nickname = Column(Text, nullable=True) - user_cardname = Column(Text, nullable=True) + user_platform: Mapped[str | None] = mapped_column(Text, nullable=True) + user_id: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True, index=True) + user_nickname: Mapped[str | None] = mapped_column(Text, nullable=True) + user_cardname: Mapped[str | None] = mapped_column(Text, nullable=True) - processed_plain_text = Column(Text, nullable=True) - display_message = Column(Text, nullable=True) - memorized_times = Column(Integer, nullable=False, default=0) - priority_mode = Column(Text, nullable=True) - priority_info = Column(Text, nullable=True) - additional_config = Column(Text, nullable=True) - is_emoji = Column(Boolean, nullable=False, default=False) - is_picid = Column(Boolean, nullable=False, default=False) - is_command = Column(Boolean, nullable=False, default=False) - is_notify = Column(Boolean, nullable=False, default=False) + processed_plain_text: Mapped[str | None] = mapped_column(Text, nullable=True) + display_message: Mapped[str | None] = mapped_column(Text, nullable=True) + memorized_times: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + priority_mode: Mapped[str | None] = mapped_column(Text, nullable=True) + priority_info: Mapped[str | None] = mapped_column(Text, nullable=True) + additional_config: Mapped[str | None] = mapped_column(Text, nullable=True) + is_emoji: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_picid: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_command: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + is_notify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) # 兴趣度系统字段 - actions = Column(Text, nullable=True) # JSON格式存储动作列表 - should_reply = Column(Boolean, nullable=True, default=False) - should_act = Column(Boolean, nullable=True, default=False) + actions: Mapped[str | None] = mapped_column(Text, nullable=True) + should_reply: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False) + should_act: Mapped[bool | None] = mapped_column(Boolean, nullable=True, default=False) __table_args__ = ( Index("idx_messages_message_id", "message_id"), @@ -257,17 +265,17 @@ class ActionRecords(Base): __tablename__ = "action_records" - id = Column(Integer, primary_key=True, autoincrement=True) - action_id = Column(get_string_field(100), nullable=False, index=True) - time = Column(Float, nullable=False) - action_name = Column(Text, nullable=False) - action_data = Column(Text, nullable=False) - action_done = Column(Boolean, nullable=False, default=False) - action_build_into_prompt = Column(Boolean, nullable=False, default=False) - action_prompt_display = Column(Text, nullable=False) - chat_id = Column(get_string_field(64), nullable=False, index=True) - chat_info_stream_id = Column(Text, nullable=False) - chat_info_platform = Column(Text, nullable=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + action_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) + time: Mapped[float] = mapped_column(Float, nullable=False) + action_name: Mapped[str] = mapped_column(Text, nullable=False) + action_data: Mapped[str] = mapped_column(Text, nullable=False) + action_done: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + action_build_into_prompt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + action_prompt_display: Mapped[str] = mapped_column(Text, nullable=False) + chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) + chat_info_stream_id: Mapped[str] = mapped_column(Text, nullable=False) + chat_info_platform: Mapped[str] = mapped_column(Text, nullable=False) __table_args__ = ( Index("idx_actionrecords_action_id", "action_id"), @@ -281,15 +289,15 @@ class Images(Base): __tablename__ = "images" - id = Column(Integer, primary_key=True, autoincrement=True) - image_id = Column(Text, nullable=False, default="") - emoji_hash = Column(get_string_field(64), nullable=False, index=True) - description = Column(Text, nullable=True) - path = Column(get_string_field(500), nullable=False, unique=True) - count = Column(Integer, nullable=False, default=1) - timestamp = Column(Float, nullable=False) - type = Column(Text, nullable=False) - vlm_processed = Column(Boolean, nullable=False, default=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + image_id: Mapped[str] = mapped_column(Text, nullable=False, default="") + emoji_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + path: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True) + count: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + timestamp: Mapped[float] = mapped_column(Float, nullable=False) + type: Mapped[str] = mapped_column(Text, nullable=False) + vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) __table_args__ = ( Index("idx_images_emoji_hash", "emoji_hash"), @@ -302,11 +310,11 @@ class ImageDescriptions(Base): __tablename__ = "image_descriptions" - id = Column(Integer, primary_key=True, autoincrement=True) - type = Column(Text, nullable=False) - image_description_hash = Column(get_string_field(64), nullable=False, index=True) - description = Column(Text, nullable=False) - timestamp = Column(Float, nullable=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + type: Mapped[str] = mapped_column(Text, nullable=False) + image_description_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) + description: Mapped[str] = mapped_column(Text, nullable=False) + timestamp: Mapped[float] = mapped_column(Float, nullable=False) __table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),) @@ -316,20 +324,20 @@ class Videos(Base): __tablename__ = "videos" - id = Column(Integer, primary_key=True, autoincrement=True) - video_id = Column(Text, nullable=False, default="") - video_hash = Column(get_string_field(64), nullable=False, index=True, unique=True) - description = Column(Text, nullable=True) - count = Column(Integer, nullable=False, default=1) - timestamp = Column(Float, nullable=False) - vlm_processed = Column(Boolean, nullable=False, default=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + video_id: Mapped[str] = mapped_column(Text, nullable=False, default="") + video_hash: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True, unique=True) + description: Mapped[str | None] = mapped_column(Text, nullable=True) + count: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + timestamp: Mapped[float] = mapped_column(Float, nullable=False) + vlm_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) # 视频特有属性 - duration = Column(Float, nullable=True) # 视频时长(秒) - frame_count = Column(Integer, nullable=True) # 总帧数 - fps = Column(Float, nullable=True) # 帧率 - resolution = Column(Text, nullable=True) # 分辨率 - file_size = Column(Integer, nullable=True) # 文件大小(字节) + duration: Mapped[float | None] = mapped_column(Float, nullable=True) + frame_count: Mapped[int | None] = mapped_column(Integer, nullable=True) + fps: Mapped[float | None] = mapped_column(Float, nullable=True) + resolution: Mapped[str | None] = mapped_column(Text, nullable=True) + file_size: Mapped[int | None] = mapped_column(Integer, nullable=True) __table_args__ = ( Index("idx_videos_video_hash", "video_hash"), @@ -342,11 +350,11 @@ class OnlineTime(Base): __tablename__ = "online_time" - id = Column(Integer, primary_key=True, autoincrement=True) - timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now)) - duration = Column(Integer, nullable=False) - start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now) - end_timestamp = Column(DateTime, nullable=False, index=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + timestamp: Mapped[str] = mapped_column(Text, nullable=False, default=str(datetime.datetime.now)) + duration: Mapped[int] = mapped_column(Integer, nullable=False) + start_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) + end_timestamp: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, index=True) __table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),) @@ -356,22 +364,22 @@ class PersonInfo(Base): __tablename__ = "person_info" - id = Column(Integer, primary_key=True, autoincrement=True) - person_id = Column(get_string_field(100), nullable=False, unique=True, index=True) - person_name = Column(Text, nullable=True) - name_reason = Column(Text, nullable=True) - platform = Column(Text, nullable=False) - user_id = Column(get_string_field(50), nullable=False, index=True) - nickname = Column(Text, nullable=True) - impression = Column(Text, nullable=True) - short_impression = Column(Text, nullable=True) - points = Column(Text, nullable=True) - forgotten_points = Column(Text, nullable=True) - info_list = Column(Text, nullable=True) - know_times = Column(Float, nullable=True) - know_since = Column(Float, nullable=True) - last_know = Column(Float, nullable=True) - attitude = Column(Integer, nullable=True, default=50) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + person_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True) + person_name: Mapped[str | None] = mapped_column(Text, nullable=True) + name_reason: Mapped[str | None] = mapped_column(Text, nullable=True) + platform: Mapped[str] = mapped_column(Text, nullable=False) + user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) + nickname: Mapped[str | None] = mapped_column(Text, nullable=True) + impression: Mapped[str | None] = mapped_column(Text, nullable=True) + short_impression: Mapped[str | None] = mapped_column(Text, nullable=True) + points: Mapped[str | None] = mapped_column(Text, nullable=True) + forgotten_points: Mapped[str | None] = mapped_column(Text, nullable=True) + info_list: Mapped[str | None] = mapped_column(Text, nullable=True) + know_times: Mapped[float | None] = mapped_column(Float, nullable=True) + know_since: Mapped[float | None] = mapped_column(Float, nullable=True) + last_know: Mapped[float | None] = mapped_column(Float, nullable=True) + attitude: Mapped[int | None] = mapped_column(Integer, nullable=True, default=50) __table_args__ = ( Index("idx_personinfo_person_id", "person_id"), @@ -384,13 +392,13 @@ class BotPersonalityInterests(Base): __tablename__ = "bot_personality_interests" - id = Column(Integer, primary_key=True, autoincrement=True) - personality_id = Column(get_string_field(100), nullable=False, index=True) - personality_description = Column(Text, nullable=False) - interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表 - embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002") - version = Column(Integer, nullable=False, default=1) - last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + personality_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) + personality_description: Mapped[str] = mapped_column(Text, nullable=False) + interest_tags: Mapped[str] = mapped_column(Text, nullable=False) + embedding_model: Mapped[str] = mapped_column(get_string_field(100), nullable=False, default="text-embedding-ada-002") + version: Mapped[int] = mapped_column(Integer, nullable=False, default=1) + last_updated: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, index=True) __table_args__ = ( Index("idx_botpersonality_personality_id", "personality_id"), @@ -404,13 +412,13 @@ class Memory(Base): __tablename__ = "memory" - id = Column(Integer, primary_key=True, autoincrement=True) - memory_id = Column(get_string_field(64), nullable=False, index=True) - chat_id = Column(Text, nullable=True) - memory_text = Column(Text, nullable=True) - keywords = Column(Text, nullable=True) - create_time = Column(Float, nullable=True) - last_view_time = Column(Float, nullable=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + memory_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) + chat_id: Mapped[str | None] = mapped_column(Text, nullable=True) + memory_text: Mapped[str | None] = mapped_column(Text, nullable=True) + keywords: Mapped[str | None] = mapped_column(Text, nullable=True) + create_time: Mapped[float | None] = mapped_column(Float, nullable=True) + last_view_time: Mapped[float | None] = mapped_column(Float, nullable=True) __table_args__ = (Index("idx_memory_memory_id", "memory_id"),) @@ -437,19 +445,19 @@ class ThinkingLog(Base): __tablename__ = "thinking_logs" - id = Column(Integer, primary_key=True, autoincrement=True) - chat_id = Column(get_string_field(64), nullable=False, index=True) - trigger_text = Column(Text, nullable=True) - response_text = Column(Text, nullable=True) - trigger_info_json = Column(Text, nullable=True) - response_info_json = Column(Text, nullable=True) - timing_results_json = Column(Text, nullable=True) - chat_history_json = Column(Text, nullable=True) - chat_history_in_thinking_json = Column(Text, nullable=True) - chat_history_after_response_json = Column(Text, nullable=True) - heartflow_data_json = Column(Text, nullable=True) - reasoning_data_json = Column(Text, nullable=True) - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) + trigger_text: Mapped[str | None] = mapped_column(Text, nullable=True) + response_text: Mapped[str | None] = mapped_column(Text, nullable=True) + trigger_info_json: Mapped[str | None] = mapped_column(Text, nullable=True) + response_info_json: Mapped[str | None] = mapped_column(Text, nullable=True) + timing_results_json: Mapped[str | None] = mapped_column(Text, nullable=True) + chat_history_json: Mapped[str | None] = mapped_column(Text, nullable=True) + chat_history_in_thinking_json: Mapped[str | None] = mapped_column(Text, nullable=True) + chat_history_after_response_json: Mapped[str | None] = mapped_column(Text, nullable=True) + heartflow_data_json: Mapped[str | None] = mapped_column(Text, nullable=True) + reasoning_data_json: Mapped[str | None] = mapped_column(Text, nullable=True) + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) __table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),) @@ -459,13 +467,13 @@ class GraphNodes(Base): __tablename__ = "graph_nodes" - id = Column(Integer, primary_key=True, autoincrement=True) - concept = Column(get_string_field(255), nullable=False, unique=True, index=True) - memory_items = Column(Text, nullable=False) - hash = Column(Text, nullable=False) - weight = Column(Float, nullable=False, default=1.0) - created_time = Column(Float, nullable=False) - last_modified = Column(Float, nullable=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + concept: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True) + memory_items: Mapped[str] = mapped_column(Text, nullable=False) + hash: Mapped[str] = mapped_column(Text, nullable=False) + weight: Mapped[float] = mapped_column(Float, nullable=False, default=1.0) + created_time: Mapped[float] = mapped_column(Float, nullable=False) + last_modified: Mapped[float] = mapped_column(Float, nullable=False) __table_args__ = (Index("idx_graphnodes_concept", "concept"),) @@ -475,13 +483,13 @@ class GraphEdges(Base): __tablename__ = "graph_edges" - id = Column(Integer, primary_key=True, autoincrement=True) - source = Column(get_string_field(255), nullable=False, index=True) - target = Column(get_string_field(255), nullable=False, index=True) - strength = Column(Integer, nullable=False) - hash = Column(Text, nullable=False) - created_time = Column(Float, nullable=False) - last_modified = Column(Float, nullable=False) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + source: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True) + target: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True) + strength: Mapped[int] = mapped_column(Integer, nullable=False) + hash: Mapped[str] = mapped_column(Text, nullable=False) + created_time: Mapped[float] = mapped_column(Float, nullable=False) + last_modified: Mapped[float] = mapped_column(Float, nullable=False) __table_args__ = ( Index("idx_graphedges_source", "source"), @@ -494,11 +502,11 @@ class Schedule(Base): __tablename__ = "schedule" - id = Column(Integer, primary_key=True, autoincrement=True) - date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式 - schedule_data = Column(Text, nullable=False) # JSON格式的日程数据 - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) - updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + date: Mapped[str] = mapped_column(get_string_field(10), nullable=False, unique=True, index=True) + schedule_data: Mapped[str] = mapped_column(Text, nullable=False) + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) + updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) __table_args__ = (Index("idx_schedule_date", "date"),) @@ -508,17 +516,15 @@ class MaiZoneScheduleStatus(Base): __tablename__ = "maizone_schedule_status" - id = Column(Integer, primary_key=True, autoincrement=True) - datetime_hour = Column( - get_string_field(13), nullable=False, unique=True, index=True - ) # YYYY-MM-DD HH格式,精确到小时 - activity = Column(Text, nullable=False) # 该小时的活动内容 - is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理 - processed_at = Column(DateTime, nullable=True) # 处理时间 - story_content = Column(Text, nullable=True) # 生成的说说内容 - send_success = Column(Boolean, nullable=False, default=False) # 是否发送成功 - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) - updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + datetime_hour: Mapped[str] = mapped_column(get_string_field(13), nullable=False, unique=True, index=True) + activity: Mapped[str] = mapped_column(Text, nullable=False) + is_processed: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + processed_at: Mapped[datetime.datetime | None] = mapped_column(DateTime, nullable=True) + story_content: Mapped[str | None] = mapped_column(Text, nullable=True) + send_success: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) + updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) __table_args__ = ( Index("idx_maizone_datetime_hour", "datetime_hour"), @@ -527,16 +533,20 @@ class MaiZoneScheduleStatus(Base): class BanUser(Base): - """被禁用用户模型""" + """被禁用用户模型 + + 使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型, + 避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。 + """ __tablename__ = "ban_users" - id = Column(Integer, primary_key=True, autoincrement=True) - platform = Column(Text, nullable=False) - user_id = Column(get_string_field(50), nullable=False, index=True) - violation_num = Column(Integer, nullable=False, default=0) - reason = Column(Text, nullable=False) - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + platform: Mapped[str] = mapped_column(Text, nullable=False) + user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) + violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) + reason: Mapped[str] = mapped_column(Text, nullable=False) + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) __table_args__ = ( Index("idx_violation_num", "violation_num"), @@ -551,38 +561,38 @@ class AntiInjectionStats(Base): __tablename__ = "anti_injection_stats" - id = Column(Integer, primary_key=True, autoincrement=True) - total_messages = Column(Integer, nullable=False, default=0) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + total_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0) """总处理消息数""" - detected_injections = Column(Integer, nullable=False, default=0) + detected_injections: Mapped[int] = mapped_column(Integer, nullable=False, default=0) """检测到的注入攻击数""" - blocked_messages = Column(Integer, nullable=False, default=0) + blocked_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0) """被阻止的消息数""" - shielded_messages = Column(Integer, nullable=False, default=0) + shielded_messages: Mapped[int] = mapped_column(Integer, nullable=False, default=0) """被加盾的消息数""" - processing_time_total = Column(Float, nullable=False, default=0.0) + processing_time_total: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) """总处理时间""" - total_process_time = Column(Float, nullable=False, default=0.0) + total_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) """累计总处理时间""" - last_process_time = Column(Float, nullable=False, default=0.0) + last_process_time: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) """最近一次处理时间""" - error_count = Column(Integer, nullable=False, default=0) + error_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) """错误计数""" - start_time = Column(DateTime, nullable=False, default=datetime.datetime.now) + start_time: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) """统计开始时间""" - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) """记录创建时间""" - updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) + updated_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) """记录更新时间""" __table_args__ = ( @@ -596,26 +606,26 @@ class CacheEntries(Base): __tablename__ = "cache_entries" - id = Column(Integer, primary_key=True, autoincrement=True) - cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + cache_key: Mapped[str] = mapped_column(get_string_field(500), nullable=False, unique=True, index=True) """缓存键,包含工具名、参数和代码哈希""" - cache_value = Column(Text, nullable=False) + cache_value: Mapped[str] = mapped_column(Text, nullable=False) """缓存的数据,JSON格式""" - expires_at = Column(Float, nullable=False, index=True) + expires_at: Mapped[float] = mapped_column(Float, nullable=False, index=True) """过期时间戳""" - tool_name = Column(get_string_field(100), nullable=False, index=True) + tool_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) """工具名称""" - created_at = Column(Float, nullable=False, default=lambda: time.time()) + created_at: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time()) """创建时间戳""" - last_accessed = Column(Float, nullable=False, default=lambda: time.time()) + last_accessed: Mapped[float] = mapped_column(Float, nullable=False, default=lambda: time.time()) """最后访问时间戳""" - access_count = Column(Integer, nullable=False, default=0) + access_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) """访问次数""" __table_args__ = ( @@ -631,18 +641,16 @@ class MonthlyPlan(Base): __tablename__ = "monthly_plans" - id = Column(Integer, primary_key=True, autoincrement=True) - plan_text = Column(Text, nullable=False) - target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM" - status = Column( - get_string_field(20), nullable=False, default="active", index=True - ) # 'active', 'completed', 'archived' - usage_count = Column(Integer, nullable=False, default=0) - last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format - created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + plan_text: Mapped[str] = mapped_column(Text, nullable=False) + target_month: Mapped[str] = mapped_column(String(7), nullable=False, index=True) + status: Mapped[str] = mapped_column(get_string_field(20), nullable=False, default="active", index=True) + usage_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0) + last_used_date: Mapped[str | None] = mapped_column(String(10), nullable=True, index=True) + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) # 保留 is_deleted 字段以兼容现有数据,但标记为已弃用 - is_deleted = Column(Boolean, nullable=False, default=False) + is_deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) __table_args__ = ( Index("idx_monthlyplan_target_month_status", "target_month", "status"), @@ -807,12 +815,12 @@ class PermissionNodes(Base): __tablename__ = "permission_nodes" - id = Column(Integer, primary_key=True, autoincrement=True) - node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称 - description = Column(Text, nullable=False) # 权限描述 - plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件 - default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权 - created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间 + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + node_name: Mapped[str] = mapped_column(get_string_field(255), nullable=False, unique=True, index=True) + description: Mapped[str] = mapped_column(Text, nullable=False) + plugin_name: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) + default_granted: Mapped[bool] = mapped_column(Boolean, default=False, nullable=False) + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False) __table_args__ = ( Index("idx_permission_plugin", "plugin_name"), @@ -825,13 +833,13 @@ class UserPermissions(Base): __tablename__ = "user_permissions" - id = Column(Integer, primary_key=True, autoincrement=True) - platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型 - user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID - permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称 - granted = Column(Boolean, default=True, nullable=False) # 是否授权 - granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间 - granted_by = Column(get_string_field(100), nullable=True) # 授权者信息 + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + platform: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) + user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, index=True) + permission_node: Mapped[str] = mapped_column(get_string_field(255), nullable=False, index=True) + granted: Mapped[bool] = mapped_column(Boolean, default=True, nullable=False) + granted_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False) + granted_by: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True) __table_args__ = ( Index("idx_user_platform_id", "platform", "user_id"), @@ -845,13 +853,13 @@ class UserRelationships(Base): __tablename__ = "user_relationships" - id = Column(Integer, primary_key=True, autoincrement=True) - user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID - user_name = Column(get_string_field(100), nullable=True) # 用户名 - relationship_text = Column(Text, nullable=True) # 关系印象描述 - relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1) - last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间 - created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间 + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + user_id: Mapped[str] = mapped_column(get_string_field(100), nullable=False, unique=True, index=True) + user_name: Mapped[str | None] = mapped_column(get_string_field(100), nullable=True) + relationship_text: Mapped[str | None] = mapped_column(Text, nullable=True) + relationship_score: Mapped[float] = mapped_column(Float, nullable=False, default=0.3) # 关系分数(0-1) + last_updated: Mapped[float] = mapped_column(Float, nullable=False, default=time.time) + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, default=datetime.datetime.utcnow, nullable=False) __table_args__ = ( Index("idx_user_relationship_id", "user_id"), diff --git a/src/common/database/sqlalchemy_models.py.bak b/src/common/database/sqlalchemy_models.py.bak new file mode 100644 index 000000000..061ac6fad --- /dev/null +++ b/src/common/database/sqlalchemy_models.py.bak @@ -0,0 +1,872 @@ +"""SQLAlchemy数据库模型定义 + +替换Peewee ORM,使用SQLAlchemy提供更好的连接池管理和错误恢复能力 + +说明: 部分旧模型仍使用 `Column = Column(Type, ...)` 的经典风格。本文件开始逐步迁移到 +SQLAlchemy 2.0 推荐的带类型注解的声明式风格: + + field_name: Mapped[PyType] = mapped_column(Type, ...) + +这样 IDE / Pylance 能正确推断实例属性的真实 Python 类型,避免将其视为不可赋值的 Column 对象。 +当前仅对产生类型检查问题的模型 (BanUser) 进行了迁移,其余模型保持不变以减少一次性改动范围。 +""" + +import datetime +import os +import time +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Any + +from sqlalchemy import Boolean, Column, DateTime, Float, Index, Integer, String, Text, text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Mapped, mapped_column + +from src.common.database.connection_pool_manager import get_connection_pool_manager +from src.common.logger import get_logger + +logger = get_logger("sqlalchemy_models") + +# 创建基类 +Base = declarative_base() + + +async def enable_sqlite_wal_mode(engine): + """为 SQLite 启用 WAL 模式以提高并发性能""" + try: + async with engine.begin() as conn: + # 启用 WAL 模式 + await conn.execute(text("PRAGMA journal_mode = WAL")) + # 设置适中的同步级别,平衡性能和安全性 + await conn.execute(text("PRAGMA synchronous = NORMAL")) + # 启用外键约束 + await conn.execute(text("PRAGMA foreign_keys = ON")) + # 设置 busy_timeout,避免锁定错误 + await conn.execute(text("PRAGMA busy_timeout = 60000")) # 60秒 + + logger.info("[SQLite] WAL 模式已启用,并发性能已优化") + except Exception as e: + logger.warning(f"[SQLite] 启用 WAL 模式失败: {e},将使用默认配置") + + +async def maintain_sqlite_database(): + """定期维护 SQLite 数据库性能""" + try: + engine, SessionLocal = await initialize_database() + if not engine: + return + + async with engine.begin() as conn: + # 检查并确保 WAL 模式仍然启用 + result = await conn.execute(text("PRAGMA journal_mode")) + journal_mode = result.scalar() + + if journal_mode != "wal": + await conn.execute(text("PRAGMA journal_mode = WAL")) + logger.info("[SQLite] WAL 模式已重新启用") + + # 优化数据库性能 + await conn.execute(text("PRAGMA synchronous = NORMAL")) + await conn.execute(text("PRAGMA busy_timeout = 60000")) + await conn.execute(text("PRAGMA foreign_keys = ON")) + + # 定期清理(可选,根据需要启用) + # await conn.execute(text("PRAGMA optimize")) + + logger.info("[SQLite] 数据库维护完成") + except Exception as e: + logger.warning(f"[SQLite] 数据库维护失败: {e}") + + +def get_sqlite_performance_config(): + """获取 SQLite 性能优化配置""" + return { + "journal_mode": "WAL", # 提高并发性能 + "synchronous": "NORMAL", # 平衡性能和安全性 + "busy_timeout": 60000, # 60秒超时 + "foreign_keys": "ON", # 启用外键约束 + "cache_size": -10000, # 10MB 缓存 + "temp_store": "MEMORY", # 临时存储使用内存 + "mmap_size": 268435456, # 256MB 内存映射 + } + + +# MySQL兼容的字段类型辅助函数 +def get_string_field(max_length=255, **kwargs): + """ + 根据数据库类型返回合适的字符串字段 + MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text + """ + from src.config.config import global_config + + if global_config.database.database_type == "mysql": + return String(max_length, **kwargs) + else: + return Text(**kwargs) + + +class ChatStreams(Base): + """聊天流模型""" + + __tablename__ = "chat_streams" + + id = Column(Integer, primary_key=True, autoincrement=True) + stream_id = Column(get_string_field(64), nullable=False, unique=True, index=True) + create_time = Column(Float, nullable=False) + group_platform = Column(Text, nullable=True) + group_id = Column(get_string_field(100), nullable=True, index=True) + group_name = Column(Text, nullable=True) + last_active_time = Column(Float, nullable=False) + platform = Column(Text, nullable=False) + user_platform = Column(Text, nullable=False) + user_id = Column(get_string_field(100), nullable=False, index=True) + user_nickname = Column(Text, nullable=False) + user_cardname = Column(Text, nullable=True) + energy_value = Column(Float, nullable=True, default=5.0) + sleep_pressure = Column(Float, nullable=True, default=0.0) + focus_energy = Column(Float, nullable=True, default=0.5) + # 动态兴趣度系统字段 + base_interest_energy = Column(Float, nullable=True, default=0.5) + message_interest_total = Column(Float, nullable=True, default=0.0) + message_count = Column(Integer, nullable=True, default=0) + action_count = Column(Integer, nullable=True, default=0) + reply_count = Column(Integer, nullable=True, default=0) + last_interaction_time = Column(Float, nullable=True, default=None) + consecutive_no_reply = Column(Integer, nullable=True, default=0) + # 消息打断系统字段 + interruption_count = Column(Integer, nullable=True, default=0) + + __table_args__ = ( + Index("idx_chatstreams_stream_id", "stream_id"), + Index("idx_chatstreams_user_id", "user_id"), + Index("idx_chatstreams_group_id", "group_id"), + ) + + +class LLMUsage(Base): + """LLM使用记录模型""" + + __tablename__ = "llm_usage" + + id = Column(Integer, primary_key=True, autoincrement=True) + model_name = Column(get_string_field(100), nullable=False, index=True) + model_assign_name = Column(get_string_field(100), index=True) # 添加索引 + model_api_provider = Column(get_string_field(100), index=True) # 添加索引 + user_id = Column(get_string_field(50), nullable=False, index=True) + request_type = Column(get_string_field(50), nullable=False, index=True) + endpoint = Column(Text, nullable=False) + prompt_tokens = Column(Integer, nullable=False) + completion_tokens = Column(Integer, nullable=False) + time_cost = Column(Float, nullable=True) + total_tokens = Column(Integer, nullable=False) + cost = Column(Float, nullable=False) + status = Column(Text, nullable=False) + timestamp = Column(DateTime, nullable=False, index=True, default=datetime.datetime.now) + + __table_args__ = ( + Index("idx_llmusage_model_name", "model_name"), + Index("idx_llmusage_model_assign_name", "model_assign_name"), + Index("idx_llmusage_model_api_provider", "model_api_provider"), + Index("idx_llmusage_time_cost", "time_cost"), + Index("idx_llmusage_user_id", "user_id"), + Index("idx_llmusage_request_type", "request_type"), + Index("idx_llmusage_timestamp", "timestamp"), + ) + + +class Emoji(Base): + """表情包模型""" + + __tablename__ = "emoji" + + id = Column(Integer, primary_key=True, autoincrement=True) + full_path = Column(get_string_field(500), nullable=False, unique=True, index=True) + format = Column(Text, nullable=False) + emoji_hash = Column(get_string_field(64), nullable=False, index=True) + description = Column(Text, nullable=False) + query_count = Column(Integer, nullable=False, default=0) + is_registered = Column(Boolean, nullable=False, default=False) + is_banned = Column(Boolean, nullable=False, default=False) + emotion = Column(Text, nullable=True) + record_time = Column(Float, nullable=False) + register_time = Column(Float, nullable=True) + usage_count = Column(Integer, nullable=False, default=0) + last_used_time = Column(Float, nullable=True) + + __table_args__ = ( + Index("idx_emoji_full_path", "full_path"), + Index("idx_emoji_hash", "emoji_hash"), + ) + + +class Messages(Base): + """消息模型""" + + __tablename__ = "messages" + + id = Column(Integer, primary_key=True, autoincrement=True) + message_id = Column(get_string_field(100), nullable=False, index=True) + time = Column(Float, nullable=False) + chat_id = Column(get_string_field(64), nullable=False, index=True) + reply_to = Column(Text, nullable=True) + interest_value = Column(Float, nullable=True) + key_words = Column(Text, nullable=True) + key_words_lite = Column(Text, nullable=True) + is_mentioned = Column(Boolean, nullable=True) + + # 从 chat_info 扁平化而来的字段 + chat_info_stream_id = Column(Text, nullable=False) + chat_info_platform = Column(Text, nullable=False) + chat_info_user_platform = Column(Text, nullable=False) + chat_info_user_id = Column(Text, nullable=False) + chat_info_user_nickname = Column(Text, nullable=False) + chat_info_user_cardname = Column(Text, nullable=True) + chat_info_group_platform = Column(Text, nullable=True) + chat_info_group_id = Column(Text, nullable=True) + chat_info_group_name = Column(Text, nullable=True) + chat_info_create_time = Column(Float, nullable=False) + chat_info_last_active_time = Column(Float, nullable=False) + + # 从顶层 user_info 扁平化而来的字段 + user_platform = Column(Text, nullable=True) + user_id = Column(get_string_field(100), nullable=True, index=True) + user_nickname = Column(Text, nullable=True) + user_cardname = Column(Text, nullable=True) + + processed_plain_text = Column(Text, nullable=True) + display_message = Column(Text, nullable=True) + memorized_times = Column(Integer, nullable=False, default=0) + priority_mode = Column(Text, nullable=True) + priority_info = Column(Text, nullable=True) + additional_config = Column(Text, nullable=True) + is_emoji = Column(Boolean, nullable=False, default=False) + is_picid = Column(Boolean, nullable=False, default=False) + is_command = Column(Boolean, nullable=False, default=False) + is_notify = Column(Boolean, nullable=False, default=False) + + # 兴趣度系统字段 + actions = Column(Text, nullable=True) # JSON格式存储动作列表 + should_reply = Column(Boolean, nullable=True, default=False) + should_act = Column(Boolean, nullable=True, default=False) + + __table_args__ = ( + Index("idx_messages_message_id", "message_id"), + Index("idx_messages_chat_id", "chat_id"), + Index("idx_messages_time", "time"), + Index("idx_messages_user_id", "user_id"), + Index("idx_messages_should_reply", "should_reply"), + Index("idx_messages_should_act", "should_act"), + ) + + +class ActionRecords(Base): + """动作记录模型""" + + __tablename__ = "action_records" + + id = Column(Integer, primary_key=True, autoincrement=True) + action_id = Column(get_string_field(100), nullable=False, index=True) + time = Column(Float, nullable=False) + action_name = Column(Text, nullable=False) + action_data = Column(Text, nullable=False) + action_done = Column(Boolean, nullable=False, default=False) + action_build_into_prompt = Column(Boolean, nullable=False, default=False) + action_prompt_display = Column(Text, nullable=False) + chat_id = Column(get_string_field(64), nullable=False, index=True) + chat_info_stream_id = Column(Text, nullable=False) + chat_info_platform = Column(Text, nullable=False) + + __table_args__ = ( + Index("idx_actionrecords_action_id", "action_id"), + Index("idx_actionrecords_chat_id", "chat_id"), + Index("idx_actionrecords_time", "time"), + ) + + +class Images(Base): + """图像信息模型""" + + __tablename__ = "images" + + id = Column(Integer, primary_key=True, autoincrement=True) + image_id = Column(Text, nullable=False, default="") + emoji_hash = Column(get_string_field(64), nullable=False, index=True) + description = Column(Text, nullable=True) + path = Column(get_string_field(500), nullable=False, unique=True) + count = Column(Integer, nullable=False, default=1) + timestamp = Column(Float, nullable=False) + type = Column(Text, nullable=False) + vlm_processed = Column(Boolean, nullable=False, default=False) + + __table_args__ = ( + Index("idx_images_emoji_hash", "emoji_hash"), + Index("idx_images_path", "path"), + ) + + +class ImageDescriptions(Base): + """图像描述信息模型""" + + __tablename__ = "image_descriptions" + + id = Column(Integer, primary_key=True, autoincrement=True) + type = Column(Text, nullable=False) + image_description_hash = Column(get_string_field(64), nullable=False, index=True) + description = Column(Text, nullable=False) + timestamp = Column(Float, nullable=False) + + __table_args__ = (Index("idx_imagedesc_hash", "image_description_hash"),) + + +class Videos(Base): + """视频信息模型""" + + __tablename__ = "videos" + + id = Column(Integer, primary_key=True, autoincrement=True) + video_id = Column(Text, nullable=False, default="") + video_hash = Column(get_string_field(64), nullable=False, index=True, unique=True) + description = Column(Text, nullable=True) + count = Column(Integer, nullable=False, default=1) + timestamp = Column(Float, nullable=False) + vlm_processed = Column(Boolean, nullable=False, default=False) + + # 视频特有属性 + duration = Column(Float, nullable=True) # 视频时长(秒) + frame_count = Column(Integer, nullable=True) # 总帧数 + fps = Column(Float, nullable=True) # 帧率 + resolution = Column(Text, nullable=True) # 分辨率 + file_size = Column(Integer, nullable=True) # 文件大小(字节) + + __table_args__ = ( + Index("idx_videos_video_hash", "video_hash"), + Index("idx_videos_timestamp", "timestamp"), + ) + + +class OnlineTime(Base): + """在线时长记录模型""" + + __tablename__ = "online_time" + + id = Column(Integer, primary_key=True, autoincrement=True) + timestamp = Column(Text, nullable=False, default=str(datetime.datetime.now)) + duration = Column(Integer, nullable=False) + start_timestamp = Column(DateTime, nullable=False, default=datetime.datetime.now) + end_timestamp = Column(DateTime, nullable=False, index=True) + + __table_args__ = (Index("idx_onlinetime_end_timestamp", "end_timestamp"),) + + +class PersonInfo(Base): + """人物信息模型""" + + __tablename__ = "person_info" + + id = Column(Integer, primary_key=True, autoincrement=True) + person_id = Column(get_string_field(100), nullable=False, unique=True, index=True) + person_name = Column(Text, nullable=True) + name_reason = Column(Text, nullable=True) + platform = Column(Text, nullable=False) + user_id = Column(get_string_field(50), nullable=False, index=True) + nickname = Column(Text, nullable=True) + impression = Column(Text, nullable=True) + short_impression = Column(Text, nullable=True) + points = Column(Text, nullable=True) + forgotten_points = Column(Text, nullable=True) + info_list = Column(Text, nullable=True) + know_times = Column(Float, nullable=True) + know_since = Column(Float, nullable=True) + last_know = Column(Float, nullable=True) + attitude = Column(Integer, nullable=True, default=50) + + __table_args__ = ( + Index("idx_personinfo_person_id", "person_id"), + Index("idx_personinfo_user_id", "user_id"), + ) + + +class BotPersonalityInterests(Base): + """机器人人格兴趣标签模型""" + + __tablename__ = "bot_personality_interests" + + id = Column(Integer, primary_key=True, autoincrement=True) + personality_id = Column(get_string_field(100), nullable=False, index=True) + personality_description = Column(Text, nullable=False) + interest_tags = Column(Text, nullable=False) # JSON格式存储的兴趣标签列表 + embedding_model = Column(get_string_field(100), nullable=False, default="text-embedding-ada-002") + version = Column(Integer, nullable=False, default=1) + last_updated = Column(DateTime, nullable=False, default=datetime.datetime.now, index=True) + + __table_args__ = ( + Index("idx_botpersonality_personality_id", "personality_id"), + Index("idx_botpersonality_version", "version"), + Index("idx_botpersonality_last_updated", "last_updated"), + ) + + +class Memory(Base): + """记忆模型""" + + __tablename__ = "memory" + + id = Column(Integer, primary_key=True, autoincrement=True) + memory_id = Column(get_string_field(64), nullable=False, index=True) + chat_id = Column(Text, nullable=True) + memory_text = Column(Text, nullable=True) + keywords = Column(Text, nullable=True) + create_time = Column(Float, nullable=True) + last_view_time = Column(Float, nullable=True) + + __table_args__ = (Index("idx_memory_memory_id", "memory_id"),) + + +class Expression(Base): + """表达风格模型""" + + __tablename__ = "expression" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + situation: Mapped[str] = mapped_column(Text, nullable=False) + style: Mapped[str] = mapped_column(Text, nullable=False) + count: Mapped[float] = mapped_column(Float, nullable=False) + last_active_time: Mapped[float] = mapped_column(Float, nullable=False) + chat_id: Mapped[str] = mapped_column(get_string_field(64), nullable=False, index=True) + type: Mapped[str] = mapped_column(Text, nullable=False) + create_date: Mapped[float | None] = mapped_column(Float, nullable=True) + + __table_args__ = (Index("idx_expression_chat_id", "chat_id"),) + + +class ThinkingLog(Base): + """思考日志模型""" + + __tablename__ = "thinking_logs" + + id = Column(Integer, primary_key=True, autoincrement=True) + chat_id = Column(get_string_field(64), nullable=False, index=True) + trigger_text = Column(Text, nullable=True) + response_text = Column(Text, nullable=True) + trigger_info_json = Column(Text, nullable=True) + response_info_json = Column(Text, nullable=True) + timing_results_json = Column(Text, nullable=True) + chat_history_json = Column(Text, nullable=True) + chat_history_in_thinking_json = Column(Text, nullable=True) + chat_history_after_response_json = Column(Text, nullable=True) + heartflow_data_json = Column(Text, nullable=True) + reasoning_data_json = Column(Text, nullable=True) + created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + + __table_args__ = (Index("idx_thinkinglog_chat_id", "chat_id"),) + + +class GraphNodes(Base): + """记忆图节点模型""" + + __tablename__ = "graph_nodes" + + id = Column(Integer, primary_key=True, autoincrement=True) + concept = Column(get_string_field(255), nullable=False, unique=True, index=True) + memory_items = Column(Text, nullable=False) + hash = Column(Text, nullable=False) + weight = Column(Float, nullable=False, default=1.0) + created_time = Column(Float, nullable=False) + last_modified = Column(Float, nullable=False) + + __table_args__ = (Index("idx_graphnodes_concept", "concept"),) + + +class GraphEdges(Base): + """记忆图边模型""" + + __tablename__ = "graph_edges" + + id = Column(Integer, primary_key=True, autoincrement=True) + source = Column(get_string_field(255), nullable=False, index=True) + target = Column(get_string_field(255), nullable=False, index=True) + strength = Column(Integer, nullable=False) + hash = Column(Text, nullable=False) + created_time = Column(Float, nullable=False) + last_modified = Column(Float, nullable=False) + + __table_args__ = ( + Index("idx_graphedges_source", "source"), + Index("idx_graphedges_target", "target"), + ) + + +class Schedule(Base): + """日程模型""" + + __tablename__ = "schedule" + + id = Column(Integer, primary_key=True, autoincrement=True) + date = Column(get_string_field(10), nullable=False, unique=True, index=True) # YYYY-MM-DD格式 + schedule_data = Column(Text, nullable=False) # JSON格式的日程数据 + created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) + + __table_args__ = (Index("idx_schedule_date", "date"),) + + +class MaiZoneScheduleStatus(Base): + """麦麦空间日程处理状态模型""" + + __tablename__ = "maizone_schedule_status" + + id = Column(Integer, primary_key=True, autoincrement=True) + datetime_hour = Column( + get_string_field(13), nullable=False, unique=True, index=True + ) # YYYY-MM-DD HH格式,精确到小时 + activity = Column(Text, nullable=False) # 该小时的活动内容 + is_processed = Column(Boolean, nullable=False, default=False) # 是否已处理 + processed_at = Column(DateTime, nullable=True) # 处理时间 + story_content = Column(Text, nullable=True) # 生成的说说内容 + send_success = Column(Boolean, nullable=False, default=False) # 是否发送成功 + created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) + + __table_args__ = ( + Index("idx_maizone_datetime_hour", "datetime_hour"), + Index("idx_maizone_is_processed", "is_processed"), + ) + + +class BanUser(Base): + """被禁用用户模型 + + 使用 SQLAlchemy 2.0 类型标注写法,方便静态类型检查器识别实际字段类型, + 避免在业务代码中对属性赋值时报 `Column[...]` 不可赋值的告警。 + """ + + __tablename__ = "ban_users" + + id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) + platform: Mapped[str] = mapped_column(Text, nullable=False) + user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) + violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) + reason: Mapped[str] = mapped_column(Text, nullable=False) + created_at: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False, default=datetime.datetime.now) + + __table_args__ = ( + Index("idx_violation_num", "violation_num"), + Index("idx_banuser_user_id", "user_id"), + Index("idx_banuser_platform", "platform"), + Index("idx_banuser_platform_user_id", "platform", "user_id"), + ) + + +class AntiInjectionStats(Base): + """反注入系统统计模型""" + + __tablename__ = "anti_injection_stats" + + id = Column(Integer, primary_key=True, autoincrement=True) + total_messages = Column(Integer, nullable=False, default=0) + """总处理消息数""" + + detected_injections = Column(Integer, nullable=False, default=0) + """检测到的注入攻击数""" + + blocked_messages = Column(Integer, nullable=False, default=0) + """被阻止的消息数""" + + shielded_messages = Column(Integer, nullable=False, default=0) + """被加盾的消息数""" + + processing_time_total = Column(Float, nullable=False, default=0.0) + """总处理时间""" + + total_process_time = Column(Float, nullable=False, default=0.0) + """累计总处理时间""" + + last_process_time = Column(Float, nullable=False, default=0.0) + """最近一次处理时间""" + + error_count = Column(Integer, nullable=False, default=0) + """错误计数""" + + start_time = Column(DateTime, nullable=False, default=datetime.datetime.now) + """统计开始时间""" + + created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + """记录创建时间""" + + updated_at = Column(DateTime, nullable=False, default=datetime.datetime.now, onupdate=datetime.datetime.now) + """记录更新时间""" + + __table_args__ = ( + Index("idx_anti_injection_stats_created_at", "created_at"), + Index("idx_anti_injection_stats_updated_at", "updated_at"), + ) + + +class CacheEntries(Base): + """工具缓存条目模型""" + + __tablename__ = "cache_entries" + + id = Column(Integer, primary_key=True, autoincrement=True) + cache_key = Column(get_string_field(500), nullable=False, unique=True, index=True) + """缓存键,包含工具名、参数和代码哈希""" + + cache_value = Column(Text, nullable=False) + """缓存的数据,JSON格式""" + + expires_at = Column(Float, nullable=False, index=True) + """过期时间戳""" + + tool_name = Column(get_string_field(100), nullable=False, index=True) + """工具名称""" + + created_at = Column(Float, nullable=False, default=lambda: time.time()) + """创建时间戳""" + + last_accessed = Column(Float, nullable=False, default=lambda: time.time()) + """最后访问时间戳""" + + access_count = Column(Integer, nullable=False, default=0) + """访问次数""" + + __table_args__ = ( + Index("idx_cache_entries_key", "cache_key"), + Index("idx_cache_entries_expires_at", "expires_at"), + Index("idx_cache_entries_tool_name", "tool_name"), + Index("idx_cache_entries_created_at", "created_at"), + ) + + +class MonthlyPlan(Base): + """月度计划模型""" + + __tablename__ = "monthly_plans" + + id = Column(Integer, primary_key=True, autoincrement=True) + plan_text = Column(Text, nullable=False) + target_month = Column(String(7), nullable=False, index=True) # "YYYY-MM" + status = Column( + get_string_field(20), nullable=False, default="active", index=True + ) # 'active', 'completed', 'archived' + usage_count = Column(Integer, nullable=False, default=0) + last_used_date = Column(String(10), nullable=True, index=True) # "YYYY-MM-DD" format + created_at = Column(DateTime, nullable=False, default=datetime.datetime.now) + + # 保留 is_deleted 字段以兼容现有数据,但标记为已弃用 + is_deleted = Column(Boolean, nullable=False, default=False) + + __table_args__ = ( + Index("idx_monthlyplan_target_month_status", "target_month", "status"), + Index("idx_monthlyplan_last_used_date", "last_used_date"), + Index("idx_monthlyplan_usage_count", "usage_count"), + # 保留旧索引以兼容 + Index("idx_monthlyplan_target_month_is_deleted", "target_month", "is_deleted"), + ) + + +# 数据库引擎和会话管理 +_engine = None +_SessionLocal = None + + +def get_database_url(): + """获取数据库连接URL""" + from src.config.config import global_config + + config = global_config.database + + if config.database_type == "mysql": + # 对用户名和密码进行URL编码,处理特殊字符 + from urllib.parse import quote_plus + + encoded_user = quote_plus(config.mysql_user) + encoded_password = quote_plus(config.mysql_password) + + # 检查是否配置了Unix socket连接 + if config.mysql_unix_socket: + # 使用Unix socket连接 + encoded_socket = quote_plus(config.mysql_unix_socket) + return ( + f"mysql+aiomysql://{encoded_user}:{encoded_password}" + f"@/{config.mysql_database}" + f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" + ) + else: + # 使用标准TCP连接 + return ( + f"mysql+aiomysql://{encoded_user}:{encoded_password}" + f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" + f"?charset={config.mysql_charset}" + ) + else: # SQLite + # 如果是相对路径,则相对于项目根目录 + if not os.path.isabs(config.sqlite_path): + ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..")) + db_path = os.path.join(ROOT_PATH, config.sqlite_path) + else: + db_path = config.sqlite_path + + # 确保数据库目录存在 + os.makedirs(os.path.dirname(db_path), exist_ok=True) + + return f"sqlite+aiosqlite:///{db_path}" + + +async def initialize_database(): + """初始化异步数据库引擎和会话""" + global _engine, _SessionLocal + + if _engine is not None: + return _engine, _SessionLocal + + database_url = get_database_url() + from src.config.config import global_config + + config = global_config.database + + # 配置引擎参数 + engine_kwargs: dict[str, Any] = { + "echo": False, # 生产环境关闭SQL日志 + "future": True, + } + + if config.database_type == "mysql": + # MySQL连接池配置 - 异步引擎使用默认连接池 + engine_kwargs.update( + { + "pool_size": config.connection_pool_size, + "max_overflow": config.connection_pool_size * 2, + "pool_timeout": config.connection_timeout, + "pool_recycle": 3600, # 1小时回收连接 + "pool_pre_ping": True, # 连接前ping检查 + "connect_args": { + "autocommit": config.mysql_autocommit, + "charset": config.mysql_charset, + "connect_timeout": config.connection_timeout, + }, + } + ) + else: + # SQLite配置 - aiosqlite不支持连接池参数 + engine_kwargs.update( + { + "connect_args": { + "check_same_thread": False, + "timeout": 60, # 增加超时时间 + }, + } + ) + + _engine = create_async_engine(database_url, **engine_kwargs) + _SessionLocal = async_sessionmaker(bind=_engine, class_=AsyncSession, expire_on_commit=False) + + # 调用新的迁移函数,它会处理表的创建和列的添加 + from src.common.database.db_migration import check_and_migrate_database + + await check_and_migrate_database() + + # 如果是 SQLite,启用 WAL 模式以提高并发性能 + if config.database_type == "sqlite": + await enable_sqlite_wal_mode(_engine) + + logger.info(f"SQLAlchemy异步数据库初始化成功: {config.database_type}") + return _engine, _SessionLocal + + +@asynccontextmanager +async def get_db_session() -> AsyncGenerator[AsyncSession]: + """ + 异步数据库会话上下文管理器。 + 在初始化失败时会yield None,调用方需要检查会话是否为None。 + + 现在使用透明的连接池管理器来复用现有连接,提高并发性能。 + """ + SessionLocal = None + try: + _, SessionLocal = await initialize_database() + if not SessionLocal: + raise RuntimeError("数据库会话工厂 (_SessionLocal) 未初始化。") + except Exception as e: + logger.error(f"数据库初始化失败,无法创建会话: {e}") + raise + + # 使用连接池管理器获取会话 + pool_manager = get_connection_pool_manager() + + async with pool_manager.get_session(SessionLocal) as session: + # 对于 SQLite,在会话开始时设置 PRAGMA(仅对新连接) + from src.config.config import global_config + + if global_config.database.database_type == "sqlite": + try: + await session.execute(text("PRAGMA busy_timeout = 60000")) + await session.execute(text("PRAGMA foreign_keys = ON")) + except Exception as e: + logger.debug(f"设置 SQLite PRAGMA 时出错(可能是复用连接): {e}") + + yield session + + +async def get_engine(): + """获取异步数据库引擎""" + engine, _ = await initialize_database() + return engine + + +class PermissionNodes(Base): + """权限节点模型""" + + __tablename__ = "permission_nodes" + + id = Column(Integer, primary_key=True, autoincrement=True) + node_name = Column(get_string_field(255), nullable=False, unique=True, index=True) # 权限节点名称 + description = Column(Text, nullable=False) # 权限描述 + plugin_name = Column(get_string_field(100), nullable=False, index=True) # 所属插件 + default_granted = Column(Boolean, default=False, nullable=False) # 默认是否授权 + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间 + + __table_args__ = ( + Index("idx_permission_plugin", "plugin_name"), + Index("idx_permission_node", "node_name"), + ) + + +class UserPermissions(Base): + """用户权限模型""" + + __tablename__ = "user_permissions" + + id = Column(Integer, primary_key=True, autoincrement=True) + platform = Column(get_string_field(50), nullable=False, index=True) # 平台类型 + user_id = Column(get_string_field(100), nullable=False, index=True) # 用户ID + permission_node = Column(get_string_field(255), nullable=False, index=True) # 权限节点名称 + granted = Column(Boolean, default=True, nullable=False) # 是否授权 + granted_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 授权时间 + granted_by = Column(get_string_field(100), nullable=True) # 授权者信息 + + __table_args__ = ( + Index("idx_user_platform_id", "platform", "user_id"), + Index("idx_user_permission", "platform", "user_id", "permission_node"), + Index("idx_permission_granted", "permission_node", "granted"), + ) + + +class UserRelationships(Base): + """用户关系模型 - 存储用户与bot的关系数据""" + + __tablename__ = "user_relationships" + + id = Column(Integer, primary_key=True, autoincrement=True) + user_id = Column(get_string_field(100), nullable=False, unique=True, index=True) # 用户ID + user_name = Column(get_string_field(100), nullable=True) # 用户名 + relationship_text = Column(Text, nullable=True) # 关系印象描述 + relationship_score = Column(Float, nullable=False, default=0.3) # 关系分数(0-1) + last_updated = Column(Float, nullable=False, default=time.time) # 最后更新时间 + created_at = Column(DateTime, default=datetime.datetime.utcnow, nullable=False) # 创建时间 + + __table_args__ = ( + Index("idx_user_relationship_id", "user_id"), + Index("idx_relationship_score", "relationship_score"), + Index("idx_relationship_updated", "last_updated"), + )