diff --git a/pyproject.toml b/pyproject.toml index d5d481934..cbf1dd913 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,9 @@ dependencies = [ "rjieba>=0.1.13", "fastmcp>=2.13.0", "mofox-wire", - "jinja2>=3.1.0" + "jinja2>=3.1.0", + "psycopg2-binary", + "PyMySQL" ] [[tool.uv.index]] diff --git a/requirements.txt b/requirements.txt index 2cc36ddf6..d508e6435 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,10 @@ aiosqlite aiofiles aiomysql +asyncpg +psycopg[binary] +psycopg2-binary +PyMySQL APScheduler aiohttp aiohttp-cors diff --git a/scripts/check_expression_database.py b/scripts/check_expression_database.py deleted file mode 100644 index d1e8a47b6..000000000 --- a/scripts/check_expression_database.py +++ /dev/null @@ -1,117 +0,0 @@ -""" -检查表达方式数据库状态的诊断脚本 -""" -import asyncio -import sys -from pathlib import Path - -# 添加项目根目录到路径 -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -from sqlalchemy import func, select - -from src.common.database.compatibility import get_db_session -from src.common.database.core.models import Expression - - -async def check_database(): - """检查表达方式数据库状态""" - - print("=" * 60) - print("表达方式数据库诊断报告") - print("=" * 60) - - async with get_db_session() as session: - # 1. 统计总数 - total_count = await session.execute(select(func.count()).select_from(Expression)) - total = total_count.scalar() - print(f"\n📊 总表达方式数量: {total}") - - if total == 0: - print("\n⚠️ 数据库为空!") - print("\n可能的原因:") - print("1. 还没有进行过表达学习") - print("2. 配置中禁用了表达学习") - print("3. 学习过程中发生了错误") - print("\n建议:") - print("- 检查 bot_config.toml 中的 [expression] 配置") - print("- 查看日志中是否有表达学习相关的错误") - print("- 确认聊天流的 learn_expression 配置为 true") - return - - # 2. 按 chat_id 统计 - print("\n📝 按聊天流统计:") - chat_counts = await session.execute( - select(Expression.chat_id, func.count()) - .group_by(Expression.chat_id) - ) - for chat_id, count in chat_counts: - print(f" - {chat_id}: {count} 个表达方式") - - # 3. 按 type 统计 - print("\n📝 按类型统计:") - type_counts = await session.execute( - select(Expression.type, func.count()) - .group_by(Expression.type) - ) - for expr_type, count in type_counts: - print(f" - {expr_type}: {count} 个") - - # 4. 检查 situation 和 style 字段是否有空值 - print("\n🔍 字段完整性检查:") - null_situation = await session.execute( - select(func.count()) - .select_from(Expression) - .where(Expression.situation is None) - ) - null_style = await session.execute( - select(func.count()) - .select_from(Expression) - .where(Expression.style is None) - ) - - null_sit_count = null_situation.scalar() - null_sty_count = null_style.scalar() - - print(f" - situation 为空: {null_sit_count} 个") - print(f" - style 为空: {null_sty_count} 个") - - if null_sit_count > 0 or null_sty_count > 0: - print(" ⚠️ 发现空值!这会导致匹配失败") - - # 5. 显示一些样例数据 - print("\n📋 样例数据 (前10条):") - samples = await session.execute( - select(Expression) - .limit(10) - ) - - for i, expr in enumerate(samples.scalars(), 1): - print(f"\n [{i}] Chat: {expr.chat_id}") - print(f" Type: {expr.type}") - print(f" Situation: {expr.situation}") - print(f" Style: {expr.style}") - print(f" Count: {expr.count}") - - # 6. 检查 style 字段的唯一值 - print("\n📋 Style 字段样例 (前20个):") - unique_styles = await session.execute( - select(Expression.style) - .distinct() - .limit(20) - ) - - styles = list(unique_styles.scalars()) - for style in styles: - print(f" - {style}") - - print(f"\n (共 {len(styles)} 个不同的 style)") - - print("\n" + "=" * 60) - print("诊断完成") - print("=" * 60) - - -if __name__ == "__main__": - asyncio.run(check_database()) diff --git a/scripts/check_style_field.py b/scripts/check_style_field.py deleted file mode 100644 index 980f3a07a..000000000 --- a/scripts/check_style_field.py +++ /dev/null @@ -1,65 +0,0 @@ -""" -检查数据库中 style 字段的内容特征 -""" -import asyncio -import sys -from pathlib import Path - -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -from sqlalchemy import select - -from src.common.database.compatibility import get_db_session -from src.common.database.core.models import Expression - - -async def analyze_style_fields(): - """分析 style 字段的内容""" - - print("=" * 60) - print("Style 字段内容分析") - print("=" * 60) - - async with get_db_session() as session: - # 获取所有表达方式 - result = await session.execute(select(Expression).limit(30)) - expressions = result.scalars().all() - - print(f"\n总共检查 {len(expressions)} 条记录\n") - - # 按类型分类 - style_examples = [ - { - "situation": expr.situation, - "style": expr.style, - "length": len(expr.style) if expr.style else 0 - } - for expr in expressions if expr.type == "style" - ] - - print("📋 Style 类型样例 (前15条):") - print("="*60) - for i, ex in enumerate(style_examples[:15], 1): - print(f"\n[{i}]") - print(f" Situation: {ex['situation']}") - print(f" Style: {ex['style']}") - print(f" 长度: {ex['length']} 字符") - - # 判断是具体表达还是风格描述 - if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]): - style_type = "✓ 风格描述" - elif ex["length"] <= 10: - style_type = "? 可能是具体表达(较短)" - else: - style_type = "✗ 具体表达内容" - - print(f" 类型判断: {style_type}") - - print("\n" + "="*60) - print("分析完成") - print("="*60) - - -if __name__ == "__main__": - asyncio.run(analyze_style_fields()) diff --git a/scripts/cleanup_models.py b/scripts/cleanup_models.py deleted file mode 100644 index e02e8ce6b..000000000 --- a/scripts/cleanup_models.py +++ /dev/null @@ -1,49 +0,0 @@ -#!/usr/bin/env python3 -"""清理 core/models.py,只保留模型定义""" - -import os - -# 文件路径 -models_file = os.path.join( - os.path.dirname(os.path.dirname(__file__)), - "src", - "common", - "database", - "core", - "models.py" -) - -print(f"正在清理文件: {models_file}") - -# 读取文件 -with open(models_file, encoding="utf-8") as f: - lines = f.readlines() - -# 找到最后一个模型类的结束位置(MonthlyPlan的 __table_args__ 结束) -# 我们要保留到第593行(包含) -keep_lines = [] -found_end = False - -for i, line in enumerate(lines, 1): - keep_lines.append(line) - - # 检查是否到达 MonthlyPlan 的 __table_args__ 结束 - if i > 580 and line.strip() == ")": - # 再检查前一行是否有 Index 相关内容 - if "idx_monthlyplan" in "".join(lines[max(0, i-5):i]): - print(f"找到模型定义结束位置: 第 {i} 行") - found_end = True - break - -if not found_end: - print("❌ 未找到模型定义结束标记") - exit(1) - -# 写回文件 -with open(models_file, "w", encoding="utf-8") as f: - f.writelines(keep_lines) - -print("✅ 文件清理完成") -print(f"保留行数: {len(keep_lines)}") -print(f"原始行数: {len(lines)}") -print(f"删除行数: {len(lines) - len(keep_lines)}") diff --git a/scripts/debug_mcp_tools.py b/scripts/debug_mcp_tools.py deleted file mode 100644 index 27ba5b2e5..000000000 --- a/scripts/debug_mcp_tools.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -调试 MCP 工具列表获取 - -直接测试 MCP 客户端是否能获取工具 -""" - -import asyncio - -from fastmcp.client import Client, StreamableHttpTransport - - -async def test_direct_connection(): - """直接连接 MCP 服务器并获取工具列表""" - print("=" * 60) - print("直接测试 MCP 服务器连接") - print("=" * 60) - - url = "http://localhost:8000/mcp" - print(f"\n连接到: {url}") - - try: - # 创建传输层 - transport = StreamableHttpTransport(url) - print("✓ 传输层创建成功") - - # 创建客户端 - async with Client(transport) as client: - print("✓ 客户端连接成功") - - # 获取工具列表 - print("\n正在获取工具列表...") - tools_result = await client.list_tools() - - print(f"\n获取结果类型: {type(tools_result)}") - print(f"结果内容: {tools_result}") - - # 检查是否有 tools 属性 - if hasattr(tools_result, "tools"): - tools = tools_result.tools - print(f"\n✓ 找到 tools 属性,包含 {len(tools)} 个工具") - - for i, tool in enumerate(tools, 1): - print(f"\n工具 {i}:") - print(f" 名称: {tool.name}") - print(f" 描述: {tool.description}") - if hasattr(tool, "inputSchema"): - print(f" 参数 Schema: {tool.inputSchema}") - else: - print("\n✗ 结果中没有 tools 属性") - print(f"可用属性: {dir(tools_result)}") - - except Exception as e: - print(f"\n✗ 连接失败: {e}") - import traceback - traceback.print_exc() - - -if __name__ == "__main__": - asyncio.run(test_direct_connection()) diff --git a/scripts/debug_style_learner.py b/scripts/debug_style_learner.py deleted file mode 100644 index 1c0937ece..000000000 --- a/scripts/debug_style_learner.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -检查 StyleLearner 模型状态的诊断脚本 -""" -import sys -from pathlib import Path - -# 添加项目根目录到路径 -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -from src.chat.express.style_learner import style_learner_manager -from src.common.logger import get_logger - -logger = get_logger("debug_style_learner") - - -def check_style_learner_status(chat_id: str): - """检查指定 chat_id 的 StyleLearner 状态""" - - print("=" * 60) - print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}") - print("=" * 60) - - # 获取 learner - learner = style_learner_manager.get_learner(chat_id) - - # 1. 基本信息 - print("\n📊 基本信息:") - print(f" Chat ID: {learner.chat_id}") - print(f" 风格数量: {len(learner.style_to_id)}") - print(f" 下一个ID: {learner.next_style_id}") - print(f" 最大风格数: {learner.max_styles}") - - # 2. 学习统计 - print("\n📈 学习统计:") - print(f" 总样本数: {learner.learning_stats['total_samples']}") - print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}") - - # 3. 风格列表(前20个) - print("\n📋 已学习的风格 (前20个):") - all_styles = learner.get_all_styles() - if not all_styles: - print(" ⚠️ 没有任何风格!模型尚未训练") - else: - for i, style in enumerate(all_styles[:20], 1): - style_id = learner.style_to_id.get(style) - situation = learner.id_to_situation.get(style_id, "N/A") - print(f" [{i}] {style}") - print(f" (ID: {style_id}, Situation: {situation})") - - # 4. 测试预测 - print("\n🔮 测试预测功能:") - if not all_styles: - print(" ⚠️ 无法测试,模型没有训练数据") - else: - test_situations = [ - "表示惊讶", - "讨论游戏", - "表达赞同" - ] - - for test_sit in test_situations: - print(f"\n 测试输入: '{test_sit}'") - best_style, scores = learner.predict_style(test_sit, top_k=3) - - if best_style: - print(f" ✓ 最佳匹配: {best_style}") - print(" Top 3:") - for style, score in list(scores.items())[:3]: - print(f" - {style}: {score:.4f}") - else: - print(" ✗ 预测失败") - - print("\n" + "=" * 60) - print("诊断完成") - print("=" * 60) - - -if __name__ == "__main__": - # 从诊断报告中看到的 chat_id - test_chat_ids = [ - "52fb94af9f500a01e023ea780e43606e", # 有78个表达方式 - "46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式 - ] - - for chat_id in test_chat_ids: - check_style_learner_status(chat_id) - print("\n") diff --git a/scripts/deduplicate_memories.py b/scripts/deduplicate_memories.py deleted file mode 100644 index e44970d12..000000000 --- a/scripts/deduplicate_memories.py +++ /dev/null @@ -1,403 +0,0 @@ -""" -记忆去重工具 - -功能: -1. 扫描所有标记为"相似"关系的记忆边 -2. 对相似记忆进行去重(保留重要性高的,删除另一个) -3. 支持干运行模式(预览不执行) -4. 提供详细的去重报告 - -使用方法: - # 预览模式(不实际删除) - python scripts/deduplicate_memories.py --dry-run - - # 执行去重 - python scripts/deduplicate_memories.py - - # 指定相似度阈值 - python scripts/deduplicate_memories.py --threshold 0.9 - - # 指定数据目录 - python scripts/deduplicate_memories.py --data-dir data/memory_graph -""" -import argparse -import asyncio -import sys -from datetime import datetime -from pathlib import Path - -import numpy as np - -sys.path.insert(0, str(Path(__file__).parent.parent)) - -from src.common.logger import get_logger -from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager - -logger = get_logger(__name__) - - -class MemoryDeduplicator: - """记忆去重器""" - - def __init__(self, data_dir: str = "data/memory_graph", dry_run: bool = False, threshold: float = 0.85): - self.data_dir = data_dir - self.dry_run = dry_run - self.threshold = threshold - self.manager = None - - # 统计信息 - self.stats = { - "total_memories": 0, - "similar_pairs": 0, - "duplicates_found": 0, - "duplicates_removed": 0, - "errors": 0, - } - - async def initialize(self): - """初始化记忆管理器""" - logger.info(f"正在初始化记忆管理器 (data_dir={self.data_dir})...") - self.manager = await initialize_memory_manager(data_dir=self.data_dir) - if not self.manager: - raise RuntimeError("记忆管理器初始化失败") - - self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories()) - logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆") - - async def find_similar_pairs(self) -> list[tuple[str, str, float]]: - """ - 查找所有相似的记忆对(通过向量相似度计算) - - Returns: - [(memory_id_1, memory_id_2, similarity), ...] - """ - logger.info("正在扫描相似记忆对...") - similar_pairs = [] - seen_pairs = set() # 避免重复 - - # 获取所有记忆 - all_memories = self.manager.graph_store.get_all_memories() - total_memories = len(all_memories) - - logger.info(f"开始计算 {total_memories} 条记忆的相似度...") - - # 两两比较记忆的相似度 - for i, memory_i in enumerate(all_memories): - # 每处理10条记忆让出控制权 - if i % 10 == 0: - await asyncio.sleep(0) - if i > 0: - logger.info(f"进度: {i}/{total_memories} ({i*100//total_memories}%)") - - # 获取记忆i的向量(从主题节点) - vector_i = None - for node in memory_i.nodes: - if node.embedding is not None: - vector_i = node.embedding - break - - if vector_i is None: - continue - - # 与后续记忆比较 - for j in range(i + 1, total_memories): - memory_j = all_memories[j] - - # 获取记忆j的向量 - vector_j = None - for node in memory_j.nodes: - if node.embedding is not None: - vector_j = node.embedding - break - - if vector_j is None: - continue - - # 计算余弦相似度 - similarity = self._cosine_similarity(vector_i, vector_j) - - # 只保存满足阈值的相似对 - if similarity >= self.threshold: - pair_key = tuple(sorted([memory_i.id, memory_j.id])) - if pair_key not in seen_pairs: - seen_pairs.add(pair_key) - similar_pairs.append((memory_i.id, memory_j.id, similarity)) - - self.stats["similar_pairs"] = len(similar_pairs) - logger.info(f"找到 {len(similar_pairs)} 对相似记忆(阈值>={self.threshold})") - - return similar_pairs - - def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float: - """计算余弦相似度""" - try: - vec1_norm = np.linalg.norm(vec1) - vec2_norm = np.linalg.norm(vec2) - - if vec1_norm == 0 or vec2_norm == 0: - return 0.0 - - similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm) - return float(similarity) - except Exception as e: - logger.error(f"计算余弦相似度失败: {e}") - return 0.0 - - def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]: - """ - 决定保留哪个记忆,删除哪个 - - 优先级: - 1. 重要性更高的 - 2. 激活度更高的 - 3. 创建时间更早的 - - Returns: - (keep_id, remove_id) - """ - mem1 = self.manager.graph_store.get_memory_by_id(mem_id_1) - mem2 = self.manager.graph_store.get_memory_by_id(mem_id_2) - - if not mem1 or not mem2: - logger.warning(f"记忆不存在: {mem_id_1} or {mem_id_2}") - return None, None - - # 比较重要性 - if mem1.importance > mem2.importance: - return mem_id_1, mem_id_2 - elif mem1.importance < mem2.importance: - return mem_id_2, mem_id_1 - - # 重要性相同,比较激活度 - if mem1.activation > mem2.activation: - return mem_id_1, mem_id_2 - elif mem1.activation < mem2.activation: - return mem_id_2, mem_id_1 - - # 激活度也相同,保留更早创建的 - if mem1.created_at < mem2.created_at: - return mem_id_1, mem_id_2 - else: - return mem_id_2, mem_id_1 - - async def deduplicate_pair(self, mem_id_1: str, mem_id_2: str, similarity: float) -> bool: - """ - 去重一对相似记忆 - - Returns: - 是否成功去重 - """ - keep_id, remove_id = self.decide_which_to_keep(mem_id_1, mem_id_2) - - if not keep_id or not remove_id: - self.stats["errors"] += 1 - return False - - keep_mem = self.manager.graph_store.get_memory_by_id(keep_id) - remove_mem = self.manager.graph_store.get_memory_by_id(remove_id) - - logger.info("") - logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):") - logger.info(f" 保留: {keep_id}") - logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}") - logger.info(f" - 重要性: {keep_mem.importance:.2f}") - logger.info(f" - 激活度: {keep_mem.activation:.2f}") - logger.info(f" - 创建时间: {keep_mem.created_at}") - logger.info(f" 删除: {remove_id}") - logger.info(f" - 主题: {remove_mem.metadata.get('topic', 'N/A')}") - logger.info(f" - 重要性: {remove_mem.importance:.2f}") - logger.info(f" - 激活度: {remove_mem.activation:.2f}") - logger.info(f" - 创建时间: {remove_mem.created_at}") - - if self.dry_run: - logger.info(" [预览模式] 不执行实际删除") - self.stats["duplicates_found"] += 1 - return True - - try: - # 增强保留记忆的属性 - keep_mem.importance = min(1.0, keep_mem.importance + 0.05) - keep_mem.activation = min(1.0, keep_mem.activation + 0.05) - - # 累加访问次数 - if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"): - keep_mem.access_count += remove_mem.access_count - - # 删除相似记忆 - await self.manager.delete_memory(remove_id) - - self.stats["duplicates_removed"] += 1 - logger.info(" ✅ 删除成功") - - # 让出控制权 - await asyncio.sleep(0) - - return True - - except Exception as e: - logger.error(f" ❌ 删除失败: {e}") - self.stats["errors"] += 1 - return False - - async def run(self): - """执行去重""" - start_time = datetime.now() - - print("="*70) - print("记忆去重工具") - print("="*70) - print(f"数据目录: {self.data_dir}") - print(f"相似度阈值: {self.threshold}") - print(f"模式: {'预览模式(不实际删除)' if self.dry_run else '执行模式(会实际删除)'}") - print("="*70) - print() - - # 初始化 - await self.initialize() - - # 查找相似对 - similar_pairs = await self.find_similar_pairs() - - if not similar_pairs: - logger.info("未找到需要去重的相似记忆对") - print() - print("="*70) - print("未找到需要去重的记忆") - print("="*70) - return - - # 去重处理 - logger.info(f"开始{'预览' if self.dry_run else '执行'}去重...") - print() - - processed_pairs = set() # 避免重复处理 - - for mem_id_1, mem_id_2, similarity in similar_pairs: - # 检查是否已处理(可能一个记忆已被删除) - pair_key = tuple(sorted([mem_id_1, mem_id_2])) - if pair_key in processed_pairs: - continue - - # 检查记忆是否仍存在 - if not self.manager.graph_store.get_memory_by_id(mem_id_1): - logger.debug(f"记忆 {mem_id_1} 已不存在,跳过") - continue - if not self.manager.graph_store.get_memory_by_id(mem_id_2): - logger.debug(f"记忆 {mem_id_2} 已不存在,跳过") - continue - - # 执行去重 - success = await self.deduplicate_pair(mem_id_1, mem_id_2, similarity) - - if success: - processed_pairs.add(pair_key) - - # 保存数据(如果不是干运行) - if not self.dry_run: - logger.info("正在保存数据...") - await self.manager.persistence.save_graph_store(self.manager.graph_store) - logger.info("✅ 数据已保存") - - # 统计报告 - elapsed = (datetime.now() - start_time).total_seconds() - - print() - print("="*70) - print("去重报告") - print("="*70) - print(f"总记忆数: {self.stats['total_memories']}") - print(f"相似记忆对: {self.stats['similar_pairs']}") - print(f"发现重复: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}") - print(f"{'预览通过' if self.dry_run else '成功删除'}: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}") - print(f"错误数: {self.stats['errors']}") - print(f"耗时: {elapsed:.2f}秒") - - if self.dry_run: - print() - print("⚠️ 这是预览模式,未实际删除任何记忆") - print("💡 要执行实际删除,请运行: python scripts/deduplicate_memories.py") - else: - print() - print("✅ 去重完成!") - final_count = len(self.manager.graph_store.get_all_memories()) - print(f"📊 最终记忆数: {final_count} (减少 {self.stats['total_memories'] - final_count} 条)") - - print("="*70) - - async def cleanup(self): - """清理资源""" - if self.manager: - await shutdown_memory_manager() - - -async def main(): - """主函数""" - parser = argparse.ArgumentParser( - description="记忆去重工具 - 对标记为相似的记忆进行一键去重", - formatter_class=argparse.RawDescriptionHelpFormatter, - epilog=""" -示例: - # 预览模式(推荐先运行) - python scripts/deduplicate_memories.py --dry-run - - # 执行去重 - python scripts/deduplicate_memories.py - - # 指定相似度阈值(只处理相似度>=0.9的记忆对) - python scripts/deduplicate_memories.py --threshold 0.9 - - # 指定数据目录 - python scripts/deduplicate_memories.py --data-dir data/memory_graph - - # 组合使用 - python scripts/deduplicate_memories.py --dry-run --threshold 0.95 --data-dir data/test - """ - ) - - parser.add_argument( - "--dry-run", - action="store_true", - help="预览模式,不实际删除记忆(推荐先运行此模式)" - ) - - parser.add_argument( - "--threshold", - type=float, - default=0.85, - help="相似度阈值,只处理相似度>=此值的记忆对(默认: 0.85)" - ) - - parser.add_argument( - "--data-dir", - type=str, - default="data/memory_graph", - help="记忆数据目录(默认: data/memory_graph)" - ) - - args = parser.parse_args() - - # 创建去重器 - deduplicator = MemoryDeduplicator( - data_dir=args.data_dir, - dry_run=args.dry_run, - threshold=args.threshold - ) - - try: - # 执行去重 - await deduplicator.run() - except KeyboardInterrupt: - print("\n\n⚠️ 用户中断操作") - except Exception as e: - logger.error(f"执行失败: {e}") - print(f"\n❌ 执行失败: {e}") - return 1 - finally: - # 清理资源 - await deduplicator.cleanup() - - return 0 - - -if __name__ == "__main__": - sys.exit(asyncio.run(main())) diff --git a/scripts/expression_stats.py b/scripts/expression_stats.py deleted file mode 100644 index abf5eb870..000000000 --- a/scripts/expression_stats.py +++ /dev/null @@ -1,195 +0,0 @@ -import os -import sys -import time - -# Add project root to Python path -from src.common.database.database_model import ChatStreams, Expression - -project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -sys.path.insert(0, project_root) - - -def get_chat_name(chat_id: str) -> str: - """Get chat name from chat_id by querying ChatStreams table directly""" - try: - # 直接从数据库查询ChatStreams表 - chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id) - if chat_stream is None: - return f"未知聊天 ({chat_id})" - - # 如果有群组信息,显示群组名称 - if chat_stream.group_name: - return f"{chat_stream.group_name} ({chat_id})" - # 如果是私聊,显示用户昵称 - elif chat_stream.user_nickname: - return f"{chat_stream.user_nickname}的私聊 ({chat_id})" - else: - return f"未知聊天 ({chat_id})" - except Exception: - return f"查询失败 ({chat_id})" - - -def calculate_time_distribution(expressions) -> dict[str, int]: - """Calculate distribution of last active time in days""" - now = time.time() - distribution = { - "0-1天": 0, - "1-3天": 0, - "3-7天": 0, - "7-14天": 0, - "14-30天": 0, - "30-60天": 0, - "60-90天": 0, - "90+天": 0, - } - for expr in expressions: - diff_days = (now - expr.last_active_time) / (24 * 3600) - if diff_days < 1: - distribution["0-1天"] += 1 - elif diff_days < 3: - distribution["1-3天"] += 1 - elif diff_days < 7: - distribution["3-7天"] += 1 - elif diff_days < 14: - distribution["7-14天"] += 1 - elif diff_days < 30: - distribution["14-30天"] += 1 - elif diff_days < 60: - distribution["30-60天"] += 1 - elif diff_days < 90: - distribution["60-90天"] += 1 - else: - distribution["90+天"] += 1 - return distribution - - -def calculate_count_distribution(expressions) -> dict[str, int]: - """Calculate distribution of count values""" - distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0} - for expr in expressions: - cnt = expr.count - if cnt < 1: - distribution["0-1"] += 1 - elif cnt < 2: - distribution["1-2"] += 1 - elif cnt < 3: - distribution["2-3"] += 1 - elif cnt < 4: - distribution["3-4"] += 1 - elif cnt < 5: - distribution["4-5"] += 1 - elif cnt < 10: - distribution["5-10"] += 1 - else: - distribution["10+"] += 1 - return distribution - - -def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> list[Expression]: - """Get top N most used expressions for a specific chat_id""" - return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n) - - -def show_overall_statistics(expressions, total: int) -> None: - """Show overall statistics""" - time_dist = calculate_time_distribution(expressions) - count_dist = calculate_count_distribution(expressions) - - print("\n=== 总体统计 ===") - print(f"总表达式数量: {total}") - - print("\n上次激活时间分布:") - for period, count in time_dist.items(): - print(f"{period}: {count} ({count / total * 100:.2f}%)") - - print("\ncount分布:") - for range_, count in count_dist.items(): - print(f"{range_}: {count} ({count / total * 100:.2f}%)") - - -def show_chat_statistics(chat_id: str, chat_name: str) -> None: - """Show statistics for a specific chat""" - chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id)) - chat_total = len(chat_exprs) - - print(f"\n=== {chat_name} ===") - print(f"表达式数量: {chat_total}") - - if chat_total == 0: - print("该聊天没有表达式数据") - return - - # Time distribution for this chat - time_dist = calculate_time_distribution(chat_exprs) - print("\n上次激活时间分布:") - for period, count in time_dist.items(): - if count > 0: - print(f"{period}: {count} ({count / chat_total * 100:.2f}%)") - - # Count distribution for this chat - count_dist = calculate_count_distribution(chat_exprs) - print("\ncount分布:") - for range_, count in count_dist.items(): - if count > 0: - print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)") - - # Top expressions - print("\nTop 10使用最多的表达式:") - top_exprs = get_top_expressions_by_chat(chat_id, 10) - for i, expr in enumerate(top_exprs, 1): - print(f"{i}. [{expr.type}] Count: {expr.count}") - print(f" Situation: {expr.situation}") - print(f" Style: {expr.style}") - print() - - -def interactive_menu() -> None: - """Interactive menu for expression statistics""" - # Get all expressions - expressions = list(Expression.select()) - if not expressions: - print("数据库中没有找到表达式") - return - - total = len(expressions) - - # Get unique chat_ids and their names - chat_ids = list({expr.chat_id for expr in expressions}) - chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids] - chat_info.sort(key=lambda x: x[1]) # Sort by chat name - - while True: - print("\n" + "=" * 50) - print("表达式统计分析") - print("=" * 50) - print("0. 显示总体统计") - - for i, (chat_id, chat_name) in enumerate(chat_info, 1): - chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id) - print(f"{i}. {chat_name} ({chat_count}个表达式)") - - print("q. 退出") - - choice = input("\n请选择要查看的统计 (输入序号): ").strip() - - if choice.lower() == "q": - print("再见!") - break - - try: - choice_num = int(choice) - if choice_num == 0: - show_overall_statistics(expressions, total) - elif 1 <= choice_num <= len(chat_info): - chat_id, chat_name = chat_info[choice_num - 1] - show_chat_statistics(chat_id, chat_name) - else: - print("无效的选择,请重新输入") - except ValueError: - print("请输入有效的数字") - - input("\n按回车键继续...") - - -if __name__ == "__main__": - interactive_menu() diff --git a/scripts/extract_models.py b/scripts/extract_models.py deleted file mode 100644 index c97ca163c..000000000 --- a/scripts/extract_models.py +++ /dev/null @@ -1,66 +0,0 @@ -#!/usr/bin/env python3 -"""提取models.py中的模型定义""" - -import re - -# 读取原始文件 -with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f: - content = f.read() - -# 找到get_string_field函数的开始和结束 -get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数") -get_string_field_end = content.find("\n\nclass ChatStreams(Base):") -get_string_field = content[get_string_field_start:get_string_field_end] - -# 找到第一个class定义开始 -first_class_pos = content.find("class ChatStreams(Base):") - -# 找到所有class定义,直到遇到非class的def -# 简单策略:找到所有以"class "开头且继承Base的类 -classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)" -matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL)) - -if matches: - # 取最后一个匹配的结束位置 - models_content = content[first_class_pos:first_class_pos + matches[-1].end()] -else: - # 备用方案:从第一个class到文件的85%位置 - models_end = int(len(content) * 0.85) - models_content = content[first_class_pos:models_end] - -# 创建新文件内容 -header = '''"""SQLAlchemy数据库模型定义 - -本文件只包含纯模型定义,使用SQLAlchemy 2.0的Mapped类型注解风格。 -引擎和会话管理已移至core/engine.py和core/session.py。 - -所有模型使用统一的类型注解风格: - field_name: Mapped[PyType] = mapped_column(Type, ...) - -这样IDE/Pylance能正确推断实例属性类型。 -""" - -import datetime -import time - -from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import Mapped, mapped_column - -# 创建基类 -Base = declarative_base() - - -''' - -new_content = header + get_string_field + "\n\n" + models_content - -# 写入新文件 -with open("src/common/database/core/models.py", "w", encoding="utf-8") as f: - f.write(new_content) - -print("✅ Models file rewritten successfully") -print(f"File size: {len(new_content)} characters") -pattern = r"^class \w+\(Base\):" -model_count = len(re.findall(pattern, models_content, re.MULTILINE)) -print(f"Number of model classes: {model_count}") diff --git a/scripts/generate_missing_embeddings.py b/scripts/generate_missing_embeddings.py deleted file mode 100644 index 951db9cde..000000000 --- a/scripts/generate_missing_embeddings.py +++ /dev/null @@ -1,267 +0,0 @@ -""" -为现有节点生成嵌入向量 - -批量为图存储中缺少嵌入向量的节点生成并索引嵌入向量 - -使用场景: -1. 历史记忆节点没有嵌入向量 -2. 嵌入生成器之前未配置,现在需要补充生成 -3. 向量索引损坏需要重建 - -使用方法: - python scripts/generate_missing_embeddings.py [--node-types TOPIC,OBJECT] [--batch-size 50] - -参数说明: - --node-types: 需要生成嵌入的节点类型,默认为 TOPIC,OBJECT - --batch-size: 批量处理大小,默认为 50 -""" - -import asyncio -import sys -from pathlib import Path - -# 添加项目根目录到路径 -sys.path.insert(0, str(Path(__file__).parent.parent)) - - -async def generate_missing_embeddings( - target_node_types: list[str] | None = None, - batch_size: int = 50, -): - """ - 为缺失嵌入向量的节点生成嵌入 - - Args: - target_node_types: 需要处理的节点类型列表(如 ["主题", "客体"]) - batch_size: 批处理大小 - """ - from src.common.logger import get_logger - from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager - from src.memory_graph.models import NodeType - - logger = get_logger("generate_missing_embeddings") - - if target_node_types is None: - target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value] - - print(f"\n{'='*80}") - print("🔧 为节点生成嵌入向量") - print(f"{'='*80}\n") - print(f"目标节点类型: {', '.join(target_node_types)}") - print(f"批处理大小: {batch_size}\n") - - # 1. 初始化记忆管理器 - print("🔧 正在初始化记忆管理器...") - await initialize_memory_manager() - manager = get_memory_manager() - - if manager is None: - print("❌ 记忆管理器初始化失败") - return - - print("✅ 记忆管理器已初始化\n") - - # 2. 获取已索引的节点ID - print("🔍 检查现有向量索引...") - existing_node_ids = set() - try: - vector_count = manager.vector_store.collection.count() - if vector_count > 0: - # 分批获取所有已索引的ID - batch_size_check = 1000 - for offset in range(0, vector_count, batch_size_check): - limit = min(batch_size_check, vector_count - offset) - result = manager.vector_store.collection.get( - limit=limit, - offset=offset, - ) - if result and "ids" in result: - existing_node_ids.update(result["ids"]) - - print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n") - except Exception as e: - logger.warning(f"获取已索引节点ID失败: {e}") - print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n") - - # 3. 收集需要生成嵌入的节点 - print("🔍 扫描需要生成嵌入的节点...") - all_memories = manager.graph_store.get_all_memories() - - nodes_to_process = [] - total_target_nodes = 0 - type_stats = {nt: {"total": 0, "need_emb": 0, "already_indexed": 0} for nt in target_node_types} - - for memory in all_memories: - for node in memory.nodes: - if node.node_type.value in target_node_types: - total_target_nodes += 1 - type_stats[node.node_type.value]["total"] += 1 - - # 检查是否已在向量索引中 - if node.id in existing_node_ids: - type_stats[node.node_type.value]["already_indexed"] += 1 - continue - - if not node.has_embedding(): - nodes_to_process.append({ - "node": node, - "memory_id": memory.id, - }) - type_stats[node.node_type.value]["need_emb"] += 1 - - print("\n📊 扫描结果:") - for node_type in target_node_types: - stats = type_stats[node_type] - already_ok = stats["already_indexed"] - coverage = (stats["total"] - stats["need_emb"]) / stats["total"] * 100 if stats["total"] > 0 else 0 - print(f" - {node_type}: {stats['total']} 个节点, {stats['need_emb']} 个缺失嵌入, " - f"{already_ok} 个已索引 (覆盖率: {coverage:.1f}%)") - - print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n") - - if len(nodes_to_process) == 0: - print("✅ 所有节点已有嵌入向量,无需生成") - return - - # 3. 批量生成嵌入 - print("🚀 开始生成嵌入向量...\n") - - total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size - success_count = 0 - failed_count = 0 - indexed_count = 0 - - for i in range(0, len(nodes_to_process), batch_size): - batch = nodes_to_process[i : i + batch_size] - batch_num = i // batch_size + 1 - - print(f"📦 批次 {batch_num}/{total_batches} ({len(batch)} 个节点)...") - - try: - # 提取文本内容 - texts = [item["node"].content for item in batch] - - # 批量生成嵌入 - embeddings = await manager.embedding_generator.generate_batch(texts) - - # 为节点设置嵌入并索引 - batch_nodes_for_index = [] - - for j, (item, embedding) in enumerate(zip(batch, embeddings)): - node = item["node"] - - if embedding is not None: - # 设置嵌入向量 - node.embedding = embedding - batch_nodes_for_index.append(node) - success_count += 1 - else: - failed_count += 1 - logger.warning(f" ⚠️ 节点 {node.id[:8]}... '{node.content[:30]}' 嵌入生成失败") - - # 批量索引到向量数据库 - if batch_nodes_for_index: - try: - await manager.vector_store.add_nodes_batch(batch_nodes_for_index) - indexed_count += len(batch_nodes_for_index) - print(f" ✅ 成功: {len(batch_nodes_for_index)}/{len(batch)} 个节点已生成并索引") - except Exception as e: - # 如果批量失败,尝试逐个添加(跳过重复) - logger.warning(f" 批量索引失败,尝试逐个添加: {e}") - individual_success = 0 - for node in batch_nodes_for_index: - try: - await manager.vector_store.add_node(node) - individual_success += 1 - indexed_count += 1 - except Exception as e2: - if "Expected IDs to be unique" in str(e2): - logger.debug(f" 跳过已存在节点: {node.id}") - else: - logger.error(f" 节点 {node.id} 索引失败: {e2}") - print(f" ⚠️ 逐个索引: {individual_success}/{len(batch_nodes_for_index)} 个成功") - - except Exception as e: - failed_count += len(batch) - logger.error(f"批次 {batch_num} 处理失败") - print(f" ❌ 批次处理失败: {e}") - - # 显示进度 - total_processed = min(i + batch_size, len(nodes_to_process)) - progress = total_processed / len(nodes_to_process) * 100 - print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n") - - # 4. 保存图数据(更新节点的 embedding 字段) - print("💾 保存图数据...") - try: - await manager.persistence.save_graph_store(manager.graph_store) - print("✅ 图数据已保存\n") - except Exception as e: - logger.error("保存图数据失败") - print(f"❌ 保存失败: {e}\n") - - # 5. 验证结果 - print("🔍 验证向量索引...") - final_vector_count = manager.vector_store.collection.count() - stats = manager.graph_store.get_statistics() - total_nodes = stats["total_nodes"] - - print(f"\n{'='*80}") - print("📊 生成完成") - print(f"{'='*80}") - print(f"处理节点数: {len(nodes_to_process)}") - print(f"成功生成: {success_count}") - print(f"失败数量: {failed_count}") - print(f"成功索引: {indexed_count}") - print(f"向量索引节点数: {final_vector_count}") - print(f"图存储节点数: {total_nodes}") - print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n") - - # 6. 测试搜索 - print("🧪 测试搜索功能...") - test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"] - - for query in test_queries: - results = await manager.search_memories(query=query, top_k=3) - if results: - print(f"\n✅ 查询 '{query}' 找到 {len(results)} 条记忆:") - for i, memory in enumerate(results[:2], 1): - subject_node = memory.get_subject_node() - # 获取主题节点(遍历所有节点找TOPIC类型) - from src.memory_graph.models import NodeType - topic_nodes = [n for n in memory.nodes if n.node_type == NodeType.TOPIC] - subject = subject_node.content if subject_node else "?" - topic = topic_nodes[0].content if topic_nodes else "?" - print(f" {i}. {subject} - {topic} (重要性: {memory.importance:.2f})") - else: - print(f"\n⚠️ 查询 '{query}' 返回 0 条结果") - - -async def main(): - import argparse - - parser = argparse.ArgumentParser(description="为节点生成嵌入向量") - parser.add_argument( - "--node-types", - type=str, - default="主题,客体", - help="需要生成嵌入的节点类型,逗号分隔(默认:主题,客体)", - ) - parser.add_argument( - "--batch-size", - type=int, - default=50, - help="批处理大小(默认:50)", - ) - - args = parser.parse_args() - - target_types = [t.strip() for t in args.node_types.split(",")] - await generate_missing_embeddings( - target_node_types=target_types, - batch_size=args.batch_size, - ) - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/scripts/migrate_database.py b/scripts/migrate_database.py new file mode 100644 index 000000000..442e994dd --- /dev/null +++ b/scripts/migrate_database.py @@ -0,0 +1,1051 @@ +#!/usr/bin/env python3 +"""数据库迁移脚本 + +支持在不同数据库之间迁移数据: +- SQLite <-> MySQL +- SQLite <-> PostgreSQL +- MySQL <-> PostgreSQL + +使用方法: + python scripts/migrate_database.py --help + python scripts/migrate_database.py --source sqlite --target postgresql + python scripts/migrate_database.py --source mysql --target postgresql --batch-size 5000 + +注意事项: +1. 迁移前请备份源数据库 +2. 目标数据库应该是空的或不存在的(脚本会自动创建表) +3. 迁移过程可能需要较长时间,请耐心等待 + +实现细节: +- 使用 SQLAlchemy 进行数据库连接和元数据管理 +- 采用流式迁移,避免一次性加载过多数据 +- 支持 SQLite、MySQL、PostgreSQL 之间的互相迁移 +""" + +from __future__ import annotations + +import argparse +import logging +import os +import sys +from getpass import getpass + +# ============================================================================= +# 设置日志 +# ============================================================================= + +logging.basicConfig( + level=logging.INFO, + format="[%(levelname)s] %(message)s", +) + +logger = logging.getLogger(__name__) + +# ============================================================================= +# 导入第三方库(延迟导入以便友好报错) +# ============================================================================= + +try: + import tomllib +except ImportError: + tomllib = None + +from typing import Any, Iterable, Callable + +from sqlalchemy import ( + create_engine, + MetaData, + Table, + inspect, + text, +) +from sqlalchemy.engine import Engine, Connection +from sqlalchemy.exc import SQLAlchemyError + +# ====== 为了在 Windows 上更友好的输出中文,提前设置环境 ====== +# 有些 Windows 终端默认编码不是 UTF-8,这里做个兼容 +if os.name == "nt": + try: + import ctypes + + ctypes.windll.kernel32.SetConsoleOutputCP(65001) + except Exception: + pass + + +# ============================================================================= +# 配置相关工具 +# ============================================================================= + + +def get_project_root() -> str: + """获取项目根目录(当前脚本的上级目录)""" + return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + + +PROJECT_ROOT = get_project_root() + + +def load_bot_config() -> dict: + """加载 config/bot_config.toml 配置文件 + + 返回: + dict: 配置字典,如果文件不存在或解析失败,则返回空字典 + """ + config_path = os.path.join(PROJECT_ROOT, "config", "bot_config.toml") + if not os.path.exists(config_path): + logger.warning("配置文件不存在: %s", config_path) + return {} + + if tomllib is None: + logger.warning("当前 Python 版本不支持 tomllib,请使用 Python 3.11+ 或手动安装 tomli") + return {} + + try: + with open(config_path, "rb") as f: + config = tomllib.load(f) + return config + except Exception as e: + logger.error("解析配置文件失败: %s", e) + return {} + + +def get_database_config_from_toml(db_type: str) -> dict | None: + """从 bot_config.toml 中读取数据库配置 + + Args: + db_type: 数据库类型,支持 "sqlite"、"mysql"、"postgresql" + + Returns: + dict: 数据库配置字典,如果对应配置不存在则返回 None + """ + config_data = load_bot_config() + if not config_data: + return None + + # 兼容旧结构和新结构 + # 旧结构: 顶层直接有 db_type 相关字段 + # 新结构: 在 [database] 下有 db_type 相关字段 + db_config = config_data.get("database", {}) + + if db_type == "sqlite": + sqlite_path = ( + db_config.get("sqlite_path") + or config_data.get("sqlite_path") + or "data/MaiBot.db" + ) + if not os.path.isabs(sqlite_path): + sqlite_path = os.path.join(PROJECT_ROOT, sqlite_path) + return {"path": sqlite_path} + + elif db_type == "mysql": + return { + "host": db_config.get("mysql_host") + or config_data.get("mysql_host") + or "localhost", + "port": db_config.get("mysql_port") + or config_data.get("mysql_port") + or 3306, + "database": db_config.get("mysql_database") + or config_data.get("mysql_database") + or "maibot", + "user": db_config.get("mysql_user") + or config_data.get("mysql_user") + or "root", + "password": db_config.get("mysql_password") + or config_data.get("mysql_password") + or "", + "charset": db_config.get("mysql_charset") + or config_data.get("mysql_charset") + or "utf8mb4", + } + + elif db_type == "postgresql": + return { + "host": db_config.get("postgresql_host") + or config_data.get("postgresql_host") + or "localhost", + "port": db_config.get("postgresql_port") + or config_data.get("postgresql_port") + or 5432, + "database": db_config.get("postgresql_database") + or config_data.get("postgresql_database") + or "maibot", + "user": db_config.get("postgresql_user") + or config_data.get("postgresql_user") + or "postgres", + "password": db_config.get("postgresql_password") + or config_data.get("postgresql_password") + or "", + "schema": db_config.get("postgresql_schema") + or config_data.get("postgresql_schema") + or "public", + } + + return None + + +# ============================================================================= +# 数据库连接相关 +# ============================================================================= + + +def create_sqlite_engine(sqlite_path: str) -> Engine: + """创建 SQLite 引擎""" + if not os.path.isabs(sqlite_path): + sqlite_path = os.path.join(PROJECT_ROOT, sqlite_path) + + # 确保目录存在 + os.makedirs(os.path.dirname(sqlite_path), exist_ok=True) + + url = f"sqlite:///{sqlite_path}" + logger.info("使用 SQLite 数据库: %s", sqlite_path) + return create_engine(url, future=True) + + +def create_mysql_engine( + host: str, + port: int, + database: str, + user: str, + password: str, + charset: str = "utf8mb4", +) -> Engine: + """创建 MySQL 引擎""" + # 延迟导入 pymysql,以便友好提示 + try: + import pymysql # noqa: F401 + except ImportError: + logger.error("需要安装 pymysql 才能连接 MySQL: pip install pymysql") + raise + + url = f"mysql+pymysql://{user}:{password}@{host}:{port}/{database}?charset={charset}" + logger.info("使用 MySQL 数据库: %s@%s:%s/%s", user, host, port, database) + return create_engine(url, future=True) + + +def create_postgresql_engine( + host: str, + port: int, + database: str, + user: str, + password: str, + schema: str = "public", +) -> Engine: + """创建 PostgreSQL 引擎""" + # 在导入 psycopg2 之前设置环境变量,解决 Windows 编码问题 + # psycopg2 在 Windows 上连接时,如果客户端编码与服务器不一致可能会有问题 + os.environ.setdefault("PGCLIENTENCODING", "utf-8") + + # 延迟导入 psycopg2,以便友好提示 + try: + import psycopg2 # noqa: F401 + except ImportError: + logger.error("需要安装 psycopg2-binary 才能连接 PostgreSQL: pip install psycopg2-binary") + raise + + url = f"postgresql+psycopg2://{user}:{password}@{host}:{port}/{database}" + logger.info("使用 PostgreSQL 数据库: %s@%s:%s/%s (schema=%s)", user, host, port, database, schema) + engine = create_engine(url, future=True) + # 为了方便,设置 search_path + with engine.connect() as conn: + conn.execute(text(f"SET search_path TO {schema}")) + return engine + + +def create_engine_by_type(db_type: str, config: dict) -> Engine: + """根据数据库类型创建对应的 SQLAlchemy Engine + + Args: + db_type: 数据库类型,支持 sqlite/mysql/postgresql + config: 配置字典 + + Returns: + Engine: SQLAlchemy 引擎实例 + """ + db_type = db_type.lower() + if db_type == "sqlite": + return create_sqlite_engine(config["path"]) + elif db_type == "mysql": + return create_mysql_engine( + host=config["host"], + port=config["port"], + database=config["database"], + user=config["user"], + password=config["password"], + charset=config.get("charset", "utf8mb4"), + ) + elif db_type == "postgresql": + return create_postgresql_engine( + host=config["host"], + port=config["port"], + database=config["database"], + user=config["user"], + password=config["password"], + schema=config.get("schema", "public"), + ) + else: + raise ValueError(f"不支持的数据库类型: {db_type}") + + +# ============================================================================= +# 工具函数 +# ============================================================================= + + +def chunked_iterable(iterable: Iterable, size: int) -> Iterable[list]: + """将可迭代对象分块 + + Args: + iterable: 可迭代对象 + size: 每块大小 + + Yields: + list: 分块列表 + """ + chunk: list[Any] = [] + for item in iterable: + chunk.append(item) + if len(chunk) >= size: + yield chunk + chunk = [] + if chunk: + yield chunk + + +def get_table_row_count(conn: Connection, table: Table) -> int: + """获取表的行数""" + try: + result = conn.execute(text(f"SELECT COUNT(*) FROM {table.name}")) + return int(result.scalar() or 0) + except SQLAlchemyError as e: + logger.warning("获取表行数失败 %s: %s", table.name, e) + return 0 + + +def copy_table_structure(source_table: Table, target_metadata: MetaData, target_engine: Engine) -> Table: + """在目标数据库中创建与源表结构相同的表 + + Args: + source_table: 源表对象 + target_metadata: 目标元数据对象 + target_engine: 目标数据库引擎 + + Returns: + Table: 目标表对象 + """ + # 复制表结构 + target_table = Table( + source_table.name, + target_metadata, + *[c.copy() for c in source_table.columns], + *[c.copy() for c in source_table.constraints], + ) + target_metadata.create_all(target_engine, tables=[target_table]) + return target_table + + +def migrate_table_data( + source_conn: Connection, + target_conn: Connection, + source_table: Table, + target_table: Table, + batch_size: int = 1000, +) -> tuple[int, int]: + """迁移单个表的数据 + + Args: + source_conn: 源数据库连接 + target_conn: 目标数据库连接 + source_table: 源表对象 + target_table: 目标表对象 + batch_size: 每批次处理大小 + + Returns: + tuple[int, int]: (迁移行数, 错误数量) + """ + total_rows = get_table_row_count(source_conn, source_table) + logger.info( + "开始迁移表: %s (共 %s 行)", + source_table.name, + total_rows if total_rows else "未知", + ) + + migrated_rows = 0 + error_count = 0 + + # 使用流式查询,避免一次性加载太多数据 + # 对于 SQLAlchemy 1.4/2.0 可以使用 yield_per + try: + select_stmt = source_table.select() + result = source_conn.execute(select_stmt) + except SQLAlchemyError as e: + logger.error("查询表 %s 失败: %s", source_table.name, e) + return 0, 1 + + columns = source_table.columns.keys() + + def insert_batch(rows: list[dict]): + nonlocal migrated_rows, error_count + if not rows: + return + try: + target_conn.execute(target_table.insert(), rows) + migrated_rows += len(rows) + logger.info(" 已迁移 %d/%s 行", migrated_rows, total_rows or "?") + except SQLAlchemyError as e: + logger.error("写入表 %s 失败: %s", target_table.name, e) + error_count += len(rows) + + batch: list[dict] = [] + for row in result: + row_dict = {col: row[col] for col in columns} + batch.append(row_dict) + if len(batch) >= batch_size: + insert_batch(batch) + batch = [] + + if batch: + insert_batch(batch) + + logger.info( + "完成迁移表: %s (成功: %d 行, 失败: %d 行)", + source_table.name, + migrated_rows, + error_count, + ) + + return migrated_rows, error_count + + +def confirm_action(prompt: str, default: bool = False) -> bool: + """确认操作 + + Args: + prompt: 提示信息 + default: 默认值 + + Returns: + bool: 用户是否确认 + """ + while True: + if default: + choice = input(f"{prompt} [Y/n]: ").strip().lower() + if choice == "": + return True + else: + choice = input(f"{prompt} [y/N]: ").strip().lower() + if choice == "": + return False + + if choice in ("y", "yes"): + return True + elif choice in ("n", "no"): + return False + else: + print("请输入 y 或 n") + + +# ============================================================================= +# 迁移器实现 +# ============================================================================= + + +class DatabaseMigrator: + """通用数据库迁移器""" + + def __init__( + self, + source_type: str, + target_type: str, + batch_size: int = 1000, + source_config: dict | None = None, + target_config: dict | None = None, + ): + """初始化迁移器 + + Args: + source_type: 源数据库类型 + target_type: 目标数据库类型 + batch_size: 批量处理大小 + source_config: 源数据库配置(可选,默认从配置文件读取) + target_config: 目标数据库配置(可选,需要手动指定) + """ + self.source_type = source_type.lower() + self.target_type = target_type.lower() + self.batch_size = batch_size + self.source_config = source_config + self.target_config = target_config + + self._validate_database_types() + + self.source_engine: Any = None + self.target_engine: Any = None + self.metadata = MetaData() + + # 统计信息 + self.stats = { + "tables_migrated": 0, + "rows_migrated": 0, + "errors": [], + "start_time": None, + "end_time": None, + } + + def _validate_database_types(self): + """验证数据库类型""" + supported_types = {"sqlite", "mysql", "postgresql"} + if self.source_type not in supported_types: + raise ValueError(f"不支持的源数据库类型: {self.source_type}") + if self.target_type not in supported_types: + raise ValueError(f"不支持的目标数据库类型: {self.target_type}") + + def _load_source_config(self) -> dict: + """加载源数据库配置 + + 如果初始化时提供了 source_config,则直接使用; + 否则从 bot_config.toml 中读取。 + """ + if self.source_config: + logger.info("使用传入的源数据库配置") + return self.source_config + + logger.info("未提供源数据库配置,尝试从 bot_config.toml 读取") + config = get_database_config_from_toml(self.source_type) + if not config: + raise ValueError("无法从配置文件中读取源数据库配置,请检查 config/bot_config.toml") + + logger.info("成功从配置文件读取源数据库配置") + return config + + def _load_target_config(self) -> dict: + """加载目标数据库配置 + + 目标数据库配置必须通过初始化参数提供,或者通过命令行参数构建。 + """ + if not self.target_config: + raise ValueError("未提供目标数据库配置,请通过命令行参数指定或在交互模式中输入") + logger.info("使用传入的目标数据库配置") + return self.target_config + + def _connect_databases(self): + """连接源数据库和目标数据库""" + # 源数据库配置 + source_config = self._load_source_config() + # 目标数据库配置 + target_config = self._load_target_config() + + # 创建引擎 + self.source_engine = create_engine_by_type(self.source_type, source_config) + self.target_engine = create_engine_by_type(self.target_type, target_config) + + # 反射源数据库元数据 + logger.info("正在反射源数据库元数据...") + self.metadata.reflect(bind=self.source_engine) + logger.info("发现 %d 张表: %s", len(self.metadata.tables), ", ".join(self.metadata.tables.keys())) + + def _get_tables_in_dependency_order(self) -> list[Table]: + """获取按依赖顺序排序的表列表 + + 为了避免外键约束问题,创建表时需要按照依赖顺序, + 例如先创建被引用的表,再创建引用它们的表。 + """ + inspector = inspect(self.source_engine) + + # 构建依赖图:table -> set(dependent_tables) + dependencies: dict[str, set[str]] = {} + for table_name in self.metadata.tables: + dependencies[table_name] = set() + + for table_name, table in self.metadata.tables.items(): + fks = inspector.get_foreign_keys(table_name) + for fk in fks: + # 被引用的表 + referred_table = fk["referred_table"] + if referred_table in dependencies: + dependencies[table_name].add(referred_table) + + # 拓扑排序 + sorted_tables: list[Table] = [] + visited: set[str] = set() + temp_mark: set[str] = set() + + def visit(table_name: str): + if table_name in visited: + return + if table_name in temp_mark: + logger.warning("检测到循环依赖,表: %s", table_name) + return + temp_mark.add(table_name) + for dep in dependencies[table_name]: + visit(dep) + temp_mark.remove(table_name) + visited.add(table_name) + sorted_tables.append(self.metadata.tables[table_name]) + + for table_name in dependencies: + if table_name not in visited: + visit(table_name) + + return sorted_tables + + def _drop_target_tables(self, conn: Connection): + """删除目标数据库中已经存在的表(谨慎操作) + + 这里为了避免冲突,迁移前会询问用户是否删除目标库中已经存在的同名表。 + """ + inspector = inspect(conn) + existing_tables = inspector.get_table_names() + + if not existing_tables: + logger.info("目标数据库中没有已存在的表,无需删除") + return + + logger.info("目标数据库中当前存在的表: %s", ", ".join(existing_tables)) + if confirm_action("是否删除目标数据库中已有的所有表?此操作不可恢复!", default=False): + with conn.begin(): + for table_name in existing_tables: + try: + logger.info("删除目标数据库中表: %s", table_name) + conn.execute(text(f"DROP TABLE IF EXISTS {table_name} CASCADE")) + except SQLAlchemyError as e: + logger.error("删除表 %s 失败: %s", table_name, e) + self.stats["errors"].append( + f"删除表 {table_name} 失败: {e}" + ) + else: + logger.info("用户选择保留目标数据库中已有的表,可能会与迁移数据发生冲突。") + + def migrate(self): + """执行迁移操作""" + import time + + self.stats["start_time"] = time.time() + + # 连接数据库 + self._connect_databases() + + # 获取表的依赖顺序 + tables = self._get_tables_in_dependency_order() + logger.info("按依赖顺序迁移表: %s", ", ".join(t.name for t in tables)) + + # 删除目标库中已有表(可选) + with self.target_engine.connect() as target_conn: + self._drop_target_tables(target_conn) + + # 开始迁移 + with self.source_engine.connect() as source_conn, self.target_engine.connect() as target_conn: + for source_table in tables: + try: + # 在目标库中创建表结构 + target_table = copy_table_structure(source_table, MetaData(), self.target_engine) + + # 迁移数据 + migrated_rows, error_count = migrate_table_data( + source_conn, + target_conn, + source_table, + target_table, + batch_size=self.batch_size, + ) + + self.stats["tables_migrated"] += 1 + self.stats["rows_migrated"] += migrated_rows + if error_count > 0: + self.stats["errors"].append( + f"表 {source_table.name} 迁移失败 {error_count} 行" + ) + + except Exception as e: + logger.error("迁移表 %s 时发生错误: %s", source_table.name, e) + self.stats["errors"].append(f"表 {source_table.name} 迁移失败: {e}") + + self.stats["end_time"] = time.time() + + def print_summary(self): + """打印迁移总结""" + import time + + duration = None + if self.stats["start_time"] is not None and self.stats["end_time"] is not None: + duration = self.stats["end_time"] - self.stats["start_time"] + + print("\n" + "=" * 60) + print("迁移完成!") + print(f" 迁移表数量: {self.stats['tables_migrated']}") + print(f" 迁移行数量: {self.stats['rows_migrated']}") + if duration is not None: + print(f" 总耗时: {duration:.2f} 秒") + if self.stats["errors"]: + print(" ⚠️ 发生错误:") + for err in self.stats["errors"]: + print(f" - {err}") + else: + print(" 没有发生错误 🎉") + print("=" * 60 + "\n") + + def run(self): + """运行迁移并打印总结""" + self.migrate() + self.print_summary() + return self.stats + + +# ============================================================================= +# 命令行参数解析 +# ============================================================================= + + +def parse_args(): + """解析命令行参数""" + parser = argparse.ArgumentParser( + description="数据库迁移工具 - 在 SQLite、MySQL、PostgreSQL 之间迁移数据", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog="""示例: + # 从 SQLite 迁移到 PostgreSQL + python scripts/migrate_database.py \ + --source sqlite \ + --target postgresql \ + --target-host localhost \ + --target-port 5432 \ + --target-database maibot \ + --target-user postgres \ + --target-password your_password + + # 从 SQLite 迁移到 MySQL + python scripts/migrate_database.py \ + --source sqlite \ + --target mysql \ + --target-host localhost \ + --target-port 3306 \ + --target-database maibot \ + --target-user root \ + --target-password your_password + + # 使用交互式向导模式(推荐) + python scripts/migrate_database.py + python scripts/migrate_database.py --interactive + """, + ) + + # 基本参数 + parser.add_argument( + "--source", + type=str, + choices=["sqlite", "mysql", "postgresql"], + help="源数据库类型(不指定时,在交互模式中选择)", + ) + parser.add_argument( + "--target", + type=str, + choices=["sqlite", "mysql", "postgresql"], + help="目标数据库类型(不指定时,在交互模式中选择)", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1000, + help="批量处理大小(默认: 1000)", + ) + + parser.add_argument( + "--interactive", + action="store_true", + help="启用交互式向导模式(推荐:直接运行脚本或加上此参数)", + ) + + # 源数据库参数(可选,默认从 bot_config.toml 读取) + source_group = parser.add_argument_group("源数据库配置(可选,默认从 bot_config.toml 读取)") + source_group.add_argument("--source-path", type=str, help="SQLite 数据库路径") + source_group.add_argument("--source-host", type=str, help="MySQL/PostgreSQL 主机") + source_group.add_argument("--source-port", type=int, help="MySQL/PostgreSQL 端口") + source_group.add_argument("--source-database", type=str, help="数据库名") + source_group.add_argument("--source-user", type=str, help="用户名") + source_group.add_argument("--source-password", type=str, help="密码") + + # 目标数据库参数 + target_group = parser.add_argument_group("目标数据库配置") + target_group.add_argument("--target-path", type=str, help="SQLite 数据库路径") + target_group.add_argument("--target-host", type=str, help="MySQL/PostgreSQL 主机") + target_group.add_argument("--target-port", type=int, help="MySQL/PostgreSQL 端口") + target_group.add_argument("--target-database", type=str, help="数据库名") + target_group.add_argument("--target-user", type=str, help="用户名") + target_group.add_argument("--target-password", type=str, help="密码") + target_group.add_argument("--target-schema", type=str, default="public", help="PostgreSQL schema") + target_group.add_argument("--target-charset", type=str, default="utf8mb4", help="MySQL 字符集") + + return parser.parse_args() + + +def build_config_from_args(args, prefix: str, db_type: str) -> dict | None: + """从命令行参数构建配置 + + Args: + args: 命令行参数 + prefix: 参数前缀 ("source" 或 "target") + db_type: 数据库类型 + + Returns: + 配置字典或 None + """ + if db_type == "sqlite": + path = getattr(args, f"{prefix}_path", None) + if path: + return {"path": path} + return None + + elif db_type in ("mysql", "postgresql"): + host = getattr(args, f"{prefix}_host", None) + if not host: + return None + + config = { + "host": host, + "port": getattr(args, f"{prefix}_port") or (3306 if db_type == "mysql" else 5432), + "database": getattr(args, f"{prefix}_database") or "maibot", + "user": getattr(args, f"{prefix}_user") or ("root" if db_type == "mysql" else "postgres"), + "password": getattr(args, f"{prefix}_password") or "", + } + + if db_type == "mysql": + config["charset"] = getattr(args, f"{prefix}_charset", "utf8mb4") + elif db_type == "postgresql": + config["schema"] = getattr(args, f"{prefix}_schema", "public") + + return config + + return None + + +def _ask_choice(prompt: str, options: list[str], default_index: int | None = None) -> str: + """在控制台中让用户从多个选项中选择一个""" + while True: + print() + print(prompt) + for i, opt in enumerate(options, start=1): + default_mark = "" + if default_index is not None and i - 1 == default_index: + default_mark = " (默认)" + print(f" {i}) {opt}{default_mark}") + ans = input("请输入选项编号: ").strip() + if not ans and default_index is not None: + return options[default_index] + if ans.isdigit(): + idx = int(ans) + if 1 <= idx <= len(options): + return options[idx - 1] + print("❌ 无效的选择,请重新输入。") + + +def _ask_int(prompt: str, default: int | None = None) -> int: + """在控制台中输入正整数""" + while True: + suffix = f" (默认 {default})" if default is not None else "" + raw = input(f"{prompt}{suffix}: ").strip() + if not raw and default is not None: + return default + try: + value = int(raw) + if value <= 0: + raise ValueError() + return value + except ValueError: + print("❌ 请输入一个大于 0 的整数。") + + +def _ask_str( + prompt: str, + default: str | None = None, + allow_empty: bool = False, + is_password: bool = False, +) -> str: + """在控制台中输入字符串,可选默认值/密码输入""" + while True: + suffix = f" (默认: {default})" if default is not None else "" + full_prompt = f"{prompt}{suffix}: " + raw = getpass(full_prompt) if is_password else input(full_prompt) + raw = raw.strip() + if not raw: + if default is not None: + return default + if allow_empty: + return "" + print("❌ 输入不能为空,请重新输入。") + continue + return raw + + +def interactive_setup() -> dict: + """交互式向导,返回用于初始化 DatabaseMigrator 的参数字典""" + print("=" * 60) + print("🌟 数据库迁移向导") + print("只需回答几个问题,我会帮你构造迁移配置。") + print("=" * 60) + + db_types = ["sqlite", "mysql", "postgresql"] + + # 选择源数据库 + source_type = _ask_choice("请选择【源数据库类型】:", db_types, default_index=0) + + # 选择目标数据库(不能与源相同) + while True: + default_idx = 2 if len(db_types) >= 3 else 0 + target_type = _ask_choice("请选择【目标数据库类型】:", db_types, default_index=default_idx) + if target_type != source_type: + break + print("❌ 目标数据库不能和源数据库相同,请重新选择。") + + # 批量大小 + batch_size = _ask_int("请输入批量大小 batch-size", default=1000) + + # 源数据库配置:默认使用 bot_config.toml + print() + print("源数据库配置:") + print(" 默认会从 config/bot_config.toml 中读取对应配置。") + use_default_source = input("是否使用配置文件中的【源数据库】配置? [Y/n]: ").strip().lower() + if use_default_source in ("", "y", "yes"): + source_config = None # 让 DatabaseMigrator 自己去读配置 + else: + # 简单交互式配置源数据库 + print("请手动输入源数据库连接信息:") + if source_type == "sqlite": + source_path = _ask_str("源 SQLite 文件路径", default="data/MaiBot.db") + source_config = {"path": source_path} + else: + port_default = 3306 if source_type == "mysql" else 5432 + user_default = "root" if source_type == "mysql" else "postgres" + host = _ask_str("源数据库 host", default="localhost") + port = _ask_int("源数据库 port", default=port_default) + database = _ask_str("源数据库名", default="maibot") + user = _ask_str("源数据库用户名", default=user_default) + password = _ask_str("源数据库密码(输入时不回显)", default="", is_password=True) + source_config = { + "host": host, + "port": port, + "database": database, + "user": user, + "password": password, + } + if source_type == "mysql": + source_config["charset"] = _ask_str("源数据库字符集", default="utf8mb4") + elif source_type == "postgresql": + source_config["schema"] = _ask_str("源数据库 schema", default="public") + + # 目标数据库配置(必须显式确认) + print() + print("目标数据库配置:") + if target_type == "sqlite": + target_path = _ask_str( + "目标 SQLite 文件路径(若不存在会自动创建)", + default="data/MaiBot_target.db", + ) + target_config = {"path": target_path} + else: + port_default = 3306 if target_type == "mysql" else 5432 + user_default = "root" if target_type == "mysql" else "postgres" + host = _ask_str("目标数据库 host", default="localhost") + port = _ask_int("目标数据库 port", default=port_default) + database = _ask_str("目标数据库名", default="maibot") + user = _ask_str("目标数据库用户名", default=user_default) + password = _ask_str("目标数据库密码(输入时不回显)", default="", is_password=True) + + target_config = { + "host": host, + "port": port, + "database": database, + "user": user, + "password": password, + } + if target_type == "mysql": + target_config["charset"] = _ask_str("目标数据库字符集", default="utf8mb4") + elif target_type == "postgresql": + target_config["schema"] = _ask_str("目标数据库 schema", default="public") + + print() + print("=" * 60) + print("迁移配置确认:") + print(f" 源数据库类型: {source_type}") + print(f" 目标数据库类型: {target_type}") + print(f" 批量大小: {batch_size}") + print("⚠️ 请确认目标数据库为空或可以被覆盖,并且已备份源数据库。") + confirm = input("是否开始迁移?[Y/n]: ").strip().lower() + if confirm not in ("", "y", "yes"): + print("已取消迁移。") + sys.exit(0) + + return { + "source_type": source_type, + "target_type": target_type, + "batch_size": batch_size, + "source_config": source_config, + "target_config": target_config, + } + + +def main(): + """主函数""" + args = parse_args() + + # 如果没有任何参数,或者显式指定 --interactive,则进入交互模式 + if args.interactive or len(sys.argv) == 1: + params = interactive_setup() + try: + migrator = DatabaseMigrator(**params) + stats = migrator.run() + if stats["errors"]: + sys.exit(1) + return + except KeyboardInterrupt: + print("\n迁移被用户中断") + sys.exit(130) + except Exception as e: + print(f"迁移失败: {e}") + sys.exit(1) + + # 非交互模式:保持原有行为,但如果没给 source/target,就提示错误 + if not args.source or not args.target: + print("错误: 非交互模式下必须指定 --source 和 --target。") + print("你也可以直接运行脚本或添加 --interactive 使用交互式向导。") + sys.exit(2) + + # 构建配置 + source_config = build_config_from_args(args, "source", args.source) + target_config = build_config_from_args(args, "target", args.target) + + # 验证目标配置 + if target_config is None: + if args.target == "sqlite": + if not args.target_path: + print("错误: 目标数据库为 SQLite 时,必须指定 --target-path(或使用交互模式)") + sys.exit(1) + target_config = {"path": args.target_path} + else: + if not args.target_host: + print(f"错误: 目标数据库为 {args.target} 时,必须指定 --target-host(或使用交互模式)") + sys.exit(1) + + try: + migrator = DatabaseMigrator( + source_type=args.source, + target_type=args.target, + batch_size=args.batch_size, + source_config=source_config, + target_config=target_config, + ) + + stats = migrator.run() + + # 如果有错误,返回非零退出码 + if stats["errors"]: + sys.exit(1) + + except KeyboardInterrupt: + print("\n迁移被用户中断") + sys.exit(130) + except Exception as e: + print(f"迁移失败: {e}") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/scripts/rebuild_metadata_index.py b/scripts/rebuild_metadata_index.py deleted file mode 100644 index 5e2c6c87e..000000000 --- a/scripts/rebuild_metadata_index.py +++ /dev/null @@ -1,140 +0,0 @@ -#!/usr/bin/env python -""" -从现有ChromaDB数据重建JSON元数据索引 -""" - -import asyncio -import os -import sys - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry -from src.chat.memory_system.memory_system import MemorySystem -from src.common.logger import get_logger - -logger = get_logger(__name__) - - -async def rebuild_metadata_index(): - """从ChromaDB重建元数据索引""" - print("=" * 80) - print("重建JSON元数据索引") - print("=" * 80) - - # 初始化记忆系统 - print("\n🔧 初始化记忆系统...") - ms = MemorySystem() - await ms.initialize() - print("✅ 记忆系统已初始化") - - if not hasattr(ms.unified_storage, "metadata_index"): - print("❌ 元数据索引管理器未初始化") - return - - # 获取所有记忆 - print("\n📥 从ChromaDB获取所有记忆...") - from src.common.vector_db import vector_db_service - - try: - # 获取集合中的所有记忆ID - collection_name = ms.unified_storage.config.memory_collection - result = vector_db_service.get( - collection_name=collection_name, include=["documents", "metadatas", "embeddings"] - ) - - if not result or not result.get("ids"): - print("❌ ChromaDB中没有找到记忆数据") - return - - ids = result["ids"] - metadatas = result.get("metadatas", []) - - print(f"✅ 找到 {len(ids)} 条记忆") - - # 重建元数据索引 - print("\n🔨 开始重建元数据索引...") - entries = [] - success_count = 0 - - for i, (memory_id, metadata) in enumerate(zip(ids, metadatas, strict=False), 1): - try: - # 从ChromaDB元数据重建索引条目 - import orjson - - entry = MemoryMetadataIndexEntry( - memory_id=memory_id, - user_id=metadata.get("user_id", "unknown"), - memory_type=metadata.get("memory_type", "general"), - subjects=orjson.loads(metadata.get("subjects", "[]")), - objects=[metadata.get("object")] if metadata.get("object") else [], - keywords=orjson.loads(metadata.get("keywords", "[]")), - tags=orjson.loads(metadata.get("tags", "[]")), - importance=2, # 默认NORMAL - confidence=2, # 默认MEDIUM - created_at=metadata.get("created_at", 0.0), - access_count=metadata.get("access_count", 0), - chat_id=metadata.get("chat_id"), - content_preview=None, - ) - - # 尝试解析importance和confidence的枚举名称 - if "importance" in metadata: - imp_str = metadata["importance"] - if imp_str == "LOW": - entry.importance = 1 - elif imp_str == "NORMAL": - entry.importance = 2 - elif imp_str == "HIGH": - entry.importance = 3 - elif imp_str == "CRITICAL": - entry.importance = 4 - - if "confidence" in metadata: - conf_str = metadata["confidence"] - if conf_str == "LOW": - entry.confidence = 1 - elif conf_str == "MEDIUM": - entry.confidence = 2 - elif conf_str == "HIGH": - entry.confidence = 3 - elif conf_str == "VERIFIED": - entry.confidence = 4 - - entries.append(entry) - success_count += 1 - - if i % 100 == 0: - print(f" 处理进度: {i}/{len(ids)} ({success_count} 成功)") - - except Exception as e: - logger.warning(f"处理记忆 {memory_id} 失败: {e}") - continue - - print(f"\n✅ 成功解析 {success_count}/{len(ids)} 条记忆元数据") - - # 批量更新索引 - print("\n💾 保存元数据索引...") - ms.unified_storage.metadata_index.batch_add_or_update(entries) - ms.unified_storage.metadata_index.save() - - # 显示统计信息 - stats = ms.unified_storage.metadata_index.get_stats() - print("\n📊 重建后的索引统计:") - print(f" - 总记忆数: {stats['total_memories']}") - print(f" - 主语数量: {stats['subjects_count']}") - print(f" - 关键词数量: {stats['keywords_count']}") - print(f" - 标签数量: {stats['tags_count']}") - print(" - 类型分布:") - for mtype, count in stats["types"].items(): - print(f" - {mtype}: {count}") - - print("\n✅ 元数据索引重建完成!") - - except Exception as e: - logger.error(f"重建索引失败: {e}") - print(f"❌ 重建索引失败: {e}") - - -if __name__ == "__main__": - asyncio.run(rebuild_metadata_index()) diff --git a/scripts/test_mcp_integration.py b/scripts/test_mcp_integration.py deleted file mode 100644 index b5cfa2b28..000000000 --- a/scripts/test_mcp_integration.py +++ /dev/null @@ -1,190 +0,0 @@ -""" -MCP 集成测试脚本 - -测试 MCP 客户端连接、工具列表获取和工具调用功能 -""" - -import asyncio -import sys -from pathlib import Path - -# 添加项目根目录到 Python 路径 -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - -from src.common.logger import get_logger -from src.plugin_system.core.component_registry import ComponentRegistry -from src.plugin_system.core.mcp_client_manager import MCPClientManager - -logger = get_logger("test_mcp_integration") - - -async def test_mcp_client_manager(): - """测试 MCPClientManager 基本功能""" - print("\n" + "="*60) - print("测试 1: MCPClientManager 连接和工具列表") - print("="*60) - - try: - # 初始化 MCP 客户端管理器 - manager = MCPClientManager() - await manager.initialize() - - print("\n✓ MCP 客户端管理器初始化成功") - print(f"已连接服务器数量: {len(manager.clients)}") - - # 获取所有工具 - tools = await manager.get_all_tools() - print(f"\n获取到 {len(tools)} 个 MCP 工具:") - - for tool in tools: - print(f"\n 工具: {tool}") - # 注意: 这里 tool 是字符串形式的工具名称 - # 如果需要工具详情,需要从其他地方获取 - - return manager, tools - - except Exception as e: - print(f"\n✗ 测试失败: {e}") - logger.exception("MCPClientManager 测试失败") - return None, [] - - -async def test_tool_call(manager: MCPClientManager, tools): - """测试工具调用""" - print("\n" + "="*60) - print("测试 2: MCP 工具调用") - print("="*60) - - if not tools: - print("\n⚠ 没有可用的工具进行测试") - return - - try: - # 工具列表测试已在第一个测试中完成 - print("\n✓ 工具列表获取成功") - print(f"可用工具数量: {len(tools)}") - - except Exception as e: - print(f"\n✗ 工具调用测试失败: {e}") - logger.exception("工具调用测试失败") - - -async def test_component_registry_integration(): - """测试 ComponentRegistry 集成""" - print("\n" + "="*60) - print("测试 3: ComponentRegistry MCP 工具集成") - print("="*60) - - try: - registry = ComponentRegistry() - - # 加载 MCP 工具 - await registry.load_mcp_tools() - - # 获取 MCP 工具 - mcp_tools = registry.get_mcp_tools() - print(f"\n✓ ComponentRegistry 加载了 {len(mcp_tools)} 个 MCP 工具") - - for tool in mcp_tools: - print(f"\n 工具: {tool.name}") - print(f" 描述: {tool.description}") - print(f" 参数数量: {len(tool.parameters)}") - - # 测试 is_mcp_tool 方法 - is_mcp = registry.is_mcp_tool(tool.name) - print(f" is_mcp_tool 检测: {'✓' if is_mcp else '✗'}") - - return mcp_tools - - except Exception as e: - print(f"\n✗ ComponentRegistry 集成测试失败: {e}") - logger.exception("ComponentRegistry 集成测试失败") - return [] - - -async def test_tool_execution(mcp_tools): - """测试通过适配器执行工具""" - print("\n" + "="*60) - print("测试 4: MCPToolAdapter 工具执行") - print("="*60) - - if not mcp_tools: - print("\n⚠ 没有可用的 MCP 工具进行测试") - return - - try: - # 选择第一个工具测试 - test_tool = mcp_tools[0] - print(f"\n测试工具: {test_tool.name}") - - # 构建测试参数 - test_args = {} - for param_name, param_type, param_desc, is_required, enum_values in test_tool.parameters: - if is_required: - # 根据类型提供默认值 - from src.llm_models.payload_content.tool_option import ToolParamType - - if param_type == ToolParamType.STRING: - test_args[param_name] = "test_value" - elif param_type == ToolParamType.INTEGER: - test_args[param_name] = 1 - elif param_type == ToolParamType.FLOAT: - test_args[param_name] = 1.0 - elif param_type == ToolParamType.BOOLEAN: - test_args[param_name] = True - - print(f"测试参数: {test_args}") - - # 执行工具 - result = await test_tool.execute(test_args) - - if result: - print("\n✓ 工具执行成功") - print(f"结果类型: {result.get('type')}") - print(f"结果内容: {result.get('content', '')[:200]}...") # 只显示前200字符 - else: - print("\n✗ 工具执行失败,返回 None") - - except Exception as e: - print(f"\n✗ 工具执行测试失败: {e}") - logger.exception("工具执行测试失败") - - -async def main(): - """主测试流程""" - print("\n" + "="*60) - print("MCP 集成测试") - print("="*60) - - try: - # 测试 1: MCPClientManager 基本功能 - manager, tools = await test_mcp_client_manager() - - if manager: - # 测试 2: 工具调用 - await test_tool_call(manager, tools) - - # 测试 3: ComponentRegistry 集成 - mcp_tools = await test_component_registry_integration() - - # 测试 4: 工具执行 - await test_tool_execution(mcp_tools) - - # 关闭连接 - await manager.close() - print("\n✓ MCP 客户端连接已关闭") - - print("\n" + "="*60) - print("测试完成") - print("="*60 + "\n") - - except KeyboardInterrupt: - print("\n\n测试被用户中断") - except Exception as e: - print(f"\n测试过程中发生错误: {e}") - logger.exception("测试失败") - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/scripts/test_three_tier_memory.py b/scripts/test_three_tier_memory.py deleted file mode 100644 index 951135733..000000000 --- a/scripts/test_three_tier_memory.py +++ /dev/null @@ -1,292 +0,0 @@ -""" -三层记忆系统测试脚本 -用于验证系统各组件是否正常工作 -""" - -import asyncio -import sys -from pathlib import Path - -# 添加项目根目录到路径 -project_root = Path(__file__).parent.parent -sys.path.insert(0, str(project_root)) - - -async def test_perceptual_memory(): - """测试感知记忆层""" - print("\n" + "=" * 60) - print("测试1: 感知记忆层") - print("=" * 60) - - from src.memory_graph.three_tier.perceptual_manager import get_perceptual_manager - - manager = get_perceptual_manager() - await manager.initialize() - - # 添加测试消息 - test_messages = [ - ("user1", "今天天气真好", 1700000000.0), - ("user2", "是啊,适合出去玩", 1700000001.0), - ("user1", "我们去公园吧", 1700000002.0), - ("user2", "好主意!", 1700000003.0), - ("user1", "带上野餐垫", 1700000004.0), - ] - - for sender, content, timestamp in test_messages: - message = { - "message_id": f"msg_{timestamp}", - "sender": sender, - "content": content, - "timestamp": timestamp, - "platform": "test", - "stream_id": "test_stream", - } - await manager.add_message(message) - - print(f"✅ 成功添加 {len(test_messages)} 条消息") - - # 测试TopK召回 - results = await manager.recall_blocks("公园野餐", top_k=2) - print(f"✅ TopK召回返回 {len(results)} 个块") - - if results: - print(f" 第一个块包含 {len(results[0].messages)} 条消息") - - # 获取统计信息 - stats = manager.get_statistics() # 不是async方法 - print(f"✅ 统计信息: {stats}") - - return True - - -async def test_short_term_memory(): - """测试短期记忆层""" - print("\n" + "=" * 60) - print("测试2: 短期记忆层") - print("=" * 60) - - from src.memory_graph.three_tier.models import MemoryBlock - from src.memory_graph.three_tier.short_term_manager import get_short_term_manager - - manager = get_short_term_manager() - await manager.initialize() - - # 创建测试块 - test_block = MemoryBlock( - id="test_block_1", - messages=[ - { - "message_id": "msg1", - "sender": "user1", - "content": "我明天要参加一个重要的面试", - "timestamp": 1700000000.0, - "platform": "test", - } - ], - combined_text="我明天要参加一个重要的面试", - recall_count=3, - ) - - # 从感知块转换为短期记忆 - try: - await manager.add_from_block(test_block) - print("✅ 成功将感知块转换为短期记忆") - except Exception as e: - print(f"⚠️ 转换失败(可能需要LLM): {e}") - return False - - # 测试搜索 - results = await manager.search_memories("面试", top_k=3) - print(f"✅ 搜索返回 {len(results)} 条记忆") - - # 获取统计 - stats = manager.get_statistics() - print(f"✅ 统计信息: {stats}") - - return True - - -async def test_long_term_memory(): - """测试长期记忆层""" - print("\n" + "=" * 60) - print("测试3: 长期记忆层") - print("=" * 60) - - from src.memory_graph.three_tier.long_term_manager import get_long_term_manager - - manager = get_long_term_manager() - await manager.initialize() - - print("✅ 长期记忆管理器初始化成功") - print(" (需要现有记忆图系统支持)") - - # 获取统计 - stats = manager.get_statistics() - print(f"✅ 统计信息: {stats}") - - return True - - -async def test_unified_manager(): - """测试统一管理器""" - print("\n" + "=" * 60) - print("测试4: 统一管理器") - print("=" * 60) - - from src.memory_graph.three_tier.unified_manager import UnifiedMemoryManager - - manager = UnifiedMemoryManager() - await manager.initialize() - - # 添加测试消息 - message = { - "message_id": "unified_test_1", - "sender": "user1", - "content": "这是一条测试消息", - "timestamp": 1700000000.0, - "platform": "test", - "stream_id": "test_stream", - } - await manager.add_message(message) - - print("✅ 通过统一接口添加消息成功") - - # 测试搜索 - results = await manager.search_memories("测试") - print(f"✅ 统一搜索返回结果:") - print(f" 感知块: {len(results.get('perceptual_blocks', []))}") - print(f" 短期记忆: {len(results.get('short_term_memories', []))}") - print(f" 长期记忆: {len(results.get('long_term_memories', []))}") - - # 获取统计 - stats = manager.get_statistics() # 不是async方法 - print(f"✅ 综合统计:") - print(f" 感知层: {stats.get('perceptual', {})}") - print(f" 短期层: {stats.get('short_term', {})}") - print(f" 长期层: {stats.get('long_term', {})}") - - return True - - -async def test_configuration(): - """测试配置加载""" - print("\n" + "=" * 60) - print("测试5: 配置系统") - print("=" * 60) - - from src.config.config import global_config - - if not hasattr(global_config, "three_tier_memory"): - print("❌ 配置类中未找到 three_tier_memory 字段") - return False - - config = global_config.three_tier_memory - - if config is None: - print("⚠️ 三层记忆配置为 None(可能未在 bot_config.toml 中配置)") - print(" 请在 bot_config.toml 中添加 [three_tier_memory] 配置") - return False - - print(f"✅ 配置加载成功") - print(f" 启用状态: {config.enable}") - print(f" 数据目录: {config.data_dir}") - print(f" 感知层最大块数: {config.perceptual_max_blocks}") - print(f" 短期层最大记忆数: {config.short_term_max_memories}") - print(f" 激活阈值: {config.activation_threshold}") - - return True - - -async def test_integration(): - """测试系统集成""" - print("\n" + "=" * 60) - print("测试6: 系统集成") - print("=" * 60) - - # 首先需要确保配置启用 - from src.config.config import global_config - - if not global_config.three_tier_memory or not global_config.three_tier_memory.enable: - print("⚠️ 配置未启用,跳过集成测试") - return False - - # 测试单例模式 - from src.memory_graph.three_tier.manager_singleton import ( - get_unified_memory_manager, - initialize_unified_memory_manager, - ) - - # 初始化 - await initialize_unified_memory_manager() - manager = get_unified_memory_manager() - - if manager is None: - print("❌ 统一管理器初始化失败") - return False - - print("✅ 单例模式正常工作") - - # 测试多次获取 - manager2 = get_unified_memory_manager() - if manager is not manager2: - print("❌ 单例模式失败(返回不同实例)") - return False - - print("✅ 单例一致性验证通过") - - return True - - -async def run_all_tests(): - """运行所有测试""" - print("\n" + "🔬" * 30) - print("三层记忆系统集成测试") - print("🔬" * 30) - - tests = [ - ("配置系统", test_configuration), - ("感知记忆层", test_perceptual_memory), - ("短期记忆层", test_short_term_memory), - ("长期记忆层", test_long_term_memory), - ("统一管理器", test_unified_manager), - ("系统集成", test_integration), - ] - - results = [] - - for name, test_func in tests: - try: - result = await test_func() - results.append((name, result)) - except Exception as e: - print(f"\n❌ 测试 {name} 失败: {e}") - import traceback - - traceback.print_exc() - results.append((name, False)) - - # 打印测试总结 - print("\n" + "=" * 60) - print("测试总结") - print("=" * 60) - - passed = sum(1 for _, result in results if result) - total = len(results) - - for name, result in results: - status = "✅ 通过" if result else "❌ 失败" - print(f"{status} - {name}") - - print(f"\n总计: {passed}/{total} 测试通过") - - if passed == total: - print("\n🎉 所有测试通过!三层记忆系统工作正常。") - else: - print("\n⚠️ 部分测试失败,请查看上方详细信息。") - - return passed == total - - -if __name__ == "__main__": - success = asyncio.run(run_all_tests()) - sys.exit(0 if success else 1) diff --git a/scripts/update_database_imports.py b/scripts/update_database_imports.py deleted file mode 100644 index 15736e641..000000000 --- a/scripts/update_database_imports.py +++ /dev/null @@ -1,185 +0,0 @@ -"""批量更新数据库导入语句的脚本 - -将旧的数据库导入路径更新为新的重构后的路径: -- sqlalchemy_models -> core, core.models -- sqlalchemy_database_api -> compatibility -- database.database -> core -""" - -import re -from pathlib import Path - -# 定义导入映射规则 -IMPORT_MAPPINGS = { - # 模型导入 - r"from src\.common\.database\.sqlalchemy_models import (.+)": - r"from src.common.database.core.models import \1", - - # API导入 - 需要特殊处理 - r"from src\.common\.database\.sqlalchemy_database_api import (.+)": - r"from src.common.database.compatibility import \1", - - # get_db_session 从 sqlalchemy_database_api 导入 - r"from src\.common\.database\.sqlalchemy_database_api import get_db_session": - r"from src.common.database.core import get_db_session", - - # get_db_session 从 sqlalchemy_models 导入 - r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)": - lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}" - if "get_db_session" in m.group(0) else m.group(0), - - # get_engine 导入 - r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)": - lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}", - - # Base 导入 - r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)": - lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}", - - # initialize_database 导入 - r"from src\.common\.database\.sqlalchemy_models import initialize_database": - r"from src.common.database.core import check_and_migrate_database as initialize_database", - - # database.py 导入 - r"from src\.common\.database\.database import stop_database": - r"from src.common.database.core import close_engine as stop_database", - - r"from src\.common\.database\.database import initialize_sql_database": - r"from src.common.database.core import check_and_migrate_database as initialize_sql_database", -} - -# 需要排除的文件 -EXCLUDE_PATTERNS = [ - "**/database_refactoring_plan.md", # 文档文件 - "**/old/**", # 旧文件目录 - "**/sqlalchemy_*.py", # 旧的数据库文件本身 - "**/database.py", # 旧的database文件 - "**/db_*.py", # 旧的db文件 -] - - -def should_exclude(file_path: Path) -> bool: - """检查文件是否应该被排除""" - for pattern in EXCLUDE_PATTERNS: - if file_path.match(pattern): - return True - return False - - -def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]: - """更新单个文件中的导入语句 - - Args: - file_path: 文件路径 - dry_run: 是否只是预览而不实际修改 - - Returns: - (修改次数, 修改详情列表) - """ - try: - content = file_path.read_text(encoding="utf-8") - original_content = content - changes = [] - - # 应用每个映射规则 - for pattern, replacement in IMPORT_MAPPINGS.items(): - matches = list(re.finditer(pattern, content)) - for match in matches: - old_line = match.group(0) - - # 处理函数类型的替换 - if callable(replacement): - new_line_result = replacement(match) - new_line = new_line_result if isinstance(new_line_result, str) else old_line - else: - new_line = re.sub(pattern, replacement, old_line) - - if old_line != new_line and isinstance(new_line, str): - content = content.replace(old_line, new_line, 1) - changes.append(f" - {old_line}") - changes.append(f" + {new_line}") - - # 如果有修改且不是dry_run,写回文件 - if content != original_content: - if not dry_run: - file_path.write_text(content, encoding="utf-8") - return len(changes) // 2, changes - - return 0, [] - - except Exception as e: - print(f"❌ 处理文件 {file_path} 时出错: {e}") - return 0, [] - - -def main(): - """主函数""" - print("🔍 搜索需要更新导入的文件...") - - # 获取项目根目录 - root_dir = Path(__file__).parent.parent - - # 搜索所有Python文件 - all_python_files = list(root_dir.rglob("*.py")) - - # 过滤掉排除的文件 - target_files = [f for f in all_python_files if not should_exclude(f)] - - print(f"📊 找到 {len(target_files)} 个Python文件需要检查") - print("\n" + "="*80) - - # 第一遍:预览模式 - print("\n🔍 预览模式 - 检查需要更新的文件...\n") - - files_to_update = [] - for file_path in target_files: - count, changes = update_imports_in_file(file_path, dry_run=True) - if count > 0: - files_to_update.append((file_path, count, changes)) - - if not files_to_update: - print("✅ 没有文件需要更新!") - return - - print(f"📝 发现 {len(files_to_update)} 个文件需要更新:\n") - - total_changes = 0 - for file_path, count, changes in files_to_update: - rel_path = file_path.relative_to(root_dir) - print(f"\n📄 {rel_path} ({count} 处修改)") - for change in changes[:10]: # 最多显示前5对修改 - print(change) - if len(changes) > 10: - print(f" ... 还有 {len(changes) - 10} 行") - total_changes += count - - print("\n" + "="*80) - print("\n📊 统计:") - print(f" - 需要更新的文件: {len(files_to_update)}") - print(f" - 总修改次数: {total_changes}") - - # 询问是否继续 - print("\n" + "="*80) - response = input("\n是否执行更新?(yes/no): ").strip().lower() - - if response != "yes": - print("❌ 已取消更新") - return - - # 第二遍:实际更新 - print("\n✨ 开始更新文件...\n") - - success_count = 0 - for file_path, _, _ in files_to_update: - count, _ = update_imports_in_file(file_path, dry_run=False) - if count > 0: - rel_path = file_path.relative_to(root_dir) - print(f"✅ {rel_path} ({count} 处修改)") - success_count += 1 - - print("\n" + "="*80) - print(f"\n🎉 完成!成功更新 {success_count} 个文件") - - -if __name__ == "__main__": - main() diff --git a/src/chat/message_manager/batch_database_writer.py b/src/chat/message_manager/batch_database_writer.py index adea3a607..0d379c01a 100644 --- a/src/chat/message_manager/batch_database_writer.py +++ b/src/chat/message_manager/batch_database_writer.py @@ -238,6 +238,14 @@ class BatchDatabaseWriter: stmt = stmt.on_duplicate_key_update( **{key: value for key, value in update_data.items() if key != "stream_id"} ) + elif global_config.database.database_type == "postgresql": + from sqlalchemy.dialects.postgresql import insert as pg_insert + + stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data) + stmt = stmt.on_conflict_do_update( + index_elements=[ChatStreams.stream_id], + set_=update_data + ) else: # 默认使用SQLite语法 from sqlalchemy.dialects.sqlite import insert as sqlite_insert @@ -264,6 +272,14 @@ class BatchDatabaseWriter: stmt = stmt.on_duplicate_key_update( **{key: value for key, value in update_data.items() if key != "stream_id"} ) + elif global_config.database.database_type == "postgresql": + from sqlalchemy.dialects.postgresql import insert as pg_insert + + stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data) + stmt = stmt.on_conflict_do_update( + index_elements=[ChatStreams.stream_id], + set_=update_data + ) else: from sqlalchemy.dialects.sqlite import insert as sqlite_insert diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 9a624ffda..de04fbc7e 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -4,6 +4,7 @@ import time from rich.traceback import install from sqlalchemy.dialects.mysql import insert as mysql_insert +from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.sqlite import insert as sqlite_insert from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo @@ -663,6 +664,13 @@ class ChatManager: stmt = stmt.on_duplicate_key_update( **{key: value for key, value in fields_to_save.items() if key != "stream_id"} ) + elif global_config.database.database_type == "postgresql": + stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) + # PostgreSQL 需要使用 constraint 参数或正确的 index_elements + stmt = stmt.on_conflict_do_update( + index_elements=[ChatStreams.stream_id], + set_=fields_to_save + ) else: stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save) stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save) diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 1f8b8766e..7836c8423 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -142,8 +142,11 @@ class MessageStorageBatcher: return None # 将ORM对象转换为字典(只包含列字段) + # 排除 id 字段,让数据库自动生成(对于 PostgreSQL SERIAL 类型尤其重要) message_dict = {} for column in Messages.__table__.columns: + if column.name == "id": + continue # 跳过自增主键,让数据库自动生成 message_dict[column.name] = getattr(message_obj, column.name) return message_dict diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 6701665d8..35234c352 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -1143,7 +1143,6 @@ class Prompt: Returns: str: 构建好的跨群聊上下文字符串。 """ - logger.info(f"Building cross context with target_user_info: {target_user_info}") if not global_config.cross_context.enable: return "" diff --git a/src/common/database/connection_pool_manager.py b/src/common/database/connection_pool_manager.py index e1df6fd2e..73fd5335c 100644 --- a/src/common/database/connection_pool_manager.py +++ b/src/common/database/connection_pool_manager.py @@ -169,6 +169,11 @@ class ConnectionPoolManager: self, session_factory: async_sessionmaker[AsyncSession] ) -> ConnectionInfo | None: """获取可复用的连接""" + # 导入方言适配器获取 ping 查询 + from src.common.database.core.dialect_adapter import DialectAdapter + + ping_query = DialectAdapter.get_ping_query() + async with self._lock: # 清理过期连接 await self._cleanup_expired_connections_locked() @@ -178,8 +183,8 @@ class ConnectionPoolManager: if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle): # 验证连接是否仍然有效 try: - # 执行一个简单的查询来验证连接 - await connection_info.session.execute(text("SELECT 1")) + # 执行 ping 查询来验证连接 + await connection_info.session.execute(text(ping_query)) return connection_info except Exception as e: logger.debug(f"连接验证失败,将移除: {e}") diff --git a/src/common/database/core/__init__.py b/src/common/database/core/__init__.py index 8f83149db..2e457f37a 100644 --- a/src/common/database/core/__init__.py +++ b/src/common/database/core/__init__.py @@ -5,8 +5,22 @@ - 会话管理 - 模型定义 - 数据库迁移 +- 方言适配 + +支持的数据库: +- SQLite (默认) +- MySQL +- PostgreSQL """ +from .dialect_adapter import ( + DatabaseDialect, + DialectAdapter, + DialectConfig, + get_dialect_adapter, + get_indexed_string_field, + get_text_field, +) from .engine import close_engine, get_engine, get_engine_info from .migration import check_and_migrate_database, create_all_tables, drop_all_tables from .models import ( @@ -50,6 +64,10 @@ __all__ = [ "BotPersonalityInterests", "CacheEntries", "ChatStreams", + # Dialect Adapter + "DatabaseDialect", + "DialectAdapter", + "DialectConfig", "Emoji", "Expression", "GraphEdges", @@ -77,10 +95,13 @@ __all__ = [ # Session "get_db_session", "get_db_session_direct", + "get_dialect_adapter", # Engine "get_engine", "get_engine_info", + "get_indexed_string_field", "get_session_factory", "get_string_field", + "get_text_field", "reset_session_factory", ] diff --git a/src/common/database/core/dialect_adapter.py b/src/common/database/core/dialect_adapter.py new file mode 100644 index 000000000..e99eb47ae --- /dev/null +++ b/src/common/database/core/dialect_adapter.py @@ -0,0 +1,230 @@ +"""数据库方言适配器 + +提供跨数据库兼容性支持,处理不同数据库之间的差异: +- SQLite: 轻量级本地数据库 +- MySQL: 高性能关系型数据库 +- PostgreSQL: 功能丰富的开源数据库 + +主要职责: +1. 提供数据库特定的类型映射 +2. 处理方言特定的查询语法 +3. 提供数据库特定的优化配置 +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from sqlalchemy import String, Text +from sqlalchemy.types import TypeEngine + + +class DatabaseDialect(Enum): + """数据库方言枚举""" + + SQLITE = "sqlite" + MYSQL = "mysql" + POSTGRESQL = "postgresql" + + +@dataclass +class DialectConfig: + """方言配置""" + + dialect: DatabaseDialect + # 连接验证查询 + ping_query: str + # 是否支持 RETURNING 子句 + supports_returning: bool + # 是否支持原生 JSON 类型 + supports_native_json: bool + # 是否支持数组类型 + supports_arrays: bool + # 是否需要指定字符串长度用于索引 + requires_length_for_index: bool + # 默认字符串长度(用于索引列) + default_string_length: int + # 事务隔离级别 + isolation_level: str + # 额外的引擎参数 + engine_kwargs: dict[str, Any] = field(default_factory=dict) + + +# 预定义的方言配置 +DIALECT_CONFIGS: dict[DatabaseDialect, DialectConfig] = { + DatabaseDialect.SQLITE: DialectConfig( + dialect=DatabaseDialect.SQLITE, + ping_query="SELECT 1", + supports_returning=True, # SQLite 3.35+ 支持 + supports_native_json=False, + supports_arrays=False, + requires_length_for_index=False, + default_string_length=255, + isolation_level="SERIALIZABLE", + engine_kwargs={ + "connect_args": { + "check_same_thread": False, + "timeout": 60, + } + }, + ), + DatabaseDialect.MYSQL: DialectConfig( + dialect=DatabaseDialect.MYSQL, + ping_query="SELECT 1", + supports_returning=False, # MySQL 8.0.21+ 有限支持 + supports_native_json=True, # MySQL 5.7+ + supports_arrays=False, + requires_length_for_index=True, # MySQL 索引需要指定长度 + default_string_length=255, + isolation_level="READ COMMITTED", + engine_kwargs={ + "pool_pre_ping": True, + "pool_recycle": 3600, + }, + ), + DatabaseDialect.POSTGRESQL: DialectConfig( + dialect=DatabaseDialect.POSTGRESQL, + ping_query="SELECT 1", + supports_returning=True, + supports_native_json=True, + supports_arrays=True, + requires_length_for_index=False, + default_string_length=255, + isolation_level="READ COMMITTED", + engine_kwargs={ + "pool_pre_ping": True, + "pool_recycle": 3600, + }, + ), +} + + +class DialectAdapter: + """数据库方言适配器 + + 根据当前配置的数据库类型,提供相应的类型映射和查询支持 + """ + + _current_dialect: DatabaseDialect | None = None + _config: DialectConfig | None = None + + @classmethod + def initialize(cls, db_type: str) -> None: + """初始化适配器 + + Args: + db_type: 数据库类型字符串 ("sqlite", "mysql", "postgresql") + """ + try: + cls._current_dialect = DatabaseDialect(db_type.lower()) + cls._config = DIALECT_CONFIGS[cls._current_dialect] + except ValueError: + raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, mysql, postgresql") + + @classmethod + def get_dialect(cls) -> DatabaseDialect: + """获取当前数据库方言""" + if cls._current_dialect is None: + # 延迟初始化:从配置获取 + from src.config.config import global_config + + if global_config is None: + raise RuntimeError("配置尚未初始化,无法获取数据库方言") + cls.initialize(global_config.database.database_type) + return cls._current_dialect # type: ignore + + @classmethod + def get_config(cls) -> DialectConfig: + """获取当前方言配置""" + if cls._config is None: + cls.get_dialect() # 触发初始化 + return cls._config # type: ignore + + @classmethod + def get_string_type(cls, max_length: int = 255, indexed: bool = False) -> TypeEngine: + """获取适合当前数据库的字符串类型 + + Args: + max_length: 最大长度 + indexed: 是否用于索引 + + Returns: + SQLAlchemy 类型 + """ + config = cls.get_config() + + # MySQL 索引列需要指定长度 + if config.requires_length_for_index and indexed: + return String(max_length) + + # SQLite 和 PostgreSQL 可以使用 Text + if config.dialect in (DatabaseDialect.SQLITE, DatabaseDialect.POSTGRESQL): + return Text() if not indexed else String(max_length) + + # MySQL 使用 VARCHAR + return String(max_length) + + @classmethod + def get_ping_query(cls) -> str: + """获取连接验证查询""" + return cls.get_config().ping_query + + @classmethod + def supports_returning(cls) -> bool: + """是否支持 RETURNING 子句""" + return cls.get_config().supports_returning + + @classmethod + def supports_native_json(cls) -> bool: + """是否支持原生 JSON 类型""" + return cls.get_config().supports_native_json + + @classmethod + def get_engine_kwargs(cls) -> dict[str, Any]: + """获取引擎额外参数""" + return cls.get_config().engine_kwargs.copy() + + @classmethod + def is_sqlite(cls) -> bool: + """是否为 SQLite""" + return cls.get_dialect() == DatabaseDialect.SQLITE + + @classmethod + def is_mysql(cls) -> bool: + """是否为 MySQL""" + return cls.get_dialect() == DatabaseDialect.MYSQL + + @classmethod + def is_postgresql(cls) -> bool: + """是否为 PostgreSQL""" + return cls.get_dialect() == DatabaseDialect.POSTGRESQL + + +def get_dialect_adapter() -> type[DialectAdapter]: + """获取方言适配器类""" + return DialectAdapter + + +def get_indexed_string_field(max_length: int = 255) -> TypeEngine: + """获取用于索引的字符串字段类型 + + 这是一个便捷函数,用于在模型定义中获取适合当前数据库的字符串类型 + + Args: + max_length: 最大长度(对于 MySQL 是必需的) + + Returns: + SQLAlchemy 类型 + """ + return DialectAdapter.get_string_type(max_length, indexed=True) + + +def get_text_field() -> TypeEngine: + """获取文本字段类型 + + 用于不需要索引的大文本字段 + + Returns: + SQLAlchemy Text 类型 + """ + return Text() diff --git a/src/common/database/core/engine.py b/src/common/database/core/engine.py index 131587b2b..a235449b0 100644 --- a/src/common/database/core/engine.py +++ b/src/common/database/core/engine.py @@ -1,6 +1,11 @@ """数据库引擎管理 单一职责:创建和管理SQLAlchemy异步引擎 + +支持的数据库类型: +- SQLite: 轻量级本地数据库,使用 aiosqlite 驱动 +- MySQL: 高性能关系型数据库,使用 aiomysql 驱动 +- PostgreSQL: 功能丰富的开源数据库,使用 asyncpg 驱动 """ import asyncio @@ -13,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine from src.common.logger import get_logger from ..utils.exceptions import DatabaseInitializationError +from .dialect_adapter import DialectAdapter logger = get_logger("database.engine") @@ -52,79 +58,27 @@ async def get_engine() -> AsyncEngine: config = global_config.database db_type = config.database_type + # 初始化方言适配器 + DialectAdapter.initialize(db_type) + logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...") - # 构建数据库URL和引擎参数 + # 根据数据库类型构建URL和引擎参数 if db_type == "mysql": - # MySQL配置 - encoded_user = quote_plus(config.mysql_user) - encoded_password = quote_plus(config.mysql_password) - - if config.mysql_unix_socket: - # Unix socket连接 - encoded_socket = quote_plus(config.mysql_unix_socket) - url = ( - f"mysql+aiomysql://{encoded_user}:{encoded_password}" - f"@/{config.mysql_database}" - f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" - ) - else: - # TCP连接 - url = ( - f"mysql+aiomysql://{encoded_user}:{encoded_password}" - f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" - f"?charset={config.mysql_charset}" - ) - - engine_kwargs = { - "echo": False, - "future": True, - "pool_size": config.connection_pool_size, - "max_overflow": config.connection_pool_size * 2, - "pool_timeout": config.connection_timeout, - "pool_recycle": 3600, - "pool_pre_ping": True, - "connect_args": { - "autocommit": config.mysql_autocommit, - "charset": config.mysql_charset, - "connect_timeout": config.connection_timeout, - }, - } - - logger.info( - f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" - ) - + url, engine_kwargs = _build_mysql_config(config) + elif db_type == "postgresql": + url, engine_kwargs = _build_postgresql_config(config) else: - # SQLite配置 - if not os.path.isabs(config.sqlite_path): - ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) - db_path = os.path.join(ROOT_PATH, config.sqlite_path) - else: - db_path = config.sqlite_path - - # 确保数据库目录存在 - os.makedirs(os.path.dirname(db_path), exist_ok=True) - - url = f"sqlite+aiosqlite:///{db_path}" - - engine_kwargs = { - "echo": False, - "future": True, - "connect_args": { - "check_same_thread": False, - "timeout": 60, - }, - } - - logger.info(f"SQLite配置: {db_path}") + url, engine_kwargs = _build_sqlite_config(config) # 创建异步引擎 _engine = create_async_engine(url, **engine_kwargs) - # SQLite特定优化 + # 数据库特定优化 if db_type == "sqlite": await _enable_sqlite_optimizations(_engine) + elif db_type == "postgresql": + await _enable_postgresql_optimizations(_engine) logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功") return _engine @@ -134,6 +88,141 @@ async def get_engine() -> AsyncEngine: raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e +def _build_sqlite_config(config) -> tuple[str, dict]: + """构建 SQLite 配置 + + Args: + config: 数据库配置对象 + + Returns: + (url, engine_kwargs) 元组 + """ + if not os.path.isabs(config.sqlite_path): + ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..")) + db_path = os.path.join(ROOT_PATH, config.sqlite_path) + else: + db_path = config.sqlite_path + + # 确保数据库目录存在 + os.makedirs(os.path.dirname(db_path), exist_ok=True) + + url = f"sqlite+aiosqlite:///{db_path}" + + engine_kwargs = { + "echo": False, + "future": True, + "connect_args": { + "check_same_thread": False, + "timeout": 60, + }, + } + + logger.info(f"SQLite配置: {db_path}") + return url, engine_kwargs + + +def _build_mysql_config(config) -> tuple[str, dict]: + """构建 MySQL 配置 + + Args: + config: 数据库配置对象 + + Returns: + (url, engine_kwargs) 元组 + """ + encoded_user = quote_plus(config.mysql_user) + encoded_password = quote_plus(config.mysql_password) + + if config.mysql_unix_socket: + # Unix socket连接 + encoded_socket = quote_plus(config.mysql_unix_socket) + url = ( + f"mysql+aiomysql://{encoded_user}:{encoded_password}" + f"@/{config.mysql_database}" + f"?unix_socket={encoded_socket}&charset={config.mysql_charset}" + ) + else: + # TCP连接 + url = ( + f"mysql+aiomysql://{encoded_user}:{encoded_password}" + f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" + f"?charset={config.mysql_charset}" + ) + + engine_kwargs = { + "echo": False, + "future": True, + "pool_size": config.connection_pool_size, + "max_overflow": config.connection_pool_size * 2, + "pool_timeout": config.connection_timeout, + "pool_recycle": 3600, + "pool_pre_ping": True, + "connect_args": { + "autocommit": config.mysql_autocommit, + "charset": config.mysql_charset, + "connect_timeout": config.connection_timeout, + }, + } + + logger.info( + f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}" + ) + return url, engine_kwargs + + +def _build_postgresql_config(config) -> tuple[str, dict]: + """构建 PostgreSQL 配置 + + Args: + config: 数据库配置对象 + + Returns: + (url, engine_kwargs) 元组 + """ + encoded_user = quote_plus(config.postgresql_user) + encoded_password = quote_plus(config.postgresql_password) + + # 构建基本 URL + url = ( + f"postgresql+asyncpg://{encoded_user}:{encoded_password}" + f"@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}" + ) + + # SSL 配置 + connect_args = {} + if config.postgresql_ssl_mode != "disable": + ssl_config = {"ssl": config.postgresql_ssl_mode} + if config.postgresql_ssl_ca: + ssl_config["ssl_ca"] = config.postgresql_ssl_ca + if config.postgresql_ssl_cert: + ssl_config["ssl_cert"] = config.postgresql_ssl_cert + if config.postgresql_ssl_key: + ssl_config["ssl_key"] = config.postgresql_ssl_key + connect_args.update(ssl_config) + + # 设置 schema(如果不是 public) + if config.postgresql_schema and config.postgresql_schema != "public": + connect_args["server_settings"] = {"search_path": config.postgresql_schema} + + engine_kwargs = { + "echo": False, + "future": True, + "pool_size": config.connection_pool_size, + "max_overflow": config.connection_pool_size * 2, + "pool_timeout": config.connection_timeout, + "pool_recycle": 3600, + "pool_pre_ping": True, + } + + if connect_args: + engine_kwargs["connect_args"] = connect_args + + logger.info( + f"PostgreSQL配置: {config.postgresql_user}@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}" + ) + return url, engine_kwargs + + async def close_engine(): """关闭数据库引擎 @@ -181,6 +270,33 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine): logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置") +async def _enable_postgresql_optimizations(engine: AsyncEngine): + """启用PostgreSQL性能优化 + + 优化项: + - 设置合适的 work_mem + - 启用 JIT 编译(如果可用) + - 设置合适的 statement_timeout + + Args: + engine: SQLAlchemy异步引擎 + """ + try: + async with engine.begin() as conn: + # 设置会话级别的参数 + # work_mem: 排序和哈希操作的内存(64MB) + await conn.execute(text("SET work_mem = '64MB'")) + # 设置语句超时(5分钟) + await conn.execute(text("SET statement_timeout = '300000'")) + # 启用自动 EXPLAIN(可选,用于调试) + # await conn.execute(text("SET auto_explain.log_min_duration = '1000'")) + + logger.info("✅ PostgreSQL性能优化已启用") + + except Exception as e: + logger.warning(f"⚠️ PostgreSQL性能优化失败: {e},将使用默认配置") + + async def get_engine_info() -> dict: """获取引擎信息(用于监控和调试) diff --git a/src/common/database/core/migration.py b/src/common/database/core/migration.py index 35d25e05e..b86c69e53 100644 --- a/src/common/database/core/migration.py +++ b/src/common/database/core/migration.py @@ -99,12 +99,17 @@ async def check_and_migrate_database(existing_engine=None): def add_columns_sync(conn): dialect = conn.dialect - compiler = dialect.ddl_compiler(dialect, None) - + for column_name in missing_columns: column = table.c[column_name] - column_type = compiler.get_column_specification(column) - sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}" + + # 获取列类型的 SQL 表示 + # 使用 compile 方法获取正确的类型字符串 + type_compiler = dialect.type_compiler(dialect) + column_type_sql = column.type.compile(dialect=dialect) + + # 构建 ALTER TABLE 语句 + sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type_sql}" if column.default: # 手动处理不同方言的默认值 @@ -114,26 +119,18 @@ async def check_and_migrate_database(existing_engine=None): ): # SQLite 将布尔值存储为 0 或 1 default_value = "1" if default_arg else "0" - elif hasattr(compiler, "render_literal_value"): - try: - # 尝试使用 render_literal_value - default_value = compiler.render_literal_value( - default_arg, column.type - ) - except AttributeError: - # 如果失败,则回退到简单的字符串转换 - default_value = ( - f"'{default_arg}'" - if isinstance(default_arg, str) - else str(default_arg) - ) + elif dialect.name == "mysql" and isinstance(default_arg, bool): + # MySQL 也使用 1/0 表示布尔值 + default_value = "1" if default_arg else "0" + elif isinstance(default_arg, bool): + # PostgreSQL 使用 TRUE/FALSE + default_value = "TRUE" if default_arg else "FALSE" + elif isinstance(default_arg, str): + default_value = f"'{default_arg}'" + elif default_arg is None: + default_value = "NULL" else: - # 对于没有 render_literal_value 的旧版或特定方言 - default_value = ( - f"'{default_arg}'" - if isinstance(default_arg, str) - else str(default_arg) - ) + default_value = str(default_arg) sql += f" DEFAULT {default_value}" diff --git a/src/common/database/core/models.py b/src/common/database/core/models.py index 202eb9dbb..89d5d6f68 100644 --- a/src/common/database/core/models.py +++ b/src/common/database/core/models.py @@ -3,6 +3,11 @@ 本文件只包含纯模型定义,使用SQLAlchemy 2.0的Mapped类型注解风格。 引擎和会话管理已移至core/engine.py和core/session.py。 +支持的数据库类型: +- SQLite: 使用 Text 类型 +- MySQL: 使用 VARCHAR(max_length) 用于索引字段 +- PostgreSQL: 使用 Text 类型(PostgreSQL 的 Text 类型性能与 VARCHAR 相当) + 所有模型使用统一的类型注解风格: field_name: Mapped[PyType] = mapped_column(Type, ...) @@ -20,16 +25,34 @@ from sqlalchemy.orm import Mapped, mapped_column Base = declarative_base() -# MySQL兼容的字段类型辅助函数 +# 数据库兼容的字段类型辅助函数 def get_string_field(max_length=255, **kwargs): """ - 根据数据库类型返回合适的字符串字段 - MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text + 根据数据库类型返回合适的字符串字段类型 + + 对于需要索引的字段: + - MySQL: 必须使用 VARCHAR(max_length),因为索引需要指定长度 + - PostgreSQL: 可以使用 Text,但为了兼容性使用 VARCHAR + - SQLite: 可以使用 Text,无长度限制 + + Args: + max_length: 最大长度(对于 MySQL 是必需的) + **kwargs: 传递给 String/Text 的额外参数 + + Returns: + SQLAlchemy 类型 """ from src.config.config import global_config - if global_config.database.database_type == "mysql": + db_type = global_config.database.database_type + + # MySQL 索引需要指定长度的 VARCHAR + if db_type == "mysql": return String(max_length, **kwargs) + # PostgreSQL 可以使用 Text,但为了跨数据库迁移兼容性,使用 VARCHAR + elif db_type == "postgresql": + return String(max_length, **kwargs) + # SQLite 使用 Text(无长度限制) else: return Text(**kwargs) @@ -477,7 +500,7 @@ class BanUser(Base): __tablename__ = "ban_users" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - platform: Mapped[str] = mapped_column(Text, nullable=False) + platform: Mapped[str] = mapped_column(get_string_field(50), nullable=False) # 使用有限长度,以便创建索引 user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True) violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True) reason: Mapped[str] = mapped_column(Text, nullable=False) diff --git a/src/common/database/core/session.py b/src/common/database/core/session.py index 90e3f634c..b033088f9 100644 --- a/src/common/database/core/session.py +++ b/src/common/database/core/session.py @@ -1,6 +1,11 @@ """数据库会话管理 单一职责:提供数据库会话工厂和上下文管理器 + +支持的数据库类型: +- SQLite: 设置 PRAGMA 参数优化并发 +- MySQL: 无特殊会话设置 +- PostgreSQL: 可选设置 schema 搜索路径 """ import asyncio @@ -53,12 +58,43 @@ async def get_session_factory() -> async_sessionmaker: return _session_factory +async def _apply_session_settings(session: AsyncSession, db_type: str) -> None: + """应用数据库特定的会话设置 + + Args: + session: 数据库会话 + db_type: 数据库类型 + """ + try: + if db_type == "sqlite": + # SQLite 特定的 PRAGMA 设置 + await session.execute(text("PRAGMA busy_timeout = 60000")) + await session.execute(text("PRAGMA foreign_keys = ON")) + elif db_type == "postgresql": + # PostgreSQL 特定设置(如果需要) + # 可以设置 schema 搜索路径等 + from src.config.config import global_config + + schema = global_config.database.postgresql_schema + if schema and schema != "public": + await session.execute(text(f"SET search_path TO {schema}")) + # MySQL 通常不需要会话级别的特殊设置 + except Exception: + # 复用连接时设置可能已存在,忽略错误 + pass + + @asynccontextmanager async def get_db_session() -> AsyncGenerator[AsyncSession, None]: """获取数据库会话上下文管理器 这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。 + 支持的数据库: + - SQLite: 自动设置 busy_timeout 和外键约束 + - MySQL: 直接使用,无特殊设置 + - PostgreSQL: 支持自定义 schema + 使用示例: async with get_db_session() as session: result = await session.execute(select(User)) @@ -75,16 +111,10 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]: # 使用连接池管理器(透明复用连接) async with pool_manager.get_session(session_factory) as session: - # 为SQLite设置特定的PRAGMA + # 获取数据库类型并应用特定设置 from src.config.config import global_config - if global_config.database.database_type == "sqlite": - try: - await session.execute(text("PRAGMA busy_timeout = 60000")) - await session.execute(text("PRAGMA foreign_keys = ON")) - except Exception: - # 复用连接时PRAGMA可能已设置,忽略错误 - pass + await _apply_session_settings(session, global_config.database.database_type) yield session @@ -103,6 +133,11 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]: async with session_factory() as session: try: + # 应用数据库特定设置 + from src.config.config import global_config + + await _apply_session_settings(session, global_config.database.database_type) + yield session except Exception: await session.rollback() diff --git a/src/common/database/optimization/batch_scheduler.py b/src/common/database/optimization/batch_scheduler.py index 08c272172..821d0b1af 100644 --- a/src/common/database/optimization/batch_scheduler.py +++ b/src/common/database/optimization/batch_scheduler.py @@ -373,8 +373,14 @@ class AdaptiveBatchScheduler: """批量执行插入操作""" async with get_db_session() as session: try: - # 收集数据 - all_data = [op.data for op in operations if op.data] + # 收集数据,并过滤掉 id=None 的情况(让数据库自动生成) + all_data = [] + for op in operations: + if op.data: + # 过滤掉 id 为 None 的键,让数据库自动生成主键 + filtered_data = {k: v for k, v in op.data.items() if not (k == "id" and v is None)} + all_data.append(filtered_data) + if not all_data: return diff --git a/src/config/official_configs.py b/src/config/official_configs.py index fd0fe50f6..795d5751d 100644 --- a/src/config/official_configs.py +++ b/src/config/official_configs.py @@ -16,8 +16,10 @@ from src.config.config_base import ValidatedConfigBase class DatabaseConfig(ValidatedConfigBase): """数据库配置类""" - database_type: Literal["sqlite", "mysql"] = Field(default="sqlite", description="数据库类型") + database_type: Literal["sqlite", "mysql", "postgresql"] = Field(default="sqlite", description="数据库类型") sqlite_path: str = Field(default="data/MaiBot.db", description="SQLite数据库文件路径") + + # MySQL 配置 mysql_host: str = Field(default="localhost", description="MySQL服务器地址") mysql_port: int = Field(default=3306, ge=1, le=65535, description="MySQL服务器端口") mysql_database: str = Field(default="maibot", description="MySQL数据库名") @@ -33,6 +35,22 @@ class DatabaseConfig(ValidatedConfigBase): mysql_ssl_key: str = Field(default="", description="SSL密钥路径") mysql_autocommit: bool = Field(default=True, description="自动提交事务") mysql_sql_mode: str = Field(default="TRADITIONAL", description="SQL模式") + + # PostgreSQL 配置 + postgresql_host: str = Field(default="localhost", description="PostgreSQL服务器地址") + postgresql_port: int = Field(default=5432, ge=1, le=65535, description="PostgreSQL服务器端口") + postgresql_database: str = Field(default="maibot", description="PostgreSQL数据库名") + postgresql_user: str = Field(default="postgres", description="PostgreSQL用户名") + postgresql_password: str = Field(default="", description="PostgreSQL密码") + postgresql_schema: str = Field(default="public", description="PostgreSQL模式名") + postgresql_ssl_mode: Literal["disable", "allow", "prefer", "require", "verify-ca", "verify-full"] = Field( + default="prefer", description="PostgreSQL SSL模式" + ) + postgresql_ssl_ca: str = Field(default="", description="PostgreSQL SSL CA证书路径") + postgresql_ssl_cert: str = Field(default="", description="PostgreSQL SSL客户端证书路径") + postgresql_ssl_key: str = Field(default="", description="PostgreSQL SSL密钥路径") + + # 通用连接池配置 connection_pool_size: int = Field(default=10, ge=1, description="连接池大小") connection_timeout: int = Field(default=10, ge=1, description="连接超时时间") diff --git a/template/bot_config_template.toml b/template/bot_config_template.toml index 775ef6c36..d7b24eaef 100644 --- a/template/bot_config_template.toml +++ b/template/bot_config_template.toml @@ -1,5 +1,5 @@ [inner] -version = "7.8.3" +version = "7.9.0" #----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- #如果你想要修改配置文件,请递增version的值 @@ -12,7 +12,7 @@ version = "7.8.3" #----以上是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读---- [database]# 数据库配置 -database_type = "sqlite" # 数据库类型,支持 "sqlite" 或 "mysql" +database_type = "sqlite" # 数据库类型,支持 "sqlite"、"mysql" 或 "postgresql" # SQLite 配置(当 database_type = "sqlite" 时使用) sqlite_path = "data/MaiBot.db" # SQLite数据库文件路径 @@ -36,8 +36,22 @@ mysql_ssl_key = "" # SSL客户端密钥路径 mysql_autocommit = true # 自动提交事务 mysql_sql_mode = "TRADITIONAL" # SQL模式 -# 连接池配置 -connection_pool_size = 10 # 连接池大小(仅MySQL有效) +# PostgreSQL 配置(当 database_type = "postgresql" 时使用) +postgresql_host = "localhost" # PostgreSQL服务器地址 +postgresql_port = 5432 # PostgreSQL服务器端口 +postgresql_database = "maibot" # PostgreSQL数据库名 +postgresql_user = "postgres" # PostgreSQL用户名 +postgresql_password = "" # PostgreSQL密码 +postgresql_schema = "public" # PostgreSQL模式名(schema) + +# PostgreSQL SSL 配置 +postgresql_ssl_mode = "prefer" # SSL模式: disable, allow, prefer, require, verify-ca, verify-full +postgresql_ssl_ca = "" # SSL CA证书路径 +postgresql_ssl_cert = "" # SSL客户端证书路径 +postgresql_ssl_key = "" # SSL客户端密钥路径 + +# 连接池配置(MySQL 和 PostgreSQL 有效) +connection_pool_size = 10 # 连接池大小 connection_timeout = 10 # 连接超时时间(秒) # 批量动作记录存储配置