feat: 更新机器人配置并添加数据库迁移脚本
- 将bot_config_template.toml中的版本升级至7.9.0 - 增强数据库配置选项以支持PostgreSQL - 引入一个新脚本,用于在SQLite、MySQL和PostgreSQL之间迁移数据 - 实现一个方言适配器,用于处理特定于数据库的行为和配置
This commit is contained in:
@@ -79,7 +79,9 @@ dependencies = [
|
||||
"rjieba>=0.1.13",
|
||||
"fastmcp>=2.13.0",
|
||||
"mofox-wire",
|
||||
"jinja2>=3.1.0"
|
||||
"jinja2>=3.1.0",
|
||||
"psycopg2-binary",
|
||||
"PyMySQL"
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
|
||||
@@ -1,6 +1,10 @@
|
||||
aiosqlite
|
||||
aiofiles
|
||||
aiomysql
|
||||
asyncpg
|
||||
psycopg[binary]
|
||||
psycopg2-binary
|
||||
PyMySQL
|
||||
APScheduler
|
||||
aiohttp
|
||||
aiohttp-cors
|
||||
|
||||
@@ -1,117 +0,0 @@
|
||||
"""
|
||||
检查表达方式数据库状态的诊断脚本
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Expression
|
||||
|
||||
|
||||
async def check_database():
|
||||
"""检查表达方式数据库状态"""
|
||||
|
||||
print("=" * 60)
|
||||
print("表达方式数据库诊断报告")
|
||||
print("=" * 60)
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 1. 统计总数
|
||||
total_count = await session.execute(select(func.count()).select_from(Expression))
|
||||
total = total_count.scalar()
|
||||
print(f"\n📊 总表达方式数量: {total}")
|
||||
|
||||
if total == 0:
|
||||
print("\n⚠️ 数据库为空!")
|
||||
print("\n可能的原因:")
|
||||
print("1. 还没有进行过表达学习")
|
||||
print("2. 配置中禁用了表达学习")
|
||||
print("3. 学习过程中发生了错误")
|
||||
print("\n建议:")
|
||||
print("- 检查 bot_config.toml 中的 [expression] 配置")
|
||||
print("- 查看日志中是否有表达学习相关的错误")
|
||||
print("- 确认聊天流的 learn_expression 配置为 true")
|
||||
return
|
||||
|
||||
# 2. 按 chat_id 统计
|
||||
print("\n📝 按聊天流统计:")
|
||||
chat_counts = await session.execute(
|
||||
select(Expression.chat_id, func.count())
|
||||
.group_by(Expression.chat_id)
|
||||
)
|
||||
for chat_id, count in chat_counts:
|
||||
print(f" - {chat_id}: {count} 个表达方式")
|
||||
|
||||
# 3. 按 type 统计
|
||||
print("\n📝 按类型统计:")
|
||||
type_counts = await session.execute(
|
||||
select(Expression.type, func.count())
|
||||
.group_by(Expression.type)
|
||||
)
|
||||
for expr_type, count in type_counts:
|
||||
print(f" - {expr_type}: {count} 个")
|
||||
|
||||
# 4. 检查 situation 和 style 字段是否有空值
|
||||
print("\n🔍 字段完整性检查:")
|
||||
null_situation = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(Expression)
|
||||
.where(Expression.situation is None)
|
||||
)
|
||||
null_style = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(Expression)
|
||||
.where(Expression.style is None)
|
||||
)
|
||||
|
||||
null_sit_count = null_situation.scalar()
|
||||
null_sty_count = null_style.scalar()
|
||||
|
||||
print(f" - situation 为空: {null_sit_count} 个")
|
||||
print(f" - style 为空: {null_sty_count} 个")
|
||||
|
||||
if null_sit_count > 0 or null_sty_count > 0:
|
||||
print(" ⚠️ 发现空值!这会导致匹配失败")
|
||||
|
||||
# 5. 显示一些样例数据
|
||||
print("\n📋 样例数据 (前10条):")
|
||||
samples = await session.execute(
|
||||
select(Expression)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
for i, expr in enumerate(samples.scalars(), 1):
|
||||
print(f"\n [{i}] Chat: {expr.chat_id}")
|
||||
print(f" Type: {expr.type}")
|
||||
print(f" Situation: {expr.situation}")
|
||||
print(f" Style: {expr.style}")
|
||||
print(f" Count: {expr.count}")
|
||||
|
||||
# 6. 检查 style 字段的唯一值
|
||||
print("\n📋 Style 字段样例 (前20个):")
|
||||
unique_styles = await session.execute(
|
||||
select(Expression.style)
|
||||
.distinct()
|
||||
.limit(20)
|
||||
)
|
||||
|
||||
styles = list(unique_styles.scalars())
|
||||
for style in styles:
|
||||
print(f" - {style}")
|
||||
|
||||
print(f"\n (共 {len(styles)} 个不同的 style)")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("诊断完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_database())
|
||||
@@ -1,65 +0,0 @@
|
||||
"""
|
||||
检查数据库中 style 字段的内容特征
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Expression
|
||||
|
||||
|
||||
async def analyze_style_fields():
|
||||
"""分析 style 字段的内容"""
|
||||
|
||||
print("=" * 60)
|
||||
print("Style 字段内容分析")
|
||||
print("=" * 60)
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 获取所有表达方式
|
||||
result = await session.execute(select(Expression).limit(30))
|
||||
expressions = result.scalars().all()
|
||||
|
||||
print(f"\n总共检查 {len(expressions)} 条记录\n")
|
||||
|
||||
# 按类型分类
|
||||
style_examples = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"length": len(expr.style) if expr.style else 0
|
||||
}
|
||||
for expr in expressions if expr.type == "style"
|
||||
]
|
||||
|
||||
print("📋 Style 类型样例 (前15条):")
|
||||
print("="*60)
|
||||
for i, ex in enumerate(style_examples[:15], 1):
|
||||
print(f"\n[{i}]")
|
||||
print(f" Situation: {ex['situation']}")
|
||||
print(f" Style: {ex['style']}")
|
||||
print(f" 长度: {ex['length']} 字符")
|
||||
|
||||
# 判断是具体表达还是风格描述
|
||||
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
|
||||
style_type = "✓ 风格描述"
|
||||
elif ex["length"] <= 10:
|
||||
style_type = "? 可能是具体表达(较短)"
|
||||
else:
|
||||
style_type = "✗ 具体表达内容"
|
||||
|
||||
print(f" 类型判断: {style_type}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("分析完成")
|
||||
print("="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(analyze_style_fields())
|
||||
@@ -1,49 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""清理 core/models.py,只保留模型定义"""
|
||||
|
||||
import os
|
||||
|
||||
# 文件路径
|
||||
models_file = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"src",
|
||||
"common",
|
||||
"database",
|
||||
"core",
|
||||
"models.py"
|
||||
)
|
||||
|
||||
print(f"正在清理文件: {models_file}")
|
||||
|
||||
# 读取文件
|
||||
with open(models_file, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 找到最后一个模型类的结束位置(MonthlyPlan的 __table_args__ 结束)
|
||||
# 我们要保留到第593行(包含)
|
||||
keep_lines = []
|
||||
found_end = False
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
keep_lines.append(line)
|
||||
|
||||
# 检查是否到达 MonthlyPlan 的 __table_args__ 结束
|
||||
if i > 580 and line.strip() == ")":
|
||||
# 再检查前一行是否有 Index 相关内容
|
||||
if "idx_monthlyplan" in "".join(lines[max(0, i-5):i]):
|
||||
print(f"找到模型定义结束位置: 第 {i} 行")
|
||||
found_end = True
|
||||
break
|
||||
|
||||
if not found_end:
|
||||
print("❌ 未找到模型定义结束标记")
|
||||
exit(1)
|
||||
|
||||
# 写回文件
|
||||
with open(models_file, "w", encoding="utf-8") as f:
|
||||
f.writelines(keep_lines)
|
||||
|
||||
print("✅ 文件清理完成")
|
||||
print(f"保留行数: {len(keep_lines)}")
|
||||
print(f"原始行数: {len(lines)}")
|
||||
print(f"删除行数: {len(lines) - len(keep_lines)}")
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
调试 MCP 工具列表获取
|
||||
|
||||
直接测试 MCP 客户端是否能获取工具
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastmcp.client import Client, StreamableHttpTransport
|
||||
|
||||
|
||||
async def test_direct_connection():
|
||||
"""直接连接 MCP 服务器并获取工具列表"""
|
||||
print("=" * 60)
|
||||
print("直接测试 MCP 服务器连接")
|
||||
print("=" * 60)
|
||||
|
||||
url = "http://localhost:8000/mcp"
|
||||
print(f"\n连接到: {url}")
|
||||
|
||||
try:
|
||||
# 创建传输层
|
||||
transport = StreamableHttpTransport(url)
|
||||
print("✓ 传输层创建成功")
|
||||
|
||||
# 创建客户端
|
||||
async with Client(transport) as client:
|
||||
print("✓ 客户端连接成功")
|
||||
|
||||
# 获取工具列表
|
||||
print("\n正在获取工具列表...")
|
||||
tools_result = await client.list_tools()
|
||||
|
||||
print(f"\n获取结果类型: {type(tools_result)}")
|
||||
print(f"结果内容: {tools_result}")
|
||||
|
||||
# 检查是否有 tools 属性
|
||||
if hasattr(tools_result, "tools"):
|
||||
tools = tools_result.tools
|
||||
print(f"\n✓ 找到 tools 属性,包含 {len(tools)} 个工具")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
print(f"\n工具 {i}:")
|
||||
print(f" 名称: {tool.name}")
|
||||
print(f" 描述: {tool.description}")
|
||||
if hasattr(tool, "inputSchema"):
|
||||
print(f" 参数 Schema: {tool.inputSchema}")
|
||||
else:
|
||||
print("\n✗ 结果中没有 tools 属性")
|
||||
print(f"可用属性: {dir(tools_result)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 连接失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_direct_connection())
|
||||
@@ -1,88 +0,0 @@
|
||||
"""
|
||||
检查 StyleLearner 模型状态的诊断脚本
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.chat.express.style_learner import style_learner_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("debug_style_learner")
|
||||
|
||||
|
||||
def check_style_learner_status(chat_id: str):
|
||||
"""检查指定 chat_id 的 StyleLearner 状态"""
|
||||
|
||||
print("=" * 60)
|
||||
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
|
||||
print("=" * 60)
|
||||
|
||||
# 获取 learner
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
# 1. 基本信息
|
||||
print("\n📊 基本信息:")
|
||||
print(f" Chat ID: {learner.chat_id}")
|
||||
print(f" 风格数量: {len(learner.style_to_id)}")
|
||||
print(f" 下一个ID: {learner.next_style_id}")
|
||||
print(f" 最大风格数: {learner.max_styles}")
|
||||
|
||||
# 2. 学习统计
|
||||
print("\n📈 学习统计:")
|
||||
print(f" 总样本数: {learner.learning_stats['total_samples']}")
|
||||
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
|
||||
|
||||
# 3. 风格列表(前20个)
|
||||
print("\n📋 已学习的风格 (前20个):")
|
||||
all_styles = learner.get_all_styles()
|
||||
if not all_styles:
|
||||
print(" ⚠️ 没有任何风格!模型尚未训练")
|
||||
else:
|
||||
for i, style in enumerate(all_styles[:20], 1):
|
||||
style_id = learner.style_to_id.get(style)
|
||||
situation = learner.id_to_situation.get(style_id, "N/A")
|
||||
print(f" [{i}] {style}")
|
||||
print(f" (ID: {style_id}, Situation: {situation})")
|
||||
|
||||
# 4. 测试预测
|
||||
print("\n🔮 测试预测功能:")
|
||||
if not all_styles:
|
||||
print(" ⚠️ 无法测试,模型没有训练数据")
|
||||
else:
|
||||
test_situations = [
|
||||
"表示惊讶",
|
||||
"讨论游戏",
|
||||
"表达赞同"
|
||||
]
|
||||
|
||||
for test_sit in test_situations:
|
||||
print(f"\n 测试输入: '{test_sit}'")
|
||||
best_style, scores = learner.predict_style(test_sit, top_k=3)
|
||||
|
||||
if best_style:
|
||||
print(f" ✓ 最佳匹配: {best_style}")
|
||||
print(" Top 3:")
|
||||
for style, score in list(scores.items())[:3]:
|
||||
print(f" - {style}: {score:.4f}")
|
||||
else:
|
||||
print(" ✗ 预测失败")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("诊断完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 从诊断报告中看到的 chat_id
|
||||
test_chat_ids = [
|
||||
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
|
||||
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
|
||||
]
|
||||
|
||||
for chat_id in test_chat_ids:
|
||||
check_style_learner_status(chat_id)
|
||||
print("\n")
|
||||
@@ -1,403 +0,0 @@
|
||||
"""
|
||||
记忆去重工具
|
||||
|
||||
功能:
|
||||
1. 扫描所有标记为"相似"关系的记忆边
|
||||
2. 对相似记忆进行去重(保留重要性高的,删除另一个)
|
||||
3. 支持干运行模式(预览不执行)
|
||||
4. 提供详细的去重报告
|
||||
|
||||
使用方法:
|
||||
# 预览模式(不实际删除)
|
||||
python scripts/deduplicate_memories.py --dry-run
|
||||
|
||||
# 执行去重
|
||||
python scripts/deduplicate_memories.py
|
||||
|
||||
# 指定相似度阈值
|
||||
python scripts/deduplicate_memories.py --threshold 0.9
|
||||
|
||||
# 指定数据目录
|
||||
python scripts/deduplicate_memories.py --data-dir data/memory_graph
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryDeduplicator:
|
||||
"""记忆去重器"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/memory_graph", dry_run: bool = False, threshold: float = 0.85):
|
||||
self.data_dir = data_dir
|
||||
self.dry_run = dry_run
|
||||
self.threshold = threshold
|
||||
self.manager = None
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_memories": 0,
|
||||
"similar_pairs": 0,
|
||||
"duplicates_found": 0,
|
||||
"duplicates_removed": 0,
|
||||
"errors": 0,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
logger.info(f"正在初始化记忆管理器 (data_dir={self.data_dir})...")
|
||||
self.manager = await initialize_memory_manager(data_dir=self.data_dir)
|
||||
if not self.manager:
|
||||
raise RuntimeError("记忆管理器初始化失败")
|
||||
|
||||
self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories())
|
||||
logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆")
|
||||
|
||||
async def find_similar_pairs(self) -> list[tuple[str, str, float]]:
|
||||
"""
|
||||
查找所有相似的记忆对(通过向量相似度计算)
|
||||
|
||||
Returns:
|
||||
[(memory_id_1, memory_id_2, similarity), ...]
|
||||
"""
|
||||
logger.info("正在扫描相似记忆对...")
|
||||
similar_pairs = []
|
||||
seen_pairs = set() # 避免重复
|
||||
|
||||
# 获取所有记忆
|
||||
all_memories = self.manager.graph_store.get_all_memories()
|
||||
total_memories = len(all_memories)
|
||||
|
||||
logger.info(f"开始计算 {total_memories} 条记忆的相似度...")
|
||||
|
||||
# 两两比较记忆的相似度
|
||||
for i, memory_i in enumerate(all_memories):
|
||||
# 每处理10条记忆让出控制权
|
||||
if i % 10 == 0:
|
||||
await asyncio.sleep(0)
|
||||
if i > 0:
|
||||
logger.info(f"进度: {i}/{total_memories} ({i*100//total_memories}%)")
|
||||
|
||||
# 获取记忆i的向量(从主题节点)
|
||||
vector_i = None
|
||||
for node in memory_i.nodes:
|
||||
if node.embedding is not None:
|
||||
vector_i = node.embedding
|
||||
break
|
||||
|
||||
if vector_i is None:
|
||||
continue
|
||||
|
||||
# 与后续记忆比较
|
||||
for j in range(i + 1, total_memories):
|
||||
memory_j = all_memories[j]
|
||||
|
||||
# 获取记忆j的向量
|
||||
vector_j = None
|
||||
for node in memory_j.nodes:
|
||||
if node.embedding is not None:
|
||||
vector_j = node.embedding
|
||||
break
|
||||
|
||||
if vector_j is None:
|
||||
continue
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity = self._cosine_similarity(vector_i, vector_j)
|
||||
|
||||
# 只保存满足阈值的相似对
|
||||
if similarity >= self.threshold:
|
||||
pair_key = tuple(sorted([memory_i.id, memory_j.id]))
|
||||
if pair_key not in seen_pairs:
|
||||
seen_pairs.add(pair_key)
|
||||
similar_pairs.append((memory_i.id, memory_j.id, similarity))
|
||||
|
||||
self.stats["similar_pairs"] = len(similar_pairs)
|
||||
logger.info(f"找到 {len(similar_pairs)} 对相似记忆(阈值>={self.threshold})")
|
||||
|
||||
return similar_pairs
|
||||
|
||||
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
vec2_norm = np.linalg.norm(vec2)
|
||||
|
||||
if vec1_norm == 0 or vec2_norm == 0:
|
||||
return 0.0
|
||||
|
||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||
return float(similarity)
|
||||
except Exception as e:
|
||||
logger.error(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
决定保留哪个记忆,删除哪个
|
||||
|
||||
优先级:
|
||||
1. 重要性更高的
|
||||
2. 激活度更高的
|
||||
3. 创建时间更早的
|
||||
|
||||
Returns:
|
||||
(keep_id, remove_id)
|
||||
"""
|
||||
mem1 = self.manager.graph_store.get_memory_by_id(mem_id_1)
|
||||
mem2 = self.manager.graph_store.get_memory_by_id(mem_id_2)
|
||||
|
||||
if not mem1 or not mem2:
|
||||
logger.warning(f"记忆不存在: {mem_id_1} or {mem_id_2}")
|
||||
return None, None
|
||||
|
||||
# 比较重要性
|
||||
if mem1.importance > mem2.importance:
|
||||
return mem_id_1, mem_id_2
|
||||
elif mem1.importance < mem2.importance:
|
||||
return mem_id_2, mem_id_1
|
||||
|
||||
# 重要性相同,比较激活度
|
||||
if mem1.activation > mem2.activation:
|
||||
return mem_id_1, mem_id_2
|
||||
elif mem1.activation < mem2.activation:
|
||||
return mem_id_2, mem_id_1
|
||||
|
||||
# 激活度也相同,保留更早创建的
|
||||
if mem1.created_at < mem2.created_at:
|
||||
return mem_id_1, mem_id_2
|
||||
else:
|
||||
return mem_id_2, mem_id_1
|
||||
|
||||
async def deduplicate_pair(self, mem_id_1: str, mem_id_2: str, similarity: float) -> bool:
|
||||
"""
|
||||
去重一对相似记忆
|
||||
|
||||
Returns:
|
||||
是否成功去重
|
||||
"""
|
||||
keep_id, remove_id = self.decide_which_to_keep(mem_id_1, mem_id_2)
|
||||
|
||||
if not keep_id or not remove_id:
|
||||
self.stats["errors"] += 1
|
||||
return False
|
||||
|
||||
keep_mem = self.manager.graph_store.get_memory_by_id(keep_id)
|
||||
remove_mem = self.manager.graph_store.get_memory_by_id(remove_id)
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):")
|
||||
logger.info(f" 保留: {keep_id}")
|
||||
logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}")
|
||||
logger.info(f" - 重要性: {keep_mem.importance:.2f}")
|
||||
logger.info(f" - 激活度: {keep_mem.activation:.2f}")
|
||||
logger.info(f" - 创建时间: {keep_mem.created_at}")
|
||||
logger.info(f" 删除: {remove_id}")
|
||||
logger.info(f" - 主题: {remove_mem.metadata.get('topic', 'N/A')}")
|
||||
logger.info(f" - 重要性: {remove_mem.importance:.2f}")
|
||||
logger.info(f" - 激活度: {remove_mem.activation:.2f}")
|
||||
logger.info(f" - 创建时间: {remove_mem.created_at}")
|
||||
|
||||
if self.dry_run:
|
||||
logger.info(" [预览模式] 不执行实际删除")
|
||||
self.stats["duplicates_found"] += 1
|
||||
return True
|
||||
|
||||
try:
|
||||
# 增强保留记忆的属性
|
||||
keep_mem.importance = min(1.0, keep_mem.importance + 0.05)
|
||||
keep_mem.activation = min(1.0, keep_mem.activation + 0.05)
|
||||
|
||||
# 累加访问次数
|
||||
if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
|
||||
keep_mem.access_count += remove_mem.access_count
|
||||
|
||||
# 删除相似记忆
|
||||
await self.manager.delete_memory(remove_id)
|
||||
|
||||
self.stats["duplicates_removed"] += 1
|
||||
logger.info(" ✅ 删除成功")
|
||||
|
||||
# 让出控制权
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 删除失败: {e}")
|
||||
self.stats["errors"] += 1
|
||||
return False
|
||||
|
||||
async def run(self):
|
||||
"""执行去重"""
|
||||
start_time = datetime.now()
|
||||
|
||||
print("="*70)
|
||||
print("记忆去重工具")
|
||||
print("="*70)
|
||||
print(f"数据目录: {self.data_dir}")
|
||||
print(f"相似度阈值: {self.threshold}")
|
||||
print(f"模式: {'预览模式(不实际删除)' if self.dry_run else '执行模式(会实际删除)'}")
|
||||
print("="*70)
|
||||
print()
|
||||
|
||||
# 初始化
|
||||
await self.initialize()
|
||||
|
||||
# 查找相似对
|
||||
similar_pairs = await self.find_similar_pairs()
|
||||
|
||||
if not similar_pairs:
|
||||
logger.info("未找到需要去重的相似记忆对")
|
||||
print()
|
||||
print("="*70)
|
||||
print("未找到需要去重的记忆")
|
||||
print("="*70)
|
||||
return
|
||||
|
||||
# 去重处理
|
||||
logger.info(f"开始{'预览' if self.dry_run else '执行'}去重...")
|
||||
print()
|
||||
|
||||
processed_pairs = set() # 避免重复处理
|
||||
|
||||
for mem_id_1, mem_id_2, similarity in similar_pairs:
|
||||
# 检查是否已处理(可能一个记忆已被删除)
|
||||
pair_key = tuple(sorted([mem_id_1, mem_id_2]))
|
||||
if pair_key in processed_pairs:
|
||||
continue
|
||||
|
||||
# 检查记忆是否仍存在
|
||||
if not self.manager.graph_store.get_memory_by_id(mem_id_1):
|
||||
logger.debug(f"记忆 {mem_id_1} 已不存在,跳过")
|
||||
continue
|
||||
if not self.manager.graph_store.get_memory_by_id(mem_id_2):
|
||||
logger.debug(f"记忆 {mem_id_2} 已不存在,跳过")
|
||||
continue
|
||||
|
||||
# 执行去重
|
||||
success = await self.deduplicate_pair(mem_id_1, mem_id_2, similarity)
|
||||
|
||||
if success:
|
||||
processed_pairs.add(pair_key)
|
||||
|
||||
# 保存数据(如果不是干运行)
|
||||
if not self.dry_run:
|
||||
logger.info("正在保存数据...")
|
||||
await self.manager.persistence.save_graph_store(self.manager.graph_store)
|
||||
logger.info("✅ 数据已保存")
|
||||
|
||||
# 统计报告
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
print()
|
||||
print("="*70)
|
||||
print("去重报告")
|
||||
print("="*70)
|
||||
print(f"总记忆数: {self.stats['total_memories']}")
|
||||
print(f"相似记忆对: {self.stats['similar_pairs']}")
|
||||
print(f"发现重复: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}")
|
||||
print(f"{'预览通过' if self.dry_run else '成功删除'}: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}")
|
||||
print(f"错误数: {self.stats['errors']}")
|
||||
print(f"耗时: {elapsed:.2f}秒")
|
||||
|
||||
if self.dry_run:
|
||||
print()
|
||||
print("⚠️ 这是预览模式,未实际删除任何记忆")
|
||||
print("💡 要执行实际删除,请运行: python scripts/deduplicate_memories.py")
|
||||
else:
|
||||
print()
|
||||
print("✅ 去重完成!")
|
||||
final_count = len(self.manager.graph_store.get_all_memories())
|
||||
print(f"📊 最终记忆数: {final_count} (减少 {self.stats['total_memories'] - final_count} 条)")
|
||||
|
||||
print("="*70)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.manager:
|
||||
await shutdown_memory_manager()
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="记忆去重工具 - 对标记为相似的记忆进行一键去重",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# 预览模式(推荐先运行)
|
||||
python scripts/deduplicate_memories.py --dry-run
|
||||
|
||||
# 执行去重
|
||||
python scripts/deduplicate_memories.py
|
||||
|
||||
# 指定相似度阈值(只处理相似度>=0.9的记忆对)
|
||||
python scripts/deduplicate_memories.py --threshold 0.9
|
||||
|
||||
# 指定数据目录
|
||||
python scripts/deduplicate_memories.py --data-dir data/memory_graph
|
||||
|
||||
# 组合使用
|
||||
python scripts/deduplicate_memories.py --dry-run --threshold 0.95 --data-dir data/test
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="预览模式,不实际删除记忆(推荐先运行此模式)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=0.85,
|
||||
help="相似度阈值,只处理相似度>=此值的记忆对(默认: 0.85)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="data/memory_graph",
|
||||
help="记忆数据目录(默认: data/memory_graph)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建去重器
|
||||
deduplicator = MemoryDeduplicator(
|
||||
data_dir=args.data_dir,
|
||||
dry_run=args.dry_run,
|
||||
threshold=args.threshold
|
||||
)
|
||||
|
||||
try:
|
||||
# 执行去重
|
||||
await deduplicator.run()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ 用户中断操作")
|
||||
except Exception as e:
|
||||
logger.error(f"执行失败: {e}")
|
||||
print(f"\n❌ 执行失败: {e}")
|
||||
return 1
|
||||
finally:
|
||||
# 清理资源
|
||||
await deduplicator.cleanup()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(asyncio.run(main()))
|
||||
@@ -1,195 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Add project root to Python path
|
||||
from src.common.database.database_model import ChatStreams, Expression
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
# 直接从数据库查询ChatStreams表
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
# 如果有群组信息,显示群组名称
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
# 如果是私聊,显示用户昵称
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def calculate_time_distribution(expressions) -> dict[str, int]:
|
||||
"""Calculate distribution of last active time in days"""
|
||||
now = time.time()
|
||||
distribution = {
|
||||
"0-1天": 0,
|
||||
"1-3天": 0,
|
||||
"3-7天": 0,
|
||||
"7-14天": 0,
|
||||
"14-30天": 0,
|
||||
"30-60天": 0,
|
||||
"60-90天": 0,
|
||||
"90+天": 0,
|
||||
}
|
||||
for expr in expressions:
|
||||
diff_days = (now - expr.last_active_time) / (24 * 3600)
|
||||
if diff_days < 1:
|
||||
distribution["0-1天"] += 1
|
||||
elif diff_days < 3:
|
||||
distribution["1-3天"] += 1
|
||||
elif diff_days < 7:
|
||||
distribution["3-7天"] += 1
|
||||
elif diff_days < 14:
|
||||
distribution["7-14天"] += 1
|
||||
elif diff_days < 30:
|
||||
distribution["14-30天"] += 1
|
||||
elif diff_days < 60:
|
||||
distribution["30-60天"] += 1
|
||||
elif diff_days < 90:
|
||||
distribution["60-90天"] += 1
|
||||
else:
|
||||
distribution["90+天"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_count_distribution(expressions) -> dict[str, int]:
|
||||
"""Calculate distribution of count values"""
|
||||
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
|
||||
for expr in expressions:
|
||||
cnt = expr.count
|
||||
if cnt < 1:
|
||||
distribution["0-1"] += 1
|
||||
elif cnt < 2:
|
||||
distribution["1-2"] += 1
|
||||
elif cnt < 3:
|
||||
distribution["2-3"] += 1
|
||||
elif cnt < 4:
|
||||
distribution["3-4"] += 1
|
||||
elif cnt < 5:
|
||||
distribution["4-5"] += 1
|
||||
elif cnt < 10:
|
||||
distribution["5-10"] += 1
|
||||
else:
|
||||
distribution["10+"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> list[Expression]:
|
||||
"""Get top N most used expressions for a specific chat_id"""
|
||||
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
|
||||
|
||||
|
||||
def show_overall_statistics(expressions, total: int) -> None:
|
||||
"""Show overall statistics"""
|
||||
time_dist = calculate_time_distribution(expressions)
|
||||
count_dist = calculate_count_distribution(expressions)
|
||||
|
||||
print("\n=== 总体统计 ===")
|
||||
print(f"总表达式数量: {total}")
|
||||
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
print(f"{period}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
|
||||
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
|
||||
"""Show statistics for a specific chat"""
|
||||
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
|
||||
chat_total = len(chat_exprs)
|
||||
|
||||
print(f"\n=== {chat_name} ===")
|
||||
print(f"表达式数量: {chat_total}")
|
||||
|
||||
if chat_total == 0:
|
||||
print("该聊天没有表达式数据")
|
||||
return
|
||||
|
||||
# Time distribution for this chat
|
||||
time_dist = calculate_time_distribution(chat_exprs)
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
if count > 0:
|
||||
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Count distribution for this chat
|
||||
count_dist = calculate_count_distribution(chat_exprs)
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
if count > 0:
|
||||
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Top expressions
|
||||
print("\nTop 10使用最多的表达式:")
|
||||
top_exprs = get_top_expressions_by_chat(chat_id, 10)
|
||||
for i, expr in enumerate(top_exprs, 1):
|
||||
print(f"{i}. [{expr.type}] Count: {expr.count}")
|
||||
print(f" Situation: {expr.situation}")
|
||||
print(f" Style: {expr.style}")
|
||||
print()
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for expression statistics"""
|
||||
# Get all expressions
|
||||
expressions = list(Expression.select())
|
||||
if not expressions:
|
||||
print("数据库中没有找到表达式")
|
||||
return
|
||||
|
||||
total = len(expressions)
|
||||
|
||||
# Get unique chat_ids and their names
|
||||
chat_ids = list({expr.chat_id for expr in expressions})
|
||||
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
|
||||
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
|
||||
|
||||
while True:
|
||||
print("\n" + "=" * 50)
|
||||
print("表达式统计分析")
|
||||
print("=" * 50)
|
||||
print("0. 显示总体统计")
|
||||
|
||||
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
|
||||
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
|
||||
print(f"{i}. {chat_name} ({chat_count}个表达式)")
|
||||
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
try:
|
||||
choice_num = int(choice)
|
||||
if choice_num == 0:
|
||||
show_overall_statistics(expressions, total)
|
||||
elif 1 <= choice_num <= len(chat_info):
|
||||
chat_id, chat_name = chat_info[choice_num - 1]
|
||||
show_chat_statistics(chat_id, chat_name)
|
||||
else:
|
||||
print("无效的选择,请重新输入")
|
||||
except ValueError:
|
||||
print("请输入有效的数字")
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
@@ -1,66 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""提取models.py中的模型定义"""
|
||||
|
||||
import re
|
||||
|
||||
# 读取原始文件
|
||||
with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 找到get_string_field函数的开始和结束
|
||||
get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数")
|
||||
get_string_field_end = content.find("\n\nclass ChatStreams(Base):")
|
||||
get_string_field = content[get_string_field_start:get_string_field_end]
|
||||
|
||||
# 找到第一个class定义开始
|
||||
first_class_pos = content.find("class ChatStreams(Base):")
|
||||
|
||||
# 找到所有class定义,直到遇到非class的def
|
||||
# 简单策略:找到所有以"class "开头且继承Base的类
|
||||
classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)"
|
||||
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
|
||||
|
||||
if matches:
|
||||
# 取最后一个匹配的结束位置
|
||||
models_content = content[first_class_pos:first_class_pos + matches[-1].end()]
|
||||
else:
|
||||
# 备用方案:从第一个class到文件的85%位置
|
||||
models_end = int(len(content) * 0.85)
|
||||
models_content = content[first_class_pos:models_end]
|
||||
|
||||
# 创建新文件内容
|
||||
header = '''"""SQLAlchemy数据库模型定义
|
||||
|
||||
本文件只包含纯模型定义,使用SQLAlchemy 2.0的Mapped类型注解风格。
|
||||
引擎和会话管理已移至core/engine.py和core/session.py。
|
||||
|
||||
所有模型使用统一的类型注解风格:
|
||||
field_name: Mapped[PyType] = mapped_column(Type, ...)
|
||||
|
||||
这样IDE/Pylance能正确推断实例属性类型。
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
'''
|
||||
|
||||
new_content = header + get_string_field + "\n\n" + models_content
|
||||
|
||||
# 写入新文件
|
||||
with open("src/common/database/core/models.py", "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
print("✅ Models file rewritten successfully")
|
||||
print(f"File size: {len(new_content)} characters")
|
||||
pattern = r"^class \w+\(Base\):"
|
||||
model_count = len(re.findall(pattern, models_content, re.MULTILINE))
|
||||
print(f"Number of model classes: {model_count}")
|
||||
@@ -1,267 +0,0 @@
|
||||
"""
|
||||
为现有节点生成嵌入向量
|
||||
|
||||
批量为图存储中缺少嵌入向量的节点生成并索引嵌入向量
|
||||
|
||||
使用场景:
|
||||
1. 历史记忆节点没有嵌入向量
|
||||
2. 嵌入生成器之前未配置,现在需要补充生成
|
||||
3. 向量索引损坏需要重建
|
||||
|
||||
使用方法:
|
||||
python scripts/generate_missing_embeddings.py [--node-types TOPIC,OBJECT] [--batch-size 50]
|
||||
|
||||
参数说明:
|
||||
--node-types: 需要生成嵌入的节点类型,默认为 TOPIC,OBJECT
|
||||
--batch-size: 批量处理大小,默认为 50
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
async def generate_missing_embeddings(
|
||||
target_node_types: list[str] | None = None,
|
||||
batch_size: int = 50,
|
||||
):
|
||||
"""
|
||||
为缺失嵌入向量的节点生成嵌入
|
||||
|
||||
Args:
|
||||
target_node_types: 需要处理的节点类型列表(如 ["主题", "客体"])
|
||||
batch_size: 批处理大小
|
||||
"""
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager
|
||||
from src.memory_graph.models import NodeType
|
||||
|
||||
logger = get_logger("generate_missing_embeddings")
|
||||
|
||||
if target_node_types is None:
|
||||
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("🔧 为节点生成嵌入向量")
|
||||
print(f"{'='*80}\n")
|
||||
print(f"目标节点类型: {', '.join(target_node_types)}")
|
||||
print(f"批处理大小: {batch_size}\n")
|
||||
|
||||
# 1. 初始化记忆管理器
|
||||
print("🔧 正在初始化记忆管理器...")
|
||||
await initialize_memory_manager()
|
||||
manager = get_memory_manager()
|
||||
|
||||
if manager is None:
|
||||
print("❌ 记忆管理器初始化失败")
|
||||
return
|
||||
|
||||
print("✅ 记忆管理器已初始化\n")
|
||||
|
||||
# 2. 获取已索引的节点ID
|
||||
print("🔍 检查现有向量索引...")
|
||||
existing_node_ids = set()
|
||||
try:
|
||||
vector_count = manager.vector_store.collection.count()
|
||||
if vector_count > 0:
|
||||
# 分批获取所有已索引的ID
|
||||
batch_size_check = 1000
|
||||
for offset in range(0, vector_count, batch_size_check):
|
||||
limit = min(batch_size_check, vector_count - offset)
|
||||
result = manager.vector_store.collection.get(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if result and "ids" in result:
|
||||
existing_node_ids.update(result["ids"])
|
||||
|
||||
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取已索引节点ID失败: {e}")
|
||||
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||
|
||||
# 3. 收集需要生成嵌入的节点
|
||||
print("🔍 扫描需要生成嵌入的节点...")
|
||||
all_memories = manager.graph_store.get_all_memories()
|
||||
|
||||
nodes_to_process = []
|
||||
total_target_nodes = 0
|
||||
type_stats = {nt: {"total": 0, "need_emb": 0, "already_indexed": 0} for nt in target_node_types}
|
||||
|
||||
for memory in all_memories:
|
||||
for node in memory.nodes:
|
||||
if node.node_type.value in target_node_types:
|
||||
total_target_nodes += 1
|
||||
type_stats[node.node_type.value]["total"] += 1
|
||||
|
||||
# 检查是否已在向量索引中
|
||||
if node.id in existing_node_ids:
|
||||
type_stats[node.node_type.value]["already_indexed"] += 1
|
||||
continue
|
||||
|
||||
if not node.has_embedding():
|
||||
nodes_to_process.append({
|
||||
"node": node,
|
||||
"memory_id": memory.id,
|
||||
})
|
||||
type_stats[node.node_type.value]["need_emb"] += 1
|
||||
|
||||
print("\n📊 扫描结果:")
|
||||
for node_type in target_node_types:
|
||||
stats = type_stats[node_type]
|
||||
already_ok = stats["already_indexed"]
|
||||
coverage = (stats["total"] - stats["need_emb"]) / stats["total"] * 100 if stats["total"] > 0 else 0
|
||||
print(f" - {node_type}: {stats['total']} 个节点, {stats['need_emb']} 个缺失嵌入, "
|
||||
f"{already_ok} 个已索引 (覆盖率: {coverage:.1f}%)")
|
||||
|
||||
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
||||
|
||||
if len(nodes_to_process) == 0:
|
||||
print("✅ 所有节点已有嵌入向量,无需生成")
|
||||
return
|
||||
|
||||
# 3. 批量生成嵌入
|
||||
print("🚀 开始生成嵌入向量...\n")
|
||||
|
||||
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
indexed_count = 0
|
||||
|
||||
for i in range(0, len(nodes_to_process), batch_size):
|
||||
batch = nodes_to_process[i : i + batch_size]
|
||||
batch_num = i // batch_size + 1
|
||||
|
||||
print(f"📦 批次 {batch_num}/{total_batches} ({len(batch)} 个节点)...")
|
||||
|
||||
try:
|
||||
# 提取文本内容
|
||||
texts = [item["node"].content for item in batch]
|
||||
|
||||
# 批量生成嵌入
|
||||
embeddings = await manager.embedding_generator.generate_batch(texts)
|
||||
|
||||
# 为节点设置嵌入并索引
|
||||
batch_nodes_for_index = []
|
||||
|
||||
for j, (item, embedding) in enumerate(zip(batch, embeddings)):
|
||||
node = item["node"]
|
||||
|
||||
if embedding is not None:
|
||||
# 设置嵌入向量
|
||||
node.embedding = embedding
|
||||
batch_nodes_for_index.append(node)
|
||||
success_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
logger.warning(f" ⚠️ 节点 {node.id[:8]}... '{node.content[:30]}' 嵌入生成失败")
|
||||
|
||||
# 批量索引到向量数据库
|
||||
if batch_nodes_for_index:
|
||||
try:
|
||||
await manager.vector_store.add_nodes_batch(batch_nodes_for_index)
|
||||
indexed_count += len(batch_nodes_for_index)
|
||||
print(f" ✅ 成功: {len(batch_nodes_for_index)}/{len(batch)} 个节点已生成并索引")
|
||||
except Exception as e:
|
||||
# 如果批量失败,尝试逐个添加(跳过重复)
|
||||
logger.warning(f" 批量索引失败,尝试逐个添加: {e}")
|
||||
individual_success = 0
|
||||
for node in batch_nodes_for_index:
|
||||
try:
|
||||
await manager.vector_store.add_node(node)
|
||||
individual_success += 1
|
||||
indexed_count += 1
|
||||
except Exception as e2:
|
||||
if "Expected IDs to be unique" in str(e2):
|
||||
logger.debug(f" 跳过已存在节点: {node.id}")
|
||||
else:
|
||||
logger.error(f" 节点 {node.id} 索引失败: {e2}")
|
||||
print(f" ⚠️ 逐个索引: {individual_success}/{len(batch_nodes_for_index)} 个成功")
|
||||
|
||||
except Exception as e:
|
||||
failed_count += len(batch)
|
||||
logger.error(f"批次 {batch_num} 处理失败")
|
||||
print(f" ❌ 批次处理失败: {e}")
|
||||
|
||||
# 显示进度
|
||||
total_processed = min(i + batch_size, len(nodes_to_process))
|
||||
progress = total_processed / len(nodes_to_process) * 100
|
||||
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
||||
|
||||
# 4. 保存图数据(更新节点的 embedding 字段)
|
||||
print("💾 保存图数据...")
|
||||
try:
|
||||
await manager.persistence.save_graph_store(manager.graph_store)
|
||||
print("✅ 图数据已保存\n")
|
||||
except Exception as e:
|
||||
logger.error("保存图数据失败")
|
||||
print(f"❌ 保存失败: {e}\n")
|
||||
|
||||
# 5. 验证结果
|
||||
print("🔍 验证向量索引...")
|
||||
final_vector_count = manager.vector_store.collection.count()
|
||||
stats = manager.graph_store.get_statistics()
|
||||
total_nodes = stats["total_nodes"]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("📊 生成完成")
|
||||
print(f"{'='*80}")
|
||||
print(f"处理节点数: {len(nodes_to_process)}")
|
||||
print(f"成功生成: {success_count}")
|
||||
print(f"失败数量: {failed_count}")
|
||||
print(f"成功索引: {indexed_count}")
|
||||
print(f"向量索引节点数: {final_vector_count}")
|
||||
print(f"图存储节点数: {total_nodes}")
|
||||
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
||||
|
||||
# 6. 测试搜索
|
||||
print("🧪 测试搜索功能...")
|
||||
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
||||
|
||||
for query in test_queries:
|
||||
results = await manager.search_memories(query=query, top_k=3)
|
||||
if results:
|
||||
print(f"\n✅ 查询 '{query}' 找到 {len(results)} 条记忆:")
|
||||
for i, memory in enumerate(results[:2], 1):
|
||||
subject_node = memory.get_subject_node()
|
||||
# 获取主题节点(遍历所有节点找TOPIC类型)
|
||||
from src.memory_graph.models import NodeType
|
||||
topic_nodes = [n for n in memory.nodes if n.node_type == NodeType.TOPIC]
|
||||
subject = subject_node.content if subject_node else "?"
|
||||
topic = topic_nodes[0].content if topic_nodes else "?"
|
||||
print(f" {i}. {subject} - {topic} (重要性: {memory.importance:.2f})")
|
||||
else:
|
||||
print(f"\n⚠️ 查询 '{query}' 返回 0 条结果")
|
||||
|
||||
|
||||
async def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="为节点生成嵌入向量")
|
||||
parser.add_argument(
|
||||
"--node-types",
|
||||
type=str,
|
||||
default="主题,客体",
|
||||
help="需要生成嵌入的节点类型,逗号分隔(默认:主题,客体)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=50,
|
||||
help="批处理大小(默认:50)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
target_types = [t.strip() for t in args.node_types.split(",")]
|
||||
await generate_missing_embeddings(
|
||||
target_node_types=target_types,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
1051
scripts/migrate_database.py
Normal file
1051
scripts/migrate_database.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,140 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
从现有ChromaDB数据重建JSON元数据索引
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def rebuild_metadata_index():
|
||||
"""从ChromaDB重建元数据索引"""
|
||||
print("=" * 80)
|
||||
print("重建JSON元数据索引")
|
||||
print("=" * 80)
|
||||
|
||||
# 初始化记忆系统
|
||||
print("\n🔧 初始化记忆系统...")
|
||||
ms = MemorySystem()
|
||||
await ms.initialize()
|
||||
print("✅ 记忆系统已初始化")
|
||||
|
||||
if not hasattr(ms.unified_storage, "metadata_index"):
|
||||
print("❌ 元数据索引管理器未初始化")
|
||||
return
|
||||
|
||||
# 获取所有记忆
|
||||
print("\n📥 从ChromaDB获取所有记忆...")
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
try:
|
||||
# 获取集合中的所有记忆ID
|
||||
collection_name = ms.unified_storage.config.memory_collection
|
||||
result = vector_db_service.get(
|
||||
collection_name=collection_name, include=["documents", "metadatas", "embeddings"]
|
||||
)
|
||||
|
||||
if not result or not result.get("ids"):
|
||||
print("❌ ChromaDB中没有找到记忆数据")
|
||||
return
|
||||
|
||||
ids = result["ids"]
|
||||
metadatas = result.get("metadatas", [])
|
||||
|
||||
print(f"✅ 找到 {len(ids)} 条记忆")
|
||||
|
||||
# 重建元数据索引
|
||||
print("\n🔨 开始重建元数据索引...")
|
||||
entries = []
|
||||
success_count = 0
|
||||
|
||||
for i, (memory_id, metadata) in enumerate(zip(ids, metadatas, strict=False), 1):
|
||||
try:
|
||||
# 从ChromaDB元数据重建索引条目
|
||||
import orjson
|
||||
|
||||
entry = MemoryMetadataIndexEntry(
|
||||
memory_id=memory_id,
|
||||
user_id=metadata.get("user_id", "unknown"),
|
||||
memory_type=metadata.get("memory_type", "general"),
|
||||
subjects=orjson.loads(metadata.get("subjects", "[]")),
|
||||
objects=[metadata.get("object")] if metadata.get("object") else [],
|
||||
keywords=orjson.loads(metadata.get("keywords", "[]")),
|
||||
tags=orjson.loads(metadata.get("tags", "[]")),
|
||||
importance=2, # 默认NORMAL
|
||||
confidence=2, # 默认MEDIUM
|
||||
created_at=metadata.get("created_at", 0.0),
|
||||
access_count=metadata.get("access_count", 0),
|
||||
chat_id=metadata.get("chat_id"),
|
||||
content_preview=None,
|
||||
)
|
||||
|
||||
# 尝试解析importance和confidence的枚举名称
|
||||
if "importance" in metadata:
|
||||
imp_str = metadata["importance"]
|
||||
if imp_str == "LOW":
|
||||
entry.importance = 1
|
||||
elif imp_str == "NORMAL":
|
||||
entry.importance = 2
|
||||
elif imp_str == "HIGH":
|
||||
entry.importance = 3
|
||||
elif imp_str == "CRITICAL":
|
||||
entry.importance = 4
|
||||
|
||||
if "confidence" in metadata:
|
||||
conf_str = metadata["confidence"]
|
||||
if conf_str == "LOW":
|
||||
entry.confidence = 1
|
||||
elif conf_str == "MEDIUM":
|
||||
entry.confidence = 2
|
||||
elif conf_str == "HIGH":
|
||||
entry.confidence = 3
|
||||
elif conf_str == "VERIFIED":
|
||||
entry.confidence = 4
|
||||
|
||||
entries.append(entry)
|
||||
success_count += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f" 处理进度: {i}/{len(ids)} ({success_count} 成功)")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"处理记忆 {memory_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n✅ 成功解析 {success_count}/{len(ids)} 条记忆元数据")
|
||||
|
||||
# 批量更新索引
|
||||
print("\n💾 保存元数据索引...")
|
||||
ms.unified_storage.metadata_index.batch_add_or_update(entries)
|
||||
ms.unified_storage.metadata_index.save()
|
||||
|
||||
# 显示统计信息
|
||||
stats = ms.unified_storage.metadata_index.get_stats()
|
||||
print("\n📊 重建后的索引统计:")
|
||||
print(f" - 总记忆数: {stats['total_memories']}")
|
||||
print(f" - 主语数量: {stats['subjects_count']}")
|
||||
print(f" - 关键词数量: {stats['keywords_count']}")
|
||||
print(f" - 标签数量: {stats['tags_count']}")
|
||||
print(" - 类型分布:")
|
||||
for mtype, count in stats["types"].items():
|
||||
print(f" - {mtype}: {count}")
|
||||
|
||||
print("\n✅ 元数据索引重建完成!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重建索引失败: {e}")
|
||||
print(f"❌ 重建索引失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(rebuild_metadata_index())
|
||||
@@ -1,190 +0,0 @@
|
||||
"""
|
||||
MCP 集成测试脚本
|
||||
|
||||
测试 MCP 客户端连接、工具列表获取和工具调用功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.core.component_registry import ComponentRegistry
|
||||
from src.plugin_system.core.mcp_client_manager import MCPClientManager
|
||||
|
||||
logger = get_logger("test_mcp_integration")
|
||||
|
||||
|
||||
async def test_mcp_client_manager():
|
||||
"""测试 MCPClientManager 基本功能"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 1: MCPClientManager 连接和工具列表")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# 初始化 MCP 客户端管理器
|
||||
manager = MCPClientManager()
|
||||
await manager.initialize()
|
||||
|
||||
print("\n✓ MCP 客户端管理器初始化成功")
|
||||
print(f"已连接服务器数量: {len(manager.clients)}")
|
||||
|
||||
# 获取所有工具
|
||||
tools = await manager.get_all_tools()
|
||||
print(f"\n获取到 {len(tools)} 个 MCP 工具:")
|
||||
|
||||
for tool in tools:
|
||||
print(f"\n 工具: {tool}")
|
||||
# 注意: 这里 tool 是字符串形式的工具名称
|
||||
# 如果需要工具详情,需要从其他地方获取
|
||||
|
||||
return manager, tools
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {e}")
|
||||
logger.exception("MCPClientManager 测试失败")
|
||||
return None, []
|
||||
|
||||
|
||||
async def test_tool_call(manager: MCPClientManager, tools):
|
||||
"""测试工具调用"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 2: MCP 工具调用")
|
||||
print("="*60)
|
||||
|
||||
if not tools:
|
||||
print("\n⚠ 没有可用的工具进行测试")
|
||||
return
|
||||
|
||||
try:
|
||||
# 工具列表测试已在第一个测试中完成
|
||||
print("\n✓ 工具列表获取成功")
|
||||
print(f"可用工具数量: {len(tools)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 工具调用测试失败: {e}")
|
||||
logger.exception("工具调用测试失败")
|
||||
|
||||
|
||||
async def test_component_registry_integration():
|
||||
"""测试 ComponentRegistry 集成"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 3: ComponentRegistry MCP 工具集成")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
registry = ComponentRegistry()
|
||||
|
||||
# 加载 MCP 工具
|
||||
await registry.load_mcp_tools()
|
||||
|
||||
# 获取 MCP 工具
|
||||
mcp_tools = registry.get_mcp_tools()
|
||||
print(f"\n✓ ComponentRegistry 加载了 {len(mcp_tools)} 个 MCP 工具")
|
||||
|
||||
for tool in mcp_tools:
|
||||
print(f"\n 工具: {tool.name}")
|
||||
print(f" 描述: {tool.description}")
|
||||
print(f" 参数数量: {len(tool.parameters)}")
|
||||
|
||||
# 测试 is_mcp_tool 方法
|
||||
is_mcp = registry.is_mcp_tool(tool.name)
|
||||
print(f" is_mcp_tool 检测: {'✓' if is_mcp else '✗'}")
|
||||
|
||||
return mcp_tools
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ ComponentRegistry 集成测试失败: {e}")
|
||||
logger.exception("ComponentRegistry 集成测试失败")
|
||||
return []
|
||||
|
||||
|
||||
async def test_tool_execution(mcp_tools):
|
||||
"""测试通过适配器执行工具"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 4: MCPToolAdapter 工具执行")
|
||||
print("="*60)
|
||||
|
||||
if not mcp_tools:
|
||||
print("\n⚠ 没有可用的 MCP 工具进行测试")
|
||||
return
|
||||
|
||||
try:
|
||||
# 选择第一个工具测试
|
||||
test_tool = mcp_tools[0]
|
||||
print(f"\n测试工具: {test_tool.name}")
|
||||
|
||||
# 构建测试参数
|
||||
test_args = {}
|
||||
for param_name, param_type, param_desc, is_required, enum_values in test_tool.parameters:
|
||||
if is_required:
|
||||
# 根据类型提供默认值
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
|
||||
if param_type == ToolParamType.STRING:
|
||||
test_args[param_name] = "test_value"
|
||||
elif param_type == ToolParamType.INTEGER:
|
||||
test_args[param_name] = 1
|
||||
elif param_type == ToolParamType.FLOAT:
|
||||
test_args[param_name] = 1.0
|
||||
elif param_type == ToolParamType.BOOLEAN:
|
||||
test_args[param_name] = True
|
||||
|
||||
print(f"测试参数: {test_args}")
|
||||
|
||||
# 执行工具
|
||||
result = await test_tool.execute(test_args)
|
||||
|
||||
if result:
|
||||
print("\n✓ 工具执行成功")
|
||||
print(f"结果类型: {result.get('type')}")
|
||||
print(f"结果内容: {result.get('content', '')[:200]}...") # 只显示前200字符
|
||||
else:
|
||||
print("\n✗ 工具执行失败,返回 None")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 工具执行测试失败: {e}")
|
||||
logger.exception("工具执行测试失败")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试流程"""
|
||||
print("\n" + "="*60)
|
||||
print("MCP 集成测试")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# 测试 1: MCPClientManager 基本功能
|
||||
manager, tools = await test_mcp_client_manager()
|
||||
|
||||
if manager:
|
||||
# 测试 2: 工具调用
|
||||
await test_tool_call(manager, tools)
|
||||
|
||||
# 测试 3: ComponentRegistry 集成
|
||||
mcp_tools = await test_component_registry_integration()
|
||||
|
||||
# 测试 4: 工具执行
|
||||
await test_tool_execution(mcp_tools)
|
||||
|
||||
# 关闭连接
|
||||
await manager.close()
|
||||
print("\n✓ MCP 客户端连接已关闭")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("测试完成")
|
||||
print("="*60 + "\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n测试被用户中断")
|
||||
except Exception as e:
|
||||
print(f"\n测试过程中发生错误: {e}")
|
||||
logger.exception("测试失败")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,292 +0,0 @@
|
||||
"""
|
||||
三层记忆系统测试脚本
|
||||
用于验证系统各组件是否正常工作
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
async def test_perceptual_memory():
|
||||
"""测试感知记忆层"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试1: 感知记忆层")
|
||||
print("=" * 60)
|
||||
|
||||
from src.memory_graph.three_tier.perceptual_manager import get_perceptual_manager
|
||||
|
||||
manager = get_perceptual_manager()
|
||||
await manager.initialize()
|
||||
|
||||
# 添加测试消息
|
||||
test_messages = [
|
||||
("user1", "今天天气真好", 1700000000.0),
|
||||
("user2", "是啊,适合出去玩", 1700000001.0),
|
||||
("user1", "我们去公园吧", 1700000002.0),
|
||||
("user2", "好主意!", 1700000003.0),
|
||||
("user1", "带上野餐垫", 1700000004.0),
|
||||
]
|
||||
|
||||
for sender, content, timestamp in test_messages:
|
||||
message = {
|
||||
"message_id": f"msg_{timestamp}",
|
||||
"sender": sender,
|
||||
"content": content,
|
||||
"timestamp": timestamp,
|
||||
"platform": "test",
|
||||
"stream_id": "test_stream",
|
||||
}
|
||||
await manager.add_message(message)
|
||||
|
||||
print(f"✅ 成功添加 {len(test_messages)} 条消息")
|
||||
|
||||
# 测试TopK召回
|
||||
results = await manager.recall_blocks("公园野餐", top_k=2)
|
||||
print(f"✅ TopK召回返回 {len(results)} 个块")
|
||||
|
||||
if results:
|
||||
print(f" 第一个块包含 {len(results[0].messages)} 条消息")
|
||||
|
||||
# 获取统计信息
|
||||
stats = manager.get_statistics() # 不是async方法
|
||||
print(f"✅ 统计信息: {stats}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_short_term_memory():
|
||||
"""测试短期记忆层"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试2: 短期记忆层")
|
||||
print("=" * 60)
|
||||
|
||||
from src.memory_graph.three_tier.models import MemoryBlock
|
||||
from src.memory_graph.three_tier.short_term_manager import get_short_term_manager
|
||||
|
||||
manager = get_short_term_manager()
|
||||
await manager.initialize()
|
||||
|
||||
# 创建测试块
|
||||
test_block = MemoryBlock(
|
||||
id="test_block_1",
|
||||
messages=[
|
||||
{
|
||||
"message_id": "msg1",
|
||||
"sender": "user1",
|
||||
"content": "我明天要参加一个重要的面试",
|
||||
"timestamp": 1700000000.0,
|
||||
"platform": "test",
|
||||
}
|
||||
],
|
||||
combined_text="我明天要参加一个重要的面试",
|
||||
recall_count=3,
|
||||
)
|
||||
|
||||
# 从感知块转换为短期记忆
|
||||
try:
|
||||
await manager.add_from_block(test_block)
|
||||
print("✅ 成功将感知块转换为短期记忆")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 转换失败(可能需要LLM): {e}")
|
||||
return False
|
||||
|
||||
# 测试搜索
|
||||
results = await manager.search_memories("面试", top_k=3)
|
||||
print(f"✅ 搜索返回 {len(results)} 条记忆")
|
||||
|
||||
# 获取统计
|
||||
stats = manager.get_statistics()
|
||||
print(f"✅ 统计信息: {stats}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_long_term_memory():
|
||||
"""测试长期记忆层"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试3: 长期记忆层")
|
||||
print("=" * 60)
|
||||
|
||||
from src.memory_graph.three_tier.long_term_manager import get_long_term_manager
|
||||
|
||||
manager = get_long_term_manager()
|
||||
await manager.initialize()
|
||||
|
||||
print("✅ 长期记忆管理器初始化成功")
|
||||
print(" (需要现有记忆图系统支持)")
|
||||
|
||||
# 获取统计
|
||||
stats = manager.get_statistics()
|
||||
print(f"✅ 统计信息: {stats}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_unified_manager():
|
||||
"""测试统一管理器"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试4: 统一管理器")
|
||||
print("=" * 60)
|
||||
|
||||
from src.memory_graph.three_tier.unified_manager import UnifiedMemoryManager
|
||||
|
||||
manager = UnifiedMemoryManager()
|
||||
await manager.initialize()
|
||||
|
||||
# 添加测试消息
|
||||
message = {
|
||||
"message_id": "unified_test_1",
|
||||
"sender": "user1",
|
||||
"content": "这是一条测试消息",
|
||||
"timestamp": 1700000000.0,
|
||||
"platform": "test",
|
||||
"stream_id": "test_stream",
|
||||
}
|
||||
await manager.add_message(message)
|
||||
|
||||
print("✅ 通过统一接口添加消息成功")
|
||||
|
||||
# 测试搜索
|
||||
results = await manager.search_memories("测试")
|
||||
print(f"✅ 统一搜索返回结果:")
|
||||
print(f" 感知块: {len(results.get('perceptual_blocks', []))}")
|
||||
print(f" 短期记忆: {len(results.get('short_term_memories', []))}")
|
||||
print(f" 长期记忆: {len(results.get('long_term_memories', []))}")
|
||||
|
||||
# 获取统计
|
||||
stats = manager.get_statistics() # 不是async方法
|
||||
print(f"✅ 综合统计:")
|
||||
print(f" 感知层: {stats.get('perceptual', {})}")
|
||||
print(f" 短期层: {stats.get('short_term', {})}")
|
||||
print(f" 长期层: {stats.get('long_term', {})}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_configuration():
|
||||
"""测试配置加载"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试5: 配置系统")
|
||||
print("=" * 60)
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
if not hasattr(global_config, "three_tier_memory"):
|
||||
print("❌ 配置类中未找到 three_tier_memory 字段")
|
||||
return False
|
||||
|
||||
config = global_config.three_tier_memory
|
||||
|
||||
if config is None:
|
||||
print("⚠️ 三层记忆配置为 None(可能未在 bot_config.toml 中配置)")
|
||||
print(" 请在 bot_config.toml 中添加 [three_tier_memory] 配置")
|
||||
return False
|
||||
|
||||
print(f"✅ 配置加载成功")
|
||||
print(f" 启用状态: {config.enable}")
|
||||
print(f" 数据目录: {config.data_dir}")
|
||||
print(f" 感知层最大块数: {config.perceptual_max_blocks}")
|
||||
print(f" 短期层最大记忆数: {config.short_term_max_memories}")
|
||||
print(f" 激活阈值: {config.activation_threshold}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_integration():
|
||||
"""测试系统集成"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试6: 系统集成")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先需要确保配置启用
|
||||
from src.config.config import global_config
|
||||
|
||||
if not global_config.three_tier_memory or not global_config.three_tier_memory.enable:
|
||||
print("⚠️ 配置未启用,跳过集成测试")
|
||||
return False
|
||||
|
||||
# 测试单例模式
|
||||
from src.memory_graph.three_tier.manager_singleton import (
|
||||
get_unified_memory_manager,
|
||||
initialize_unified_memory_manager,
|
||||
)
|
||||
|
||||
# 初始化
|
||||
await initialize_unified_memory_manager()
|
||||
manager = get_unified_memory_manager()
|
||||
|
||||
if manager is None:
|
||||
print("❌ 统一管理器初始化失败")
|
||||
return False
|
||||
|
||||
print("✅ 单例模式正常工作")
|
||||
|
||||
# 测试多次获取
|
||||
manager2 = get_unified_memory_manager()
|
||||
if manager is not manager2:
|
||||
print("❌ 单例模式失败(返回不同实例)")
|
||||
return False
|
||||
|
||||
print("✅ 单例一致性验证通过")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
print("\n" + "🔬" * 30)
|
||||
print("三层记忆系统集成测试")
|
||||
print("🔬" * 30)
|
||||
|
||||
tests = [
|
||||
("配置系统", test_configuration),
|
||||
("感知记忆层", test_perceptual_memory),
|
||||
("短期记忆层", test_short_term_memory),
|
||||
("长期记忆层", test_long_term_memory),
|
||||
("统一管理器", test_unified_manager),
|
||||
("系统集成", test_integration),
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for name, test_func in tests:
|
||||
try:
|
||||
result = await test_func()
|
||||
results.append((name, result))
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试 {name} 失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
results.append((name, False))
|
||||
|
||||
# 打印测试总结
|
||||
print("\n" + "=" * 60)
|
||||
print("测试总结")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for _, result in results if result)
|
||||
total = len(results)
|
||||
|
||||
for name, result in results:
|
||||
status = "✅ 通过" if result else "❌ 失败"
|
||||
print(f"{status} - {name}")
|
||||
|
||||
print(f"\n总计: {passed}/{total} 测试通过")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 所有测试通过!三层记忆系统工作正常。")
|
||||
else:
|
||||
print("\n⚠️ 部分测试失败,请查看上方详细信息。")
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = asyncio.run(run_all_tests())
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,185 +0,0 @@
|
||||
"""批量更新数据库导入语句的脚本
|
||||
|
||||
将旧的数据库导入路径更新为新的重构后的路径:
|
||||
- sqlalchemy_models -> core, core.models
|
||||
- sqlalchemy_database_api -> compatibility
|
||||
- database.database -> core
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# 定义导入映射规则
|
||||
IMPORT_MAPPINGS = {
|
||||
# 模型导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.+)":
|
||||
r"from src.common.database.core.models import \1",
|
||||
|
||||
# API导入 - 需要特殊处理
|
||||
r"from src\.common\.database\.sqlalchemy_database_api import (.+)":
|
||||
r"from src.common.database.compatibility import \1",
|
||||
|
||||
# get_db_session 从 sqlalchemy_database_api 导入
|
||||
r"from src\.common\.database\.sqlalchemy_database_api import get_db_session":
|
||||
r"from src.common.database.core import get_db_session",
|
||||
|
||||
# get_db_session 从 sqlalchemy_models 导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)":
|
||||
lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}"
|
||||
if "get_db_session" in m.group(0) else m.group(0),
|
||||
|
||||
# get_engine 导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)":
|
||||
lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}",
|
||||
|
||||
# Base 导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)":
|
||||
lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}",
|
||||
|
||||
# initialize_database 导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import initialize_database":
|
||||
r"from src.common.database.core import check_and_migrate_database as initialize_database",
|
||||
|
||||
# database.py 导入
|
||||
r"from src\.common\.database\.database import stop_database":
|
||||
r"from src.common.database.core import close_engine as stop_database",
|
||||
|
||||
r"from src\.common\.database\.database import initialize_sql_database":
|
||||
r"from src.common.database.core import check_and_migrate_database as initialize_sql_database",
|
||||
}
|
||||
|
||||
# 需要排除的文件
|
||||
EXCLUDE_PATTERNS = [
|
||||
"**/database_refactoring_plan.md", # 文档文件
|
||||
"**/old/**", # 旧文件目录
|
||||
"**/sqlalchemy_*.py", # 旧的数据库文件本身
|
||||
"**/database.py", # 旧的database文件
|
||||
"**/db_*.py", # 旧的db文件
|
||||
]
|
||||
|
||||
|
||||
def should_exclude(file_path: Path) -> bool:
|
||||
"""检查文件是否应该被排除"""
|
||||
for pattern in EXCLUDE_PATTERNS:
|
||||
if file_path.match(pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]:
|
||||
"""更新单个文件中的导入语句
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
dry_run: 是否只是预览而不实际修改
|
||||
|
||||
Returns:
|
||||
(修改次数, 修改详情列表)
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
original_content = content
|
||||
changes = []
|
||||
|
||||
# 应用每个映射规则
|
||||
for pattern, replacement in IMPORT_MAPPINGS.items():
|
||||
matches = list(re.finditer(pattern, content))
|
||||
for match in matches:
|
||||
old_line = match.group(0)
|
||||
|
||||
# 处理函数类型的替换
|
||||
if callable(replacement):
|
||||
new_line_result = replacement(match)
|
||||
new_line = new_line_result if isinstance(new_line_result, str) else old_line
|
||||
else:
|
||||
new_line = re.sub(pattern, replacement, old_line)
|
||||
|
||||
if old_line != new_line and isinstance(new_line, str):
|
||||
content = content.replace(old_line, new_line, 1)
|
||||
changes.append(f" - {old_line}")
|
||||
changes.append(f" + {new_line}")
|
||||
|
||||
# 如果有修改且不是dry_run,写回文件
|
||||
if content != original_content:
|
||||
if not dry_run:
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
return len(changes) // 2, changes
|
||||
|
||||
return 0, []
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 处理文件 {file_path} 时出错: {e}")
|
||||
return 0, []
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("🔍 搜索需要更新导入的文件...")
|
||||
|
||||
# 获取项目根目录
|
||||
root_dir = Path(__file__).parent.parent
|
||||
|
||||
# 搜索所有Python文件
|
||||
all_python_files = list(root_dir.rglob("*.py"))
|
||||
|
||||
# 过滤掉排除的文件
|
||||
target_files = [f for f in all_python_files if not should_exclude(f)]
|
||||
|
||||
print(f"📊 找到 {len(target_files)} 个Python文件需要检查")
|
||||
print("\n" + "="*80)
|
||||
|
||||
# 第一遍:预览模式
|
||||
print("\n🔍 预览模式 - 检查需要更新的文件...\n")
|
||||
|
||||
files_to_update = []
|
||||
for file_path in target_files:
|
||||
count, changes = update_imports_in_file(file_path, dry_run=True)
|
||||
if count > 0:
|
||||
files_to_update.append((file_path, count, changes))
|
||||
|
||||
if not files_to_update:
|
||||
print("✅ 没有文件需要更新!")
|
||||
return
|
||||
|
||||
print(f"📝 发现 {len(files_to_update)} 个文件需要更新:\n")
|
||||
|
||||
total_changes = 0
|
||||
for file_path, count, changes in files_to_update:
|
||||
rel_path = file_path.relative_to(root_dir)
|
||||
print(f"\n📄 {rel_path} ({count} 处修改)")
|
||||
for change in changes[:10]: # 最多显示前5对修改
|
||||
print(change)
|
||||
if len(changes) > 10:
|
||||
print(f" ... 还有 {len(changes) - 10} 行")
|
||||
total_changes += count
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("\n📊 统计:")
|
||||
print(f" - 需要更新的文件: {len(files_to_update)}")
|
||||
print(f" - 总修改次数: {total_changes}")
|
||||
|
||||
# 询问是否继续
|
||||
print("\n" + "="*80)
|
||||
response = input("\n是否执行更新?(yes/no): ").strip().lower()
|
||||
|
||||
if response != "yes":
|
||||
print("❌ 已取消更新")
|
||||
return
|
||||
|
||||
# 第二遍:实际更新
|
||||
print("\n✨ 开始更新文件...\n")
|
||||
|
||||
success_count = 0
|
||||
for file_path, _, _ in files_to_update:
|
||||
count, _ = update_imports_in_file(file_path, dry_run=False)
|
||||
if count > 0:
|
||||
rel_path = file_path.relative_to(root_dir)
|
||||
print(f"✅ {rel_path} ({count} 处修改)")
|
||||
success_count += 1
|
||||
|
||||
print("\n" + "="*80)
|
||||
print(f"\n🎉 完成!成功更新 {success_count} 个文件")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -238,6 +238,14 @@ class BatchDatabaseWriter:
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
elif global_config.database.database_type == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[ChatStreams.stream_id],
|
||||
set_=update_data
|
||||
)
|
||||
else:
|
||||
# 默认使用SQLite语法
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
@@ -264,6 +272,14 @@ class BatchDatabaseWriter:
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in update_data.items() if key != "stream_id"}
|
||||
)
|
||||
elif global_config.database.database_type == "postgresql":
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data)
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[ChatStreams.stream_id],
|
||||
set_=update_data
|
||||
)
|
||||
else:
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
|
||||
from rich.traceback import install
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo
|
||||
@@ -663,6 +664,13 @@ class ChatManager:
|
||||
stmt = stmt.on_duplicate_key_update(
|
||||
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
|
||||
)
|
||||
elif global_config.database.database_type == "postgresql":
|
||||
stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
# PostgreSQL 需要使用 constraint 参数或正确的 index_elements
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=[ChatStreams.stream_id],
|
||||
set_=fields_to_save
|
||||
)
|
||||
else:
|
||||
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
|
||||
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)
|
||||
|
||||
@@ -142,8 +142,11 @@ class MessageStorageBatcher:
|
||||
return None
|
||||
|
||||
# 将ORM对象转换为字典(只包含列字段)
|
||||
# 排除 id 字段,让数据库自动生成(对于 PostgreSQL SERIAL 类型尤其重要)
|
||||
message_dict = {}
|
||||
for column in Messages.__table__.columns:
|
||||
if column.name == "id":
|
||||
continue # 跳过自增主键,让数据库自动生成
|
||||
message_dict[column.name] = getattr(message_obj, column.name)
|
||||
|
||||
return message_dict
|
||||
|
||||
@@ -1143,7 +1143,6 @@ class Prompt:
|
||||
Returns:
|
||||
str: 构建好的跨群聊上下文字符串。
|
||||
"""
|
||||
logger.info(f"Building cross context with target_user_info: {target_user_info}")
|
||||
if not global_config.cross_context.enable:
|
||||
return ""
|
||||
|
||||
|
||||
@@ -169,6 +169,11 @@ class ConnectionPoolManager:
|
||||
self, session_factory: async_sessionmaker[AsyncSession]
|
||||
) -> ConnectionInfo | None:
|
||||
"""获取可复用的连接"""
|
||||
# 导入方言适配器获取 ping 查询
|
||||
from src.common.database.core.dialect_adapter import DialectAdapter
|
||||
|
||||
ping_query = DialectAdapter.get_ping_query()
|
||||
|
||||
async with self._lock:
|
||||
# 清理过期连接
|
||||
await self._cleanup_expired_connections_locked()
|
||||
@@ -178,8 +183,8 @@ class ConnectionPoolManager:
|
||||
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
|
||||
# 验证连接是否仍然有效
|
||||
try:
|
||||
# 执行一个简单的查询来验证连接
|
||||
await connection_info.session.execute(text("SELECT 1"))
|
||||
# 执行 ping 查询来验证连接
|
||||
await connection_info.session.execute(text(ping_query))
|
||||
return connection_info
|
||||
except Exception as e:
|
||||
logger.debug(f"连接验证失败,将移除: {e}")
|
||||
|
||||
@@ -5,8 +5,22 @@
|
||||
- 会话管理
|
||||
- 模型定义
|
||||
- 数据库迁移
|
||||
- 方言适配
|
||||
|
||||
支持的数据库:
|
||||
- SQLite (默认)
|
||||
- MySQL
|
||||
- PostgreSQL
|
||||
"""
|
||||
|
||||
from .dialect_adapter import (
|
||||
DatabaseDialect,
|
||||
DialectAdapter,
|
||||
DialectConfig,
|
||||
get_dialect_adapter,
|
||||
get_indexed_string_field,
|
||||
get_text_field,
|
||||
)
|
||||
from .engine import close_engine, get_engine, get_engine_info
|
||||
from .migration import check_and_migrate_database, create_all_tables, drop_all_tables
|
||||
from .models import (
|
||||
@@ -50,6 +64,10 @@ __all__ = [
|
||||
"BotPersonalityInterests",
|
||||
"CacheEntries",
|
||||
"ChatStreams",
|
||||
# Dialect Adapter
|
||||
"DatabaseDialect",
|
||||
"DialectAdapter",
|
||||
"DialectConfig",
|
||||
"Emoji",
|
||||
"Expression",
|
||||
"GraphEdges",
|
||||
@@ -77,10 +95,13 @@ __all__ = [
|
||||
# Session
|
||||
"get_db_session",
|
||||
"get_db_session_direct",
|
||||
"get_dialect_adapter",
|
||||
# Engine
|
||||
"get_engine",
|
||||
"get_engine_info",
|
||||
"get_indexed_string_field",
|
||||
"get_session_factory",
|
||||
"get_string_field",
|
||||
"get_text_field",
|
||||
"reset_session_factory",
|
||||
]
|
||||
|
||||
230
src/common/database/core/dialect_adapter.py
Normal file
230
src/common/database/core/dialect_adapter.py
Normal file
@@ -0,0 +1,230 @@
|
||||
"""数据库方言适配器
|
||||
|
||||
提供跨数据库兼容性支持,处理不同数据库之间的差异:
|
||||
- SQLite: 轻量级本地数据库
|
||||
- MySQL: 高性能关系型数据库
|
||||
- PostgreSQL: 功能丰富的开源数据库
|
||||
|
||||
主要职责:
|
||||
1. 提供数据库特定的类型映射
|
||||
2. 处理方言特定的查询语法
|
||||
3. 提供数据库特定的优化配置
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
|
||||
class DatabaseDialect(Enum):
|
||||
"""数据库方言枚举"""
|
||||
|
||||
SQLITE = "sqlite"
|
||||
MYSQL = "mysql"
|
||||
POSTGRESQL = "postgresql"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DialectConfig:
|
||||
"""方言配置"""
|
||||
|
||||
dialect: DatabaseDialect
|
||||
# 连接验证查询
|
||||
ping_query: str
|
||||
# 是否支持 RETURNING 子句
|
||||
supports_returning: bool
|
||||
# 是否支持原生 JSON 类型
|
||||
supports_native_json: bool
|
||||
# 是否支持数组类型
|
||||
supports_arrays: bool
|
||||
# 是否需要指定字符串长度用于索引
|
||||
requires_length_for_index: bool
|
||||
# 默认字符串长度(用于索引列)
|
||||
default_string_length: int
|
||||
# 事务隔离级别
|
||||
isolation_level: str
|
||||
# 额外的引擎参数
|
||||
engine_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# 预定义的方言配置
|
||||
DIALECT_CONFIGS: dict[DatabaseDialect, DialectConfig] = {
|
||||
DatabaseDialect.SQLITE: DialectConfig(
|
||||
dialect=DatabaseDialect.SQLITE,
|
||||
ping_query="SELECT 1",
|
||||
supports_returning=True, # SQLite 3.35+ 支持
|
||||
supports_native_json=False,
|
||||
supports_arrays=False,
|
||||
requires_length_for_index=False,
|
||||
default_string_length=255,
|
||||
isolation_level="SERIALIZABLE",
|
||||
engine_kwargs={
|
||||
"connect_args": {
|
||||
"check_same_thread": False,
|
||||
"timeout": 60,
|
||||
}
|
||||
},
|
||||
),
|
||||
DatabaseDialect.MYSQL: DialectConfig(
|
||||
dialect=DatabaseDialect.MYSQL,
|
||||
ping_query="SELECT 1",
|
||||
supports_returning=False, # MySQL 8.0.21+ 有限支持
|
||||
supports_native_json=True, # MySQL 5.7+
|
||||
supports_arrays=False,
|
||||
requires_length_for_index=True, # MySQL 索引需要指定长度
|
||||
default_string_length=255,
|
||||
isolation_level="READ COMMITTED",
|
||||
engine_kwargs={
|
||||
"pool_pre_ping": True,
|
||||
"pool_recycle": 3600,
|
||||
},
|
||||
),
|
||||
DatabaseDialect.POSTGRESQL: DialectConfig(
|
||||
dialect=DatabaseDialect.POSTGRESQL,
|
||||
ping_query="SELECT 1",
|
||||
supports_returning=True,
|
||||
supports_native_json=True,
|
||||
supports_arrays=True,
|
||||
requires_length_for_index=False,
|
||||
default_string_length=255,
|
||||
isolation_level="READ COMMITTED",
|
||||
engine_kwargs={
|
||||
"pool_pre_ping": True,
|
||||
"pool_recycle": 3600,
|
||||
},
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
class DialectAdapter:
|
||||
"""数据库方言适配器
|
||||
|
||||
根据当前配置的数据库类型,提供相应的类型映射和查询支持
|
||||
"""
|
||||
|
||||
_current_dialect: DatabaseDialect | None = None
|
||||
_config: DialectConfig | None = None
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, db_type: str) -> None:
|
||||
"""初始化适配器
|
||||
|
||||
Args:
|
||||
db_type: 数据库类型字符串 ("sqlite", "mysql", "postgresql")
|
||||
"""
|
||||
try:
|
||||
cls._current_dialect = DatabaseDialect(db_type.lower())
|
||||
cls._config = DIALECT_CONFIGS[cls._current_dialect]
|
||||
except ValueError:
|
||||
raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, mysql, postgresql")
|
||||
|
||||
@classmethod
|
||||
def get_dialect(cls) -> DatabaseDialect:
|
||||
"""获取当前数据库方言"""
|
||||
if cls._current_dialect is None:
|
||||
# 延迟初始化:从配置获取
|
||||
from src.config.config import global_config
|
||||
|
||||
if global_config is None:
|
||||
raise RuntimeError("配置尚未初始化,无法获取数据库方言")
|
||||
cls.initialize(global_config.database.database_type)
|
||||
return cls._current_dialect # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_config(cls) -> DialectConfig:
|
||||
"""获取当前方言配置"""
|
||||
if cls._config is None:
|
||||
cls.get_dialect() # 触发初始化
|
||||
return cls._config # type: ignore
|
||||
|
||||
@classmethod
|
||||
def get_string_type(cls, max_length: int = 255, indexed: bool = False) -> TypeEngine:
|
||||
"""获取适合当前数据库的字符串类型
|
||||
|
||||
Args:
|
||||
max_length: 最大长度
|
||||
indexed: 是否用于索引
|
||||
|
||||
Returns:
|
||||
SQLAlchemy 类型
|
||||
"""
|
||||
config = cls.get_config()
|
||||
|
||||
# MySQL 索引列需要指定长度
|
||||
if config.requires_length_for_index and indexed:
|
||||
return String(max_length)
|
||||
|
||||
# SQLite 和 PostgreSQL 可以使用 Text
|
||||
if config.dialect in (DatabaseDialect.SQLITE, DatabaseDialect.POSTGRESQL):
|
||||
return Text() if not indexed else String(max_length)
|
||||
|
||||
# MySQL 使用 VARCHAR
|
||||
return String(max_length)
|
||||
|
||||
@classmethod
|
||||
def get_ping_query(cls) -> str:
|
||||
"""获取连接验证查询"""
|
||||
return cls.get_config().ping_query
|
||||
|
||||
@classmethod
|
||||
def supports_returning(cls) -> bool:
|
||||
"""是否支持 RETURNING 子句"""
|
||||
return cls.get_config().supports_returning
|
||||
|
||||
@classmethod
|
||||
def supports_native_json(cls) -> bool:
|
||||
"""是否支持原生 JSON 类型"""
|
||||
return cls.get_config().supports_native_json
|
||||
|
||||
@classmethod
|
||||
def get_engine_kwargs(cls) -> dict[str, Any]:
|
||||
"""获取引擎额外参数"""
|
||||
return cls.get_config().engine_kwargs.copy()
|
||||
|
||||
@classmethod
|
||||
def is_sqlite(cls) -> bool:
|
||||
"""是否为 SQLite"""
|
||||
return cls.get_dialect() == DatabaseDialect.SQLITE
|
||||
|
||||
@classmethod
|
||||
def is_mysql(cls) -> bool:
|
||||
"""是否为 MySQL"""
|
||||
return cls.get_dialect() == DatabaseDialect.MYSQL
|
||||
|
||||
@classmethod
|
||||
def is_postgresql(cls) -> bool:
|
||||
"""是否为 PostgreSQL"""
|
||||
return cls.get_dialect() == DatabaseDialect.POSTGRESQL
|
||||
|
||||
|
||||
def get_dialect_adapter() -> type[DialectAdapter]:
|
||||
"""获取方言适配器类"""
|
||||
return DialectAdapter
|
||||
|
||||
|
||||
def get_indexed_string_field(max_length: int = 255) -> TypeEngine:
|
||||
"""获取用于索引的字符串字段类型
|
||||
|
||||
这是一个便捷函数,用于在模型定义中获取适合当前数据库的字符串类型
|
||||
|
||||
Args:
|
||||
max_length: 最大长度(对于 MySQL 是必需的)
|
||||
|
||||
Returns:
|
||||
SQLAlchemy 类型
|
||||
"""
|
||||
return DialectAdapter.get_string_type(max_length, indexed=True)
|
||||
|
||||
|
||||
def get_text_field() -> TypeEngine:
|
||||
"""获取文本字段类型
|
||||
|
||||
用于不需要索引的大文本字段
|
||||
|
||||
Returns:
|
||||
SQLAlchemy Text 类型
|
||||
"""
|
||||
return Text()
|
||||
@@ -1,6 +1,11 @@
|
||||
"""数据库引擎管理
|
||||
|
||||
单一职责:创建和管理SQLAlchemy异步引擎
|
||||
|
||||
支持的数据库类型:
|
||||
- SQLite: 轻量级本地数据库,使用 aiosqlite 驱动
|
||||
- MySQL: 高性能关系型数据库,使用 aiomysql 驱动
|
||||
- PostgreSQL: 功能丰富的开源数据库,使用 asyncpg 驱动
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -13,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from ..utils.exceptions import DatabaseInitializationError
|
||||
from .dialect_adapter import DialectAdapter
|
||||
|
||||
logger = get_logger("database.engine")
|
||||
|
||||
@@ -52,11 +58,78 @@ async def get_engine() -> AsyncEngine:
|
||||
config = global_config.database
|
||||
db_type = config.database_type
|
||||
|
||||
# 初始化方言适配器
|
||||
DialectAdapter.initialize(db_type)
|
||||
|
||||
logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...")
|
||||
|
||||
# 构建数据库URL和引擎参数
|
||||
# 根据数据库类型构建URL和引擎参数
|
||||
if db_type == "mysql":
|
||||
# MySQL配置
|
||||
url, engine_kwargs = _build_mysql_config(config)
|
||||
elif db_type == "postgresql":
|
||||
url, engine_kwargs = _build_postgresql_config(config)
|
||||
else:
|
||||
url, engine_kwargs = _build_sqlite_config(config)
|
||||
|
||||
# 创建异步引擎
|
||||
_engine = create_async_engine(url, **engine_kwargs)
|
||||
|
||||
# 数据库特定优化
|
||||
if db_type == "sqlite":
|
||||
await _enable_sqlite_optimizations(_engine)
|
||||
elif db_type == "postgresql":
|
||||
await _enable_postgresql_optimizations(_engine)
|
||||
|
||||
logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功")
|
||||
return _engine
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 数据库引擎初始化失败: {e}")
|
||||
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
|
||||
|
||||
|
||||
def _build_sqlite_config(config) -> tuple[str, dict]:
|
||||
"""构建 SQLite 配置
|
||||
|
||||
Args:
|
||||
config: 数据库配置对象
|
||||
|
||||
Returns:
|
||||
(url, engine_kwargs) 元组
|
||||
"""
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
|
||||
url = f"sqlite+aiosqlite:///{db_path}"
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
"connect_args": {
|
||||
"check_same_thread": False,
|
||||
"timeout": 60,
|
||||
},
|
||||
}
|
||||
|
||||
logger.info(f"SQLite配置: {db_path}")
|
||||
return url, engine_kwargs
|
||||
|
||||
|
||||
def _build_mysql_config(config) -> tuple[str, dict]:
|
||||
"""构建 MySQL 配置
|
||||
|
||||
Args:
|
||||
config: 数据库配置对象
|
||||
|
||||
Returns:
|
||||
(url, engine_kwargs) 元组
|
||||
"""
|
||||
encoded_user = quote_plus(config.mysql_user)
|
||||
encoded_password = quote_plus(config.mysql_password)
|
||||
|
||||
@@ -94,44 +167,60 @@ async def get_engine() -> AsyncEngine:
|
||||
logger.info(
|
||||
f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
|
||||
)
|
||||
return url, engine_kwargs
|
||||
|
||||
else:
|
||||
# SQLite配置
|
||||
if not os.path.isabs(config.sqlite_path):
|
||||
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
|
||||
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
|
||||
else:
|
||||
db_path = config.sqlite_path
|
||||
|
||||
# 确保数据库目录存在
|
||||
os.makedirs(os.path.dirname(db_path), exist_ok=True)
|
||||
def _build_postgresql_config(config) -> tuple[str, dict]:
|
||||
"""构建 PostgreSQL 配置
|
||||
|
||||
url = f"sqlite+aiosqlite:///{db_path}"
|
||||
Args:
|
||||
config: 数据库配置对象
|
||||
|
||||
Returns:
|
||||
(url, engine_kwargs) 元组
|
||||
"""
|
||||
encoded_user = quote_plus(config.postgresql_user)
|
||||
encoded_password = quote_plus(config.postgresql_password)
|
||||
|
||||
# 构建基本 URL
|
||||
url = (
|
||||
f"postgresql+asyncpg://{encoded_user}:{encoded_password}"
|
||||
f"@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
|
||||
)
|
||||
|
||||
# SSL 配置
|
||||
connect_args = {}
|
||||
if config.postgresql_ssl_mode != "disable":
|
||||
ssl_config = {"ssl": config.postgresql_ssl_mode}
|
||||
if config.postgresql_ssl_ca:
|
||||
ssl_config["ssl_ca"] = config.postgresql_ssl_ca
|
||||
if config.postgresql_ssl_cert:
|
||||
ssl_config["ssl_cert"] = config.postgresql_ssl_cert
|
||||
if config.postgresql_ssl_key:
|
||||
ssl_config["ssl_key"] = config.postgresql_ssl_key
|
||||
connect_args.update(ssl_config)
|
||||
|
||||
# 设置 schema(如果不是 public)
|
||||
if config.postgresql_schema and config.postgresql_schema != "public":
|
||||
connect_args["server_settings"] = {"search_path": config.postgresql_schema}
|
||||
|
||||
engine_kwargs = {
|
||||
"echo": False,
|
||||
"future": True,
|
||||
"connect_args": {
|
||||
"check_same_thread": False,
|
||||
"timeout": 60,
|
||||
},
|
||||
"pool_size": config.connection_pool_size,
|
||||
"max_overflow": config.connection_pool_size * 2,
|
||||
"pool_timeout": config.connection_timeout,
|
||||
"pool_recycle": 3600,
|
||||
"pool_pre_ping": True,
|
||||
}
|
||||
|
||||
logger.info(f"SQLite配置: {db_path}")
|
||||
if connect_args:
|
||||
engine_kwargs["connect_args"] = connect_args
|
||||
|
||||
# 创建异步引擎
|
||||
_engine = create_async_engine(url, **engine_kwargs)
|
||||
|
||||
# SQLite特定优化
|
||||
if db_type == "sqlite":
|
||||
await _enable_sqlite_optimizations(_engine)
|
||||
|
||||
logger.info(f"✅ {db_type.upper()} 数据库引擎初始化成功")
|
||||
return _engine
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 数据库引擎初始化失败: {e}")
|
||||
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
|
||||
logger.info(
|
||||
f"PostgreSQL配置: {config.postgresql_user}@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
|
||||
)
|
||||
return url, engine_kwargs
|
||||
|
||||
|
||||
async def close_engine():
|
||||
@@ -181,6 +270,33 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine):
|
||||
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
|
||||
|
||||
|
||||
async def _enable_postgresql_optimizations(engine: AsyncEngine):
|
||||
"""启用PostgreSQL性能优化
|
||||
|
||||
优化项:
|
||||
- 设置合适的 work_mem
|
||||
- 启用 JIT 编译(如果可用)
|
||||
- 设置合适的 statement_timeout
|
||||
|
||||
Args:
|
||||
engine: SQLAlchemy异步引擎
|
||||
"""
|
||||
try:
|
||||
async with engine.begin() as conn:
|
||||
# 设置会话级别的参数
|
||||
# work_mem: 排序和哈希操作的内存(64MB)
|
||||
await conn.execute(text("SET work_mem = '64MB'"))
|
||||
# 设置语句超时(5分钟)
|
||||
await conn.execute(text("SET statement_timeout = '300000'"))
|
||||
# 启用自动 EXPLAIN(可选,用于调试)
|
||||
# await conn.execute(text("SET auto_explain.log_min_duration = '1000'"))
|
||||
|
||||
logger.info("✅ PostgreSQL性能优化已启用")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"⚠️ PostgreSQL性能优化失败: {e},将使用默认配置")
|
||||
|
||||
|
||||
async def get_engine_info() -> dict:
|
||||
"""获取引擎信息(用于监控和调试)
|
||||
|
||||
|
||||
@@ -99,12 +99,17 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
|
||||
def add_columns_sync(conn):
|
||||
dialect = conn.dialect
|
||||
compiler = dialect.ddl_compiler(dialect, None)
|
||||
|
||||
for column_name in missing_columns:
|
||||
column = table.c[column_name]
|
||||
column_type = compiler.get_column_specification(column)
|
||||
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}"
|
||||
|
||||
# 获取列类型的 SQL 表示
|
||||
# 使用 compile 方法获取正确的类型字符串
|
||||
type_compiler = dialect.type_compiler(dialect)
|
||||
column_type_sql = column.type.compile(dialect=dialect)
|
||||
|
||||
# 构建 ALTER TABLE 语句
|
||||
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type_sql}"
|
||||
|
||||
if column.default:
|
||||
# 手动处理不同方言的默认值
|
||||
@@ -114,26 +119,18 @@ async def check_and_migrate_database(existing_engine=None):
|
||||
):
|
||||
# SQLite 将布尔值存储为 0 或 1
|
||||
default_value = "1" if default_arg else "0"
|
||||
elif hasattr(compiler, "render_literal_value"):
|
||||
try:
|
||||
# 尝试使用 render_literal_value
|
||||
default_value = compiler.render_literal_value(
|
||||
default_arg, column.type
|
||||
)
|
||||
except AttributeError:
|
||||
# 如果失败,则回退到简单的字符串转换
|
||||
default_value = (
|
||||
f"'{default_arg}'"
|
||||
if isinstance(default_arg, str)
|
||||
else str(default_arg)
|
||||
)
|
||||
elif dialect.name == "mysql" and isinstance(default_arg, bool):
|
||||
# MySQL 也使用 1/0 表示布尔值
|
||||
default_value = "1" if default_arg else "0"
|
||||
elif isinstance(default_arg, bool):
|
||||
# PostgreSQL 使用 TRUE/FALSE
|
||||
default_value = "TRUE" if default_arg else "FALSE"
|
||||
elif isinstance(default_arg, str):
|
||||
default_value = f"'{default_arg}'"
|
||||
elif default_arg is None:
|
||||
default_value = "NULL"
|
||||
else:
|
||||
# 对于没有 render_literal_value 的旧版或特定方言
|
||||
default_value = (
|
||||
f"'{default_arg}'"
|
||||
if isinstance(default_arg, str)
|
||||
else str(default_arg)
|
||||
)
|
||||
default_value = str(default_arg)
|
||||
|
||||
sql += f" DEFAULT {default_value}"
|
||||
|
||||
|
||||
@@ -3,6 +3,11 @@
|
||||
本文件只包含纯模型定义,使用SQLAlchemy 2.0的Mapped类型注解风格。
|
||||
引擎和会话管理已移至core/engine.py和core/session.py。
|
||||
|
||||
支持的数据库类型:
|
||||
- SQLite: 使用 Text 类型
|
||||
- MySQL: 使用 VARCHAR(max_length) 用于索引字段
|
||||
- PostgreSQL: 使用 Text 类型(PostgreSQL 的 Text 类型性能与 VARCHAR 相当)
|
||||
|
||||
所有模型使用统一的类型注解风格:
|
||||
field_name: Mapped[PyType] = mapped_column(Type, ...)
|
||||
|
||||
@@ -20,16 +25,34 @@ from sqlalchemy.orm import Mapped, mapped_column
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
# MySQL兼容的字段类型辅助函数
|
||||
# 数据库兼容的字段类型辅助函数
|
||||
def get_string_field(max_length=255, **kwargs):
|
||||
"""
|
||||
根据数据库类型返回合适的字符串字段
|
||||
MySQL需要指定长度的VARCHAR用于索引,SQLite可以使用Text
|
||||
根据数据库类型返回合适的字符串字段类型
|
||||
|
||||
对于需要索引的字段:
|
||||
- MySQL: 必须使用 VARCHAR(max_length),因为索引需要指定长度
|
||||
- PostgreSQL: 可以使用 Text,但为了兼容性使用 VARCHAR
|
||||
- SQLite: 可以使用 Text,无长度限制
|
||||
|
||||
Args:
|
||||
max_length: 最大长度(对于 MySQL 是必需的)
|
||||
**kwargs: 传递给 String/Text 的额外参数
|
||||
|
||||
Returns:
|
||||
SQLAlchemy 类型
|
||||
"""
|
||||
from src.config.config import global_config
|
||||
|
||||
if global_config.database.database_type == "mysql":
|
||||
db_type = global_config.database.database_type
|
||||
|
||||
# MySQL 索引需要指定长度的 VARCHAR
|
||||
if db_type == "mysql":
|
||||
return String(max_length, **kwargs)
|
||||
# PostgreSQL 可以使用 Text,但为了跨数据库迁移兼容性,使用 VARCHAR
|
||||
elif db_type == "postgresql":
|
||||
return String(max_length, **kwargs)
|
||||
# SQLite 使用 Text(无长度限制)
|
||||
else:
|
||||
return Text(**kwargs)
|
||||
|
||||
@@ -477,7 +500,7 @@ class BanUser(Base):
|
||||
__tablename__ = "ban_users"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
platform: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
platform: Mapped[str] = mapped_column(get_string_field(50), nullable=False) # 使用有限长度,以便创建索引
|
||||
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
|
||||
violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
|
||||
reason: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
@@ -1,6 +1,11 @@
|
||||
"""数据库会话管理
|
||||
|
||||
单一职责:提供数据库会话工厂和上下文管理器
|
||||
|
||||
支持的数据库类型:
|
||||
- SQLite: 设置 PRAGMA 参数优化并发
|
||||
- MySQL: 无特殊会话设置
|
||||
- PostgreSQL: 可选设置 schema 搜索路径
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -53,12 +58,43 @@ async def get_session_factory() -> async_sessionmaker:
|
||||
return _session_factory
|
||||
|
||||
|
||||
async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
|
||||
"""应用数据库特定的会话设置
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
db_type: 数据库类型
|
||||
"""
|
||||
try:
|
||||
if db_type == "sqlite":
|
||||
# SQLite 特定的 PRAGMA 设置
|
||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||
await session.execute(text("PRAGMA foreign_keys = ON"))
|
||||
elif db_type == "postgresql":
|
||||
# PostgreSQL 特定设置(如果需要)
|
||||
# 可以设置 schema 搜索路径等
|
||||
from src.config.config import global_config
|
||||
|
||||
schema = global_config.database.postgresql_schema
|
||||
if schema and schema != "public":
|
||||
await session.execute(text(f"SET search_path TO {schema}"))
|
||||
# MySQL 通常不需要会话级别的特殊设置
|
||||
except Exception:
|
||||
# 复用连接时设置可能已存在,忽略错误
|
||||
pass
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
"""获取数据库会话上下文管理器
|
||||
|
||||
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
|
||||
|
||||
支持的数据库:
|
||||
- SQLite: 自动设置 busy_timeout 和外键约束
|
||||
- MySQL: 直接使用,无特殊设置
|
||||
- PostgreSQL: 支持自定义 schema
|
||||
|
||||
使用示例:
|
||||
async with get_db_session() as session:
|
||||
result = await session.execute(select(User))
|
||||
@@ -75,16 +111,10 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
# 使用连接池管理器(透明复用连接)
|
||||
async with pool_manager.get_session(session_factory) as session:
|
||||
# 为SQLite设置特定的PRAGMA
|
||||
# 获取数据库类型并应用特定设置
|
||||
from src.config.config import global_config
|
||||
|
||||
if global_config.database.database_type == "sqlite":
|
||||
try:
|
||||
await session.execute(text("PRAGMA busy_timeout = 60000"))
|
||||
await session.execute(text("PRAGMA foreign_keys = ON"))
|
||||
except Exception:
|
||||
# 复用连接时PRAGMA可能已设置,忽略错误
|
||||
pass
|
||||
await _apply_session_settings(session, global_config.database.database_type)
|
||||
|
||||
yield session
|
||||
|
||||
@@ -103,6 +133,11 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
|
||||
|
||||
async with session_factory() as session:
|
||||
try:
|
||||
# 应用数据库特定设置
|
||||
from src.config.config import global_config
|
||||
|
||||
await _apply_session_settings(session, global_config.database.database_type)
|
||||
|
||||
yield session
|
||||
except Exception:
|
||||
await session.rollback()
|
||||
|
||||
@@ -373,8 +373,14 @@ class AdaptiveBatchScheduler:
|
||||
"""批量执行插入操作"""
|
||||
async with get_db_session() as session:
|
||||
try:
|
||||
# 收集数据
|
||||
all_data = [op.data for op in operations if op.data]
|
||||
# 收集数据,并过滤掉 id=None 的情况(让数据库自动生成)
|
||||
all_data = []
|
||||
for op in operations:
|
||||
if op.data:
|
||||
# 过滤掉 id 为 None 的键,让数据库自动生成主键
|
||||
filtered_data = {k: v for k, v in op.data.items() if not (k == "id" and v is None)}
|
||||
all_data.append(filtered_data)
|
||||
|
||||
if not all_data:
|
||||
return
|
||||
|
||||
|
||||
@@ -16,8 +16,10 @@ from src.config.config_base import ValidatedConfigBase
|
||||
class DatabaseConfig(ValidatedConfigBase):
|
||||
"""数据库配置类"""
|
||||
|
||||
database_type: Literal["sqlite", "mysql"] = Field(default="sqlite", description="数据库类型")
|
||||
database_type: Literal["sqlite", "mysql", "postgresql"] = Field(default="sqlite", description="数据库类型")
|
||||
sqlite_path: str = Field(default="data/MaiBot.db", description="SQLite数据库文件路径")
|
||||
|
||||
# MySQL 配置
|
||||
mysql_host: str = Field(default="localhost", description="MySQL服务器地址")
|
||||
mysql_port: int = Field(default=3306, ge=1, le=65535, description="MySQL服务器端口")
|
||||
mysql_database: str = Field(default="maibot", description="MySQL数据库名")
|
||||
@@ -33,6 +35,22 @@ class DatabaseConfig(ValidatedConfigBase):
|
||||
mysql_ssl_key: str = Field(default="", description="SSL密钥路径")
|
||||
mysql_autocommit: bool = Field(default=True, description="自动提交事务")
|
||||
mysql_sql_mode: str = Field(default="TRADITIONAL", description="SQL模式")
|
||||
|
||||
# PostgreSQL 配置
|
||||
postgresql_host: str = Field(default="localhost", description="PostgreSQL服务器地址")
|
||||
postgresql_port: int = Field(default=5432, ge=1, le=65535, description="PostgreSQL服务器端口")
|
||||
postgresql_database: str = Field(default="maibot", description="PostgreSQL数据库名")
|
||||
postgresql_user: str = Field(default="postgres", description="PostgreSQL用户名")
|
||||
postgresql_password: str = Field(default="", description="PostgreSQL密码")
|
||||
postgresql_schema: str = Field(default="public", description="PostgreSQL模式名")
|
||||
postgresql_ssl_mode: Literal["disable", "allow", "prefer", "require", "verify-ca", "verify-full"] = Field(
|
||||
default="prefer", description="PostgreSQL SSL模式"
|
||||
)
|
||||
postgresql_ssl_ca: str = Field(default="", description="PostgreSQL SSL CA证书路径")
|
||||
postgresql_ssl_cert: str = Field(default="", description="PostgreSQL SSL客户端证书路径")
|
||||
postgresql_ssl_key: str = Field(default="", description="PostgreSQL SSL密钥路径")
|
||||
|
||||
# 通用连接池配置
|
||||
connection_pool_size: int = Field(default=10, ge=1, description="连接池大小")
|
||||
connection_timeout: int = Field(default=10, ge=1, description="连接超时时间")
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[inner]
|
||||
version = "7.8.3"
|
||||
version = "7.9.0"
|
||||
|
||||
#----以下是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||
#如果你想要修改配置文件,请递增version的值
|
||||
@@ -12,7 +12,7 @@ version = "7.8.3"
|
||||
#----以上是给开发人员阅读的,如果你只是部署了MoFox-Bot,不需要阅读----
|
||||
|
||||
[database]# 数据库配置
|
||||
database_type = "sqlite" # 数据库类型,支持 "sqlite" 或 "mysql"
|
||||
database_type = "sqlite" # 数据库类型,支持 "sqlite"、"mysql" 或 "postgresql"
|
||||
|
||||
# SQLite 配置(当 database_type = "sqlite" 时使用)
|
||||
sqlite_path = "data/MaiBot.db" # SQLite数据库文件路径
|
||||
@@ -36,8 +36,22 @@ mysql_ssl_key = "" # SSL客户端密钥路径
|
||||
mysql_autocommit = true # 自动提交事务
|
||||
mysql_sql_mode = "TRADITIONAL" # SQL模式
|
||||
|
||||
# 连接池配置
|
||||
connection_pool_size = 10 # 连接池大小(仅MySQL有效)
|
||||
# PostgreSQL 配置(当 database_type = "postgresql" 时使用)
|
||||
postgresql_host = "localhost" # PostgreSQL服务器地址
|
||||
postgresql_port = 5432 # PostgreSQL服务器端口
|
||||
postgresql_database = "maibot" # PostgreSQL数据库名
|
||||
postgresql_user = "postgres" # PostgreSQL用户名
|
||||
postgresql_password = "" # PostgreSQL密码
|
||||
postgresql_schema = "public" # PostgreSQL模式名(schema)
|
||||
|
||||
# PostgreSQL SSL 配置
|
||||
postgresql_ssl_mode = "prefer" # SSL模式: disable, allow, prefer, require, verify-ca, verify-full
|
||||
postgresql_ssl_ca = "" # SSL CA证书路径
|
||||
postgresql_ssl_cert = "" # SSL客户端证书路径
|
||||
postgresql_ssl_key = "" # SSL客户端密钥路径
|
||||
|
||||
# 连接池配置(MySQL 和 PostgreSQL 有效)
|
||||
connection_pool_size = 10 # 连接池大小
|
||||
connection_timeout = 10 # 连接超时时间(秒)
|
||||
|
||||
# 批量动作记录存储配置
|
||||
|
||||
Reference in New Issue
Block a user