feat: 更新机器人配置并添加数据库迁移脚本
- 将bot_config_template.toml中的版本升级至7.9.0 - 增强数据库配置选项以支持PostgreSQL - 引入一个新脚本,用于在SQLite、MySQL和PostgreSQL之间迁移数据 - 实现一个方言适配器,用于处理特定于数据库的行为和配置
This commit is contained in:
@@ -1,117 +0,0 @@
|
||||
"""
|
||||
检查表达方式数据库状态的诊断脚本
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy import func, select
|
||||
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Expression
|
||||
|
||||
|
||||
async def check_database():
|
||||
"""检查表达方式数据库状态"""
|
||||
|
||||
print("=" * 60)
|
||||
print("表达方式数据库诊断报告")
|
||||
print("=" * 60)
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 1. 统计总数
|
||||
total_count = await session.execute(select(func.count()).select_from(Expression))
|
||||
total = total_count.scalar()
|
||||
print(f"\n📊 总表达方式数量: {total}")
|
||||
|
||||
if total == 0:
|
||||
print("\n⚠️ 数据库为空!")
|
||||
print("\n可能的原因:")
|
||||
print("1. 还没有进行过表达学习")
|
||||
print("2. 配置中禁用了表达学习")
|
||||
print("3. 学习过程中发生了错误")
|
||||
print("\n建议:")
|
||||
print("- 检查 bot_config.toml 中的 [expression] 配置")
|
||||
print("- 查看日志中是否有表达学习相关的错误")
|
||||
print("- 确认聊天流的 learn_expression 配置为 true")
|
||||
return
|
||||
|
||||
# 2. 按 chat_id 统计
|
||||
print("\n📝 按聊天流统计:")
|
||||
chat_counts = await session.execute(
|
||||
select(Expression.chat_id, func.count())
|
||||
.group_by(Expression.chat_id)
|
||||
)
|
||||
for chat_id, count in chat_counts:
|
||||
print(f" - {chat_id}: {count} 个表达方式")
|
||||
|
||||
# 3. 按 type 统计
|
||||
print("\n📝 按类型统计:")
|
||||
type_counts = await session.execute(
|
||||
select(Expression.type, func.count())
|
||||
.group_by(Expression.type)
|
||||
)
|
||||
for expr_type, count in type_counts:
|
||||
print(f" - {expr_type}: {count} 个")
|
||||
|
||||
# 4. 检查 situation 和 style 字段是否有空值
|
||||
print("\n🔍 字段完整性检查:")
|
||||
null_situation = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(Expression)
|
||||
.where(Expression.situation is None)
|
||||
)
|
||||
null_style = await session.execute(
|
||||
select(func.count())
|
||||
.select_from(Expression)
|
||||
.where(Expression.style is None)
|
||||
)
|
||||
|
||||
null_sit_count = null_situation.scalar()
|
||||
null_sty_count = null_style.scalar()
|
||||
|
||||
print(f" - situation 为空: {null_sit_count} 个")
|
||||
print(f" - style 为空: {null_sty_count} 个")
|
||||
|
||||
if null_sit_count > 0 or null_sty_count > 0:
|
||||
print(" ⚠️ 发现空值!这会导致匹配失败")
|
||||
|
||||
# 5. 显示一些样例数据
|
||||
print("\n📋 样例数据 (前10条):")
|
||||
samples = await session.execute(
|
||||
select(Expression)
|
||||
.limit(10)
|
||||
)
|
||||
|
||||
for i, expr in enumerate(samples.scalars(), 1):
|
||||
print(f"\n [{i}] Chat: {expr.chat_id}")
|
||||
print(f" Type: {expr.type}")
|
||||
print(f" Situation: {expr.situation}")
|
||||
print(f" Style: {expr.style}")
|
||||
print(f" Count: {expr.count}")
|
||||
|
||||
# 6. 检查 style 字段的唯一值
|
||||
print("\n📋 Style 字段样例 (前20个):")
|
||||
unique_styles = await session.execute(
|
||||
select(Expression.style)
|
||||
.distinct()
|
||||
.limit(20)
|
||||
)
|
||||
|
||||
styles = list(unique_styles.scalars())
|
||||
for style in styles:
|
||||
print(f" - {style}")
|
||||
|
||||
print(f"\n (共 {len(styles)} 个不同的 style)")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("诊断完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(check_database())
|
||||
@@ -1,65 +0,0 @@
|
||||
"""
|
||||
检查数据库中 style 字段的内容特征
|
||||
"""
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.common.database.compatibility import get_db_session
|
||||
from src.common.database.core.models import Expression
|
||||
|
||||
|
||||
async def analyze_style_fields():
|
||||
"""分析 style 字段的内容"""
|
||||
|
||||
print("=" * 60)
|
||||
print("Style 字段内容分析")
|
||||
print("=" * 60)
|
||||
|
||||
async with get_db_session() as session:
|
||||
# 获取所有表达方式
|
||||
result = await session.execute(select(Expression).limit(30))
|
||||
expressions = result.scalars().all()
|
||||
|
||||
print(f"\n总共检查 {len(expressions)} 条记录\n")
|
||||
|
||||
# 按类型分类
|
||||
style_examples = [
|
||||
{
|
||||
"situation": expr.situation,
|
||||
"style": expr.style,
|
||||
"length": len(expr.style) if expr.style else 0
|
||||
}
|
||||
for expr in expressions if expr.type == "style"
|
||||
]
|
||||
|
||||
print("📋 Style 类型样例 (前15条):")
|
||||
print("="*60)
|
||||
for i, ex in enumerate(style_examples[:15], 1):
|
||||
print(f"\n[{i}]")
|
||||
print(f" Situation: {ex['situation']}")
|
||||
print(f" Style: {ex['style']}")
|
||||
print(f" 长度: {ex['length']} 字符")
|
||||
|
||||
# 判断是具体表达还是风格描述
|
||||
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
|
||||
style_type = "✓ 风格描述"
|
||||
elif ex["length"] <= 10:
|
||||
style_type = "? 可能是具体表达(较短)"
|
||||
else:
|
||||
style_type = "✗ 具体表达内容"
|
||||
|
||||
print(f" 类型判断: {style_type}")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("分析完成")
|
||||
print("="*60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(analyze_style_fields())
|
||||
@@ -1,49 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""清理 core/models.py,只保留模型定义"""
|
||||
|
||||
import os
|
||||
|
||||
# 文件路径
|
||||
models_file = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)),
|
||||
"src",
|
||||
"common",
|
||||
"database",
|
||||
"core",
|
||||
"models.py"
|
||||
)
|
||||
|
||||
print(f"正在清理文件: {models_file}")
|
||||
|
||||
# 读取文件
|
||||
with open(models_file, encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
# 找到最后一个模型类的结束位置(MonthlyPlan的 __table_args__ 结束)
|
||||
# 我们要保留到第593行(包含)
|
||||
keep_lines = []
|
||||
found_end = False
|
||||
|
||||
for i, line in enumerate(lines, 1):
|
||||
keep_lines.append(line)
|
||||
|
||||
# 检查是否到达 MonthlyPlan 的 __table_args__ 结束
|
||||
if i > 580 and line.strip() == ")":
|
||||
# 再检查前一行是否有 Index 相关内容
|
||||
if "idx_monthlyplan" in "".join(lines[max(0, i-5):i]):
|
||||
print(f"找到模型定义结束位置: 第 {i} 行")
|
||||
found_end = True
|
||||
break
|
||||
|
||||
if not found_end:
|
||||
print("❌ 未找到模型定义结束标记")
|
||||
exit(1)
|
||||
|
||||
# 写回文件
|
||||
with open(models_file, "w", encoding="utf-8") as f:
|
||||
f.writelines(keep_lines)
|
||||
|
||||
print("✅ 文件清理完成")
|
||||
print(f"保留行数: {len(keep_lines)}")
|
||||
print(f"原始行数: {len(lines)}")
|
||||
print(f"删除行数: {len(lines) - len(keep_lines)}")
|
||||
@@ -1,59 +0,0 @@
|
||||
"""
|
||||
调试 MCP 工具列表获取
|
||||
|
||||
直接测试 MCP 客户端是否能获取工具
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from fastmcp.client import Client, StreamableHttpTransport
|
||||
|
||||
|
||||
async def test_direct_connection():
|
||||
"""直接连接 MCP 服务器并获取工具列表"""
|
||||
print("=" * 60)
|
||||
print("直接测试 MCP 服务器连接")
|
||||
print("=" * 60)
|
||||
|
||||
url = "http://localhost:8000/mcp"
|
||||
print(f"\n连接到: {url}")
|
||||
|
||||
try:
|
||||
# 创建传输层
|
||||
transport = StreamableHttpTransport(url)
|
||||
print("✓ 传输层创建成功")
|
||||
|
||||
# 创建客户端
|
||||
async with Client(transport) as client:
|
||||
print("✓ 客户端连接成功")
|
||||
|
||||
# 获取工具列表
|
||||
print("\n正在获取工具列表...")
|
||||
tools_result = await client.list_tools()
|
||||
|
||||
print(f"\n获取结果类型: {type(tools_result)}")
|
||||
print(f"结果内容: {tools_result}")
|
||||
|
||||
# 检查是否有 tools 属性
|
||||
if hasattr(tools_result, "tools"):
|
||||
tools = tools_result.tools
|
||||
print(f"\n✓ 找到 tools 属性,包含 {len(tools)} 个工具")
|
||||
|
||||
for i, tool in enumerate(tools, 1):
|
||||
print(f"\n工具 {i}:")
|
||||
print(f" 名称: {tool.name}")
|
||||
print(f" 描述: {tool.description}")
|
||||
if hasattr(tool, "inputSchema"):
|
||||
print(f" 参数 Schema: {tool.inputSchema}")
|
||||
else:
|
||||
print("\n✗ 结果中没有 tools 属性")
|
||||
print(f"可用属性: {dir(tools_result)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 连接失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(test_direct_connection())
|
||||
@@ -1,88 +0,0 @@
|
||||
"""
|
||||
检查 StyleLearner 模型状态的诊断脚本
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.chat.express.style_learner import style_learner_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger("debug_style_learner")
|
||||
|
||||
|
||||
def check_style_learner_status(chat_id: str):
|
||||
"""检查指定 chat_id 的 StyleLearner 状态"""
|
||||
|
||||
print("=" * 60)
|
||||
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
|
||||
print("=" * 60)
|
||||
|
||||
# 获取 learner
|
||||
learner = style_learner_manager.get_learner(chat_id)
|
||||
|
||||
# 1. 基本信息
|
||||
print("\n📊 基本信息:")
|
||||
print(f" Chat ID: {learner.chat_id}")
|
||||
print(f" 风格数量: {len(learner.style_to_id)}")
|
||||
print(f" 下一个ID: {learner.next_style_id}")
|
||||
print(f" 最大风格数: {learner.max_styles}")
|
||||
|
||||
# 2. 学习统计
|
||||
print("\n📈 学习统计:")
|
||||
print(f" 总样本数: {learner.learning_stats['total_samples']}")
|
||||
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
|
||||
|
||||
# 3. 风格列表(前20个)
|
||||
print("\n📋 已学习的风格 (前20个):")
|
||||
all_styles = learner.get_all_styles()
|
||||
if not all_styles:
|
||||
print(" ⚠️ 没有任何风格!模型尚未训练")
|
||||
else:
|
||||
for i, style in enumerate(all_styles[:20], 1):
|
||||
style_id = learner.style_to_id.get(style)
|
||||
situation = learner.id_to_situation.get(style_id, "N/A")
|
||||
print(f" [{i}] {style}")
|
||||
print(f" (ID: {style_id}, Situation: {situation})")
|
||||
|
||||
# 4. 测试预测
|
||||
print("\n🔮 测试预测功能:")
|
||||
if not all_styles:
|
||||
print(" ⚠️ 无法测试,模型没有训练数据")
|
||||
else:
|
||||
test_situations = [
|
||||
"表示惊讶",
|
||||
"讨论游戏",
|
||||
"表达赞同"
|
||||
]
|
||||
|
||||
for test_sit in test_situations:
|
||||
print(f"\n 测试输入: '{test_sit}'")
|
||||
best_style, scores = learner.predict_style(test_sit, top_k=3)
|
||||
|
||||
if best_style:
|
||||
print(f" ✓ 最佳匹配: {best_style}")
|
||||
print(" Top 3:")
|
||||
for style, score in list(scores.items())[:3]:
|
||||
print(f" - {style}: {score:.4f}")
|
||||
else:
|
||||
print(" ✗ 预测失败")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("诊断完成")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 从诊断报告中看到的 chat_id
|
||||
test_chat_ids = [
|
||||
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
|
||||
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
|
||||
]
|
||||
|
||||
for chat_id in test_chat_ids:
|
||||
check_style_learner_status(chat_id)
|
||||
print("\n")
|
||||
@@ -1,403 +0,0 @@
|
||||
"""
|
||||
记忆去重工具
|
||||
|
||||
功能:
|
||||
1. 扫描所有标记为"相似"关系的记忆边
|
||||
2. 对相似记忆进行去重(保留重要性高的,删除另一个)
|
||||
3. 支持干运行模式(预览不执行)
|
||||
4. 提供详细的去重报告
|
||||
|
||||
使用方法:
|
||||
# 预览模式(不实际删除)
|
||||
python scripts/deduplicate_memories.py --dry-run
|
||||
|
||||
# 执行去重
|
||||
python scripts/deduplicate_memories.py
|
||||
|
||||
# 指定相似度阈值
|
||||
python scripts/deduplicate_memories.py --threshold 0.9
|
||||
|
||||
# 指定数据目录
|
||||
python scripts/deduplicate_memories.py --data-dir data/memory_graph
|
||||
"""
|
||||
import argparse
|
||||
import asyncio
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MemoryDeduplicator:
|
||||
"""记忆去重器"""
|
||||
|
||||
def __init__(self, data_dir: str = "data/memory_graph", dry_run: bool = False, threshold: float = 0.85):
|
||||
self.data_dir = data_dir
|
||||
self.dry_run = dry_run
|
||||
self.threshold = threshold
|
||||
self.manager = None
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total_memories": 0,
|
||||
"similar_pairs": 0,
|
||||
"duplicates_found": 0,
|
||||
"duplicates_removed": 0,
|
||||
"errors": 0,
|
||||
}
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化记忆管理器"""
|
||||
logger.info(f"正在初始化记忆管理器 (data_dir={self.data_dir})...")
|
||||
self.manager = await initialize_memory_manager(data_dir=self.data_dir)
|
||||
if not self.manager:
|
||||
raise RuntimeError("记忆管理器初始化失败")
|
||||
|
||||
self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories())
|
||||
logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆")
|
||||
|
||||
async def find_similar_pairs(self) -> list[tuple[str, str, float]]:
|
||||
"""
|
||||
查找所有相似的记忆对(通过向量相似度计算)
|
||||
|
||||
Returns:
|
||||
[(memory_id_1, memory_id_2, similarity), ...]
|
||||
"""
|
||||
logger.info("正在扫描相似记忆对...")
|
||||
similar_pairs = []
|
||||
seen_pairs = set() # 避免重复
|
||||
|
||||
# 获取所有记忆
|
||||
all_memories = self.manager.graph_store.get_all_memories()
|
||||
total_memories = len(all_memories)
|
||||
|
||||
logger.info(f"开始计算 {total_memories} 条记忆的相似度...")
|
||||
|
||||
# 两两比较记忆的相似度
|
||||
for i, memory_i in enumerate(all_memories):
|
||||
# 每处理10条记忆让出控制权
|
||||
if i % 10 == 0:
|
||||
await asyncio.sleep(0)
|
||||
if i > 0:
|
||||
logger.info(f"进度: {i}/{total_memories} ({i*100//total_memories}%)")
|
||||
|
||||
# 获取记忆i的向量(从主题节点)
|
||||
vector_i = None
|
||||
for node in memory_i.nodes:
|
||||
if node.embedding is not None:
|
||||
vector_i = node.embedding
|
||||
break
|
||||
|
||||
if vector_i is None:
|
||||
continue
|
||||
|
||||
# 与后续记忆比较
|
||||
for j in range(i + 1, total_memories):
|
||||
memory_j = all_memories[j]
|
||||
|
||||
# 获取记忆j的向量
|
||||
vector_j = None
|
||||
for node in memory_j.nodes:
|
||||
if node.embedding is not None:
|
||||
vector_j = node.embedding
|
||||
break
|
||||
|
||||
if vector_j is None:
|
||||
continue
|
||||
|
||||
# 计算余弦相似度
|
||||
similarity = self._cosine_similarity(vector_i, vector_j)
|
||||
|
||||
# 只保存满足阈值的相似对
|
||||
if similarity >= self.threshold:
|
||||
pair_key = tuple(sorted([memory_i.id, memory_j.id]))
|
||||
if pair_key not in seen_pairs:
|
||||
seen_pairs.add(pair_key)
|
||||
similar_pairs.append((memory_i.id, memory_j.id, similarity))
|
||||
|
||||
self.stats["similar_pairs"] = len(similar_pairs)
|
||||
logger.info(f"找到 {len(similar_pairs)} 对相似记忆(阈值>={self.threshold})")
|
||||
|
||||
return similar_pairs
|
||||
|
||||
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
|
||||
"""计算余弦相似度"""
|
||||
try:
|
||||
vec1_norm = np.linalg.norm(vec1)
|
||||
vec2_norm = np.linalg.norm(vec2)
|
||||
|
||||
if vec1_norm == 0 or vec2_norm == 0:
|
||||
return 0.0
|
||||
|
||||
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
|
||||
return float(similarity)
|
||||
except Exception as e:
|
||||
logger.error(f"计算余弦相似度失败: {e}")
|
||||
return 0.0
|
||||
|
||||
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]:
|
||||
"""
|
||||
决定保留哪个记忆,删除哪个
|
||||
|
||||
优先级:
|
||||
1. 重要性更高的
|
||||
2. 激活度更高的
|
||||
3. 创建时间更早的
|
||||
|
||||
Returns:
|
||||
(keep_id, remove_id)
|
||||
"""
|
||||
mem1 = self.manager.graph_store.get_memory_by_id(mem_id_1)
|
||||
mem2 = self.manager.graph_store.get_memory_by_id(mem_id_2)
|
||||
|
||||
if not mem1 or not mem2:
|
||||
logger.warning(f"记忆不存在: {mem_id_1} or {mem_id_2}")
|
||||
return None, None
|
||||
|
||||
# 比较重要性
|
||||
if mem1.importance > mem2.importance:
|
||||
return mem_id_1, mem_id_2
|
||||
elif mem1.importance < mem2.importance:
|
||||
return mem_id_2, mem_id_1
|
||||
|
||||
# 重要性相同,比较激活度
|
||||
if mem1.activation > mem2.activation:
|
||||
return mem_id_1, mem_id_2
|
||||
elif mem1.activation < mem2.activation:
|
||||
return mem_id_2, mem_id_1
|
||||
|
||||
# 激活度也相同,保留更早创建的
|
||||
if mem1.created_at < mem2.created_at:
|
||||
return mem_id_1, mem_id_2
|
||||
else:
|
||||
return mem_id_2, mem_id_1
|
||||
|
||||
async def deduplicate_pair(self, mem_id_1: str, mem_id_2: str, similarity: float) -> bool:
|
||||
"""
|
||||
去重一对相似记忆
|
||||
|
||||
Returns:
|
||||
是否成功去重
|
||||
"""
|
||||
keep_id, remove_id = self.decide_which_to_keep(mem_id_1, mem_id_2)
|
||||
|
||||
if not keep_id or not remove_id:
|
||||
self.stats["errors"] += 1
|
||||
return False
|
||||
|
||||
keep_mem = self.manager.graph_store.get_memory_by_id(keep_id)
|
||||
remove_mem = self.manager.graph_store.get_memory_by_id(remove_id)
|
||||
|
||||
logger.info("")
|
||||
logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):")
|
||||
logger.info(f" 保留: {keep_id}")
|
||||
logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}")
|
||||
logger.info(f" - 重要性: {keep_mem.importance:.2f}")
|
||||
logger.info(f" - 激活度: {keep_mem.activation:.2f}")
|
||||
logger.info(f" - 创建时间: {keep_mem.created_at}")
|
||||
logger.info(f" 删除: {remove_id}")
|
||||
logger.info(f" - 主题: {remove_mem.metadata.get('topic', 'N/A')}")
|
||||
logger.info(f" - 重要性: {remove_mem.importance:.2f}")
|
||||
logger.info(f" - 激活度: {remove_mem.activation:.2f}")
|
||||
logger.info(f" - 创建时间: {remove_mem.created_at}")
|
||||
|
||||
if self.dry_run:
|
||||
logger.info(" [预览模式] 不执行实际删除")
|
||||
self.stats["duplicates_found"] += 1
|
||||
return True
|
||||
|
||||
try:
|
||||
# 增强保留记忆的属性
|
||||
keep_mem.importance = min(1.0, keep_mem.importance + 0.05)
|
||||
keep_mem.activation = min(1.0, keep_mem.activation + 0.05)
|
||||
|
||||
# 累加访问次数
|
||||
if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
|
||||
keep_mem.access_count += remove_mem.access_count
|
||||
|
||||
# 删除相似记忆
|
||||
await self.manager.delete_memory(remove_id)
|
||||
|
||||
self.stats["duplicates_removed"] += 1
|
||||
logger.info(" ✅ 删除成功")
|
||||
|
||||
# 让出控制权
|
||||
await asyncio.sleep(0)
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f" ❌ 删除失败: {e}")
|
||||
self.stats["errors"] += 1
|
||||
return False
|
||||
|
||||
async def run(self):
|
||||
"""执行去重"""
|
||||
start_time = datetime.now()
|
||||
|
||||
print("="*70)
|
||||
print("记忆去重工具")
|
||||
print("="*70)
|
||||
print(f"数据目录: {self.data_dir}")
|
||||
print(f"相似度阈值: {self.threshold}")
|
||||
print(f"模式: {'预览模式(不实际删除)' if self.dry_run else '执行模式(会实际删除)'}")
|
||||
print("="*70)
|
||||
print()
|
||||
|
||||
# 初始化
|
||||
await self.initialize()
|
||||
|
||||
# 查找相似对
|
||||
similar_pairs = await self.find_similar_pairs()
|
||||
|
||||
if not similar_pairs:
|
||||
logger.info("未找到需要去重的相似记忆对")
|
||||
print()
|
||||
print("="*70)
|
||||
print("未找到需要去重的记忆")
|
||||
print("="*70)
|
||||
return
|
||||
|
||||
# 去重处理
|
||||
logger.info(f"开始{'预览' if self.dry_run else '执行'}去重...")
|
||||
print()
|
||||
|
||||
processed_pairs = set() # 避免重复处理
|
||||
|
||||
for mem_id_1, mem_id_2, similarity in similar_pairs:
|
||||
# 检查是否已处理(可能一个记忆已被删除)
|
||||
pair_key = tuple(sorted([mem_id_1, mem_id_2]))
|
||||
if pair_key in processed_pairs:
|
||||
continue
|
||||
|
||||
# 检查记忆是否仍存在
|
||||
if not self.manager.graph_store.get_memory_by_id(mem_id_1):
|
||||
logger.debug(f"记忆 {mem_id_1} 已不存在,跳过")
|
||||
continue
|
||||
if not self.manager.graph_store.get_memory_by_id(mem_id_2):
|
||||
logger.debug(f"记忆 {mem_id_2} 已不存在,跳过")
|
||||
continue
|
||||
|
||||
# 执行去重
|
||||
success = await self.deduplicate_pair(mem_id_1, mem_id_2, similarity)
|
||||
|
||||
if success:
|
||||
processed_pairs.add(pair_key)
|
||||
|
||||
# 保存数据(如果不是干运行)
|
||||
if not self.dry_run:
|
||||
logger.info("正在保存数据...")
|
||||
await self.manager.persistence.save_graph_store(self.manager.graph_store)
|
||||
logger.info("✅ 数据已保存")
|
||||
|
||||
# 统计报告
|
||||
elapsed = (datetime.now() - start_time).total_seconds()
|
||||
|
||||
print()
|
||||
print("="*70)
|
||||
print("去重报告")
|
||||
print("="*70)
|
||||
print(f"总记忆数: {self.stats['total_memories']}")
|
||||
print(f"相似记忆对: {self.stats['similar_pairs']}")
|
||||
print(f"发现重复: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}")
|
||||
print(f"{'预览通过' if self.dry_run else '成功删除'}: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}")
|
||||
print(f"错误数: {self.stats['errors']}")
|
||||
print(f"耗时: {elapsed:.2f}秒")
|
||||
|
||||
if self.dry_run:
|
||||
print()
|
||||
print("⚠️ 这是预览模式,未实际删除任何记忆")
|
||||
print("💡 要执行实际删除,请运行: python scripts/deduplicate_memories.py")
|
||||
else:
|
||||
print()
|
||||
print("✅ 去重完成!")
|
||||
final_count = len(self.manager.graph_store.get_all_memories())
|
||||
print(f"📊 最终记忆数: {final_count} (减少 {self.stats['total_memories'] - final_count} 条)")
|
||||
|
||||
print("="*70)
|
||||
|
||||
async def cleanup(self):
|
||||
"""清理资源"""
|
||||
if self.manager:
|
||||
await shutdown_memory_manager()
|
||||
|
||||
|
||||
async def main():
|
||||
"""主函数"""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="记忆去重工具 - 对标记为相似的记忆进行一键去重",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
示例:
|
||||
# 预览模式(推荐先运行)
|
||||
python scripts/deduplicate_memories.py --dry-run
|
||||
|
||||
# 执行去重
|
||||
python scripts/deduplicate_memories.py
|
||||
|
||||
# 指定相似度阈值(只处理相似度>=0.9的记忆对)
|
||||
python scripts/deduplicate_memories.py --threshold 0.9
|
||||
|
||||
# 指定数据目录
|
||||
python scripts/deduplicate_memories.py --data-dir data/memory_graph
|
||||
|
||||
# 组合使用
|
||||
python scripts/deduplicate_memories.py --dry-run --threshold 0.95 --data-dir data/test
|
||||
"""
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="预览模式,不实际删除记忆(推荐先运行此模式)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--threshold",
|
||||
type=float,
|
||||
default=0.85,
|
||||
help="相似度阈值,只处理相似度>=此值的记忆对(默认: 0.85)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--data-dir",
|
||||
type=str,
|
||||
default="data/memory_graph",
|
||||
help="记忆数据目录(默认: data/memory_graph)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建去重器
|
||||
deduplicator = MemoryDeduplicator(
|
||||
data_dir=args.data_dir,
|
||||
dry_run=args.dry_run,
|
||||
threshold=args.threshold
|
||||
)
|
||||
|
||||
try:
|
||||
# 执行去重
|
||||
await deduplicator.run()
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n⚠️ 用户中断操作")
|
||||
except Exception as e:
|
||||
logger.error(f"执行失败: {e}")
|
||||
print(f"\n❌ 执行失败: {e}")
|
||||
return 1
|
||||
finally:
|
||||
# 清理资源
|
||||
await deduplicator.cleanup()
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(asyncio.run(main()))
|
||||
@@ -1,195 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
# Add project root to Python path
|
||||
from src.common.database.database_model import ChatStreams, Expression
|
||||
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
|
||||
def get_chat_name(chat_id: str) -> str:
|
||||
"""Get chat name from chat_id by querying ChatStreams table directly"""
|
||||
try:
|
||||
# 直接从数据库查询ChatStreams表
|
||||
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
|
||||
if chat_stream is None:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
|
||||
# 如果有群组信息,显示群组名称
|
||||
if chat_stream.group_name:
|
||||
return f"{chat_stream.group_name} ({chat_id})"
|
||||
# 如果是私聊,显示用户昵称
|
||||
elif chat_stream.user_nickname:
|
||||
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
|
||||
else:
|
||||
return f"未知聊天 ({chat_id})"
|
||||
except Exception:
|
||||
return f"查询失败 ({chat_id})"
|
||||
|
||||
|
||||
def calculate_time_distribution(expressions) -> dict[str, int]:
|
||||
"""Calculate distribution of last active time in days"""
|
||||
now = time.time()
|
||||
distribution = {
|
||||
"0-1天": 0,
|
||||
"1-3天": 0,
|
||||
"3-7天": 0,
|
||||
"7-14天": 0,
|
||||
"14-30天": 0,
|
||||
"30-60天": 0,
|
||||
"60-90天": 0,
|
||||
"90+天": 0,
|
||||
}
|
||||
for expr in expressions:
|
||||
diff_days = (now - expr.last_active_time) / (24 * 3600)
|
||||
if diff_days < 1:
|
||||
distribution["0-1天"] += 1
|
||||
elif diff_days < 3:
|
||||
distribution["1-3天"] += 1
|
||||
elif diff_days < 7:
|
||||
distribution["3-7天"] += 1
|
||||
elif diff_days < 14:
|
||||
distribution["7-14天"] += 1
|
||||
elif diff_days < 30:
|
||||
distribution["14-30天"] += 1
|
||||
elif diff_days < 60:
|
||||
distribution["30-60天"] += 1
|
||||
elif diff_days < 90:
|
||||
distribution["60-90天"] += 1
|
||||
else:
|
||||
distribution["90+天"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def calculate_count_distribution(expressions) -> dict[str, int]:
|
||||
"""Calculate distribution of count values"""
|
||||
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
|
||||
for expr in expressions:
|
||||
cnt = expr.count
|
||||
if cnt < 1:
|
||||
distribution["0-1"] += 1
|
||||
elif cnt < 2:
|
||||
distribution["1-2"] += 1
|
||||
elif cnt < 3:
|
||||
distribution["2-3"] += 1
|
||||
elif cnt < 4:
|
||||
distribution["3-4"] += 1
|
||||
elif cnt < 5:
|
||||
distribution["4-5"] += 1
|
||||
elif cnt < 10:
|
||||
distribution["5-10"] += 1
|
||||
else:
|
||||
distribution["10+"] += 1
|
||||
return distribution
|
||||
|
||||
|
||||
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> list[Expression]:
|
||||
"""Get top N most used expressions for a specific chat_id"""
|
||||
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
|
||||
|
||||
|
||||
def show_overall_statistics(expressions, total: int) -> None:
|
||||
"""Show overall statistics"""
|
||||
time_dist = calculate_time_distribution(expressions)
|
||||
count_dist = calculate_count_distribution(expressions)
|
||||
|
||||
print("\n=== 总体统计 ===")
|
||||
print(f"总表达式数量: {total}")
|
||||
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
print(f"{period}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
|
||||
|
||||
|
||||
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
|
||||
"""Show statistics for a specific chat"""
|
||||
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
|
||||
chat_total = len(chat_exprs)
|
||||
|
||||
print(f"\n=== {chat_name} ===")
|
||||
print(f"表达式数量: {chat_total}")
|
||||
|
||||
if chat_total == 0:
|
||||
print("该聊天没有表达式数据")
|
||||
return
|
||||
|
||||
# Time distribution for this chat
|
||||
time_dist = calculate_time_distribution(chat_exprs)
|
||||
print("\n上次激活时间分布:")
|
||||
for period, count in time_dist.items():
|
||||
if count > 0:
|
||||
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Count distribution for this chat
|
||||
count_dist = calculate_count_distribution(chat_exprs)
|
||||
print("\ncount分布:")
|
||||
for range_, count in count_dist.items():
|
||||
if count > 0:
|
||||
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
|
||||
|
||||
# Top expressions
|
||||
print("\nTop 10使用最多的表达式:")
|
||||
top_exprs = get_top_expressions_by_chat(chat_id, 10)
|
||||
for i, expr in enumerate(top_exprs, 1):
|
||||
print(f"{i}. [{expr.type}] Count: {expr.count}")
|
||||
print(f" Situation: {expr.situation}")
|
||||
print(f" Style: {expr.style}")
|
||||
print()
|
||||
|
||||
|
||||
def interactive_menu() -> None:
|
||||
"""Interactive menu for expression statistics"""
|
||||
# Get all expressions
|
||||
expressions = list(Expression.select())
|
||||
if not expressions:
|
||||
print("数据库中没有找到表达式")
|
||||
return
|
||||
|
||||
total = len(expressions)
|
||||
|
||||
# Get unique chat_ids and their names
|
||||
chat_ids = list({expr.chat_id for expr in expressions})
|
||||
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
|
||||
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
|
||||
|
||||
while True:
|
||||
print("\n" + "=" * 50)
|
||||
print("表达式统计分析")
|
||||
print("=" * 50)
|
||||
print("0. 显示总体统计")
|
||||
|
||||
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
|
||||
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
|
||||
print(f"{i}. {chat_name} ({chat_count}个表达式)")
|
||||
|
||||
print("q. 退出")
|
||||
|
||||
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
|
||||
|
||||
if choice.lower() == "q":
|
||||
print("再见!")
|
||||
break
|
||||
|
||||
try:
|
||||
choice_num = int(choice)
|
||||
if choice_num == 0:
|
||||
show_overall_statistics(expressions, total)
|
||||
elif 1 <= choice_num <= len(chat_info):
|
||||
chat_id, chat_name = chat_info[choice_num - 1]
|
||||
show_chat_statistics(chat_id, chat_name)
|
||||
else:
|
||||
print("无效的选择,请重新输入")
|
||||
except ValueError:
|
||||
print("请输入有效的数字")
|
||||
|
||||
input("\n按回车键继续...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
interactive_menu()
|
||||
@@ -1,66 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""提取models.py中的模型定义"""
|
||||
|
||||
import re
|
||||
|
||||
# 读取原始文件
|
||||
with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
|
||||
# 找到get_string_field函数的开始和结束
|
||||
get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数")
|
||||
get_string_field_end = content.find("\n\nclass ChatStreams(Base):")
|
||||
get_string_field = content[get_string_field_start:get_string_field_end]
|
||||
|
||||
# 找到第一个class定义开始
|
||||
first_class_pos = content.find("class ChatStreams(Base):")
|
||||
|
||||
# 找到所有class定义,直到遇到非class的def
|
||||
# 简单策略:找到所有以"class "开头且继承Base的类
|
||||
classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)"
|
||||
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
|
||||
|
||||
if matches:
|
||||
# 取最后一个匹配的结束位置
|
||||
models_content = content[first_class_pos:first_class_pos + matches[-1].end()]
|
||||
else:
|
||||
# 备用方案:从第一个class到文件的85%位置
|
||||
models_end = int(len(content) * 0.85)
|
||||
models_content = content[first_class_pos:models_end]
|
||||
|
||||
# 创建新文件内容
|
||||
header = '''"""SQLAlchemy数据库模型定义
|
||||
|
||||
本文件只包含纯模型定义,使用SQLAlchemy 2.0的Mapped类型注解风格。
|
||||
引擎和会话管理已移至core/engine.py和core/session.py。
|
||||
|
||||
所有模型使用统一的类型注解风格:
|
||||
field_name: Mapped[PyType] = mapped_column(Type, ...)
|
||||
|
||||
这样IDE/Pylance能正确推断实例属性类型。
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import time
|
||||
|
||||
from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
# 创建基类
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
'''
|
||||
|
||||
new_content = header + get_string_field + "\n\n" + models_content
|
||||
|
||||
# 写入新文件
|
||||
with open("src/common/database/core/models.py", "w", encoding="utf-8") as f:
|
||||
f.write(new_content)
|
||||
|
||||
print("✅ Models file rewritten successfully")
|
||||
print(f"File size: {len(new_content)} characters")
|
||||
pattern = r"^class \w+\(Base\):"
|
||||
model_count = len(re.findall(pattern, models_content, re.MULTILINE))
|
||||
print(f"Number of model classes: {model_count}")
|
||||
@@ -1,267 +0,0 @@
|
||||
"""
|
||||
为现有节点生成嵌入向量
|
||||
|
||||
批量为图存储中缺少嵌入向量的节点生成并索引嵌入向量
|
||||
|
||||
使用场景:
|
||||
1. 历史记忆节点没有嵌入向量
|
||||
2. 嵌入生成器之前未配置,现在需要补充生成
|
||||
3. 向量索引损坏需要重建
|
||||
|
||||
使用方法:
|
||||
python scripts/generate_missing_embeddings.py [--node-types TOPIC,OBJECT] [--batch-size 50]
|
||||
|
||||
参数说明:
|
||||
--node-types: 需要生成嵌入的节点类型,默认为 TOPIC,OBJECT
|
||||
--batch-size: 批量处理大小,默认为 50
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
|
||||
async def generate_missing_embeddings(
|
||||
target_node_types: list[str] | None = None,
|
||||
batch_size: int = 50,
|
||||
):
|
||||
"""
|
||||
为缺失嵌入向量的节点生成嵌入
|
||||
|
||||
Args:
|
||||
target_node_types: 需要处理的节点类型列表(如 ["主题", "客体"])
|
||||
batch_size: 批处理大小
|
||||
"""
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager
|
||||
from src.memory_graph.models import NodeType
|
||||
|
||||
logger = get_logger("generate_missing_embeddings")
|
||||
|
||||
if target_node_types is None:
|
||||
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("🔧 为节点生成嵌入向量")
|
||||
print(f"{'='*80}\n")
|
||||
print(f"目标节点类型: {', '.join(target_node_types)}")
|
||||
print(f"批处理大小: {batch_size}\n")
|
||||
|
||||
# 1. 初始化记忆管理器
|
||||
print("🔧 正在初始化记忆管理器...")
|
||||
await initialize_memory_manager()
|
||||
manager = get_memory_manager()
|
||||
|
||||
if manager is None:
|
||||
print("❌ 记忆管理器初始化失败")
|
||||
return
|
||||
|
||||
print("✅ 记忆管理器已初始化\n")
|
||||
|
||||
# 2. 获取已索引的节点ID
|
||||
print("🔍 检查现有向量索引...")
|
||||
existing_node_ids = set()
|
||||
try:
|
||||
vector_count = manager.vector_store.collection.count()
|
||||
if vector_count > 0:
|
||||
# 分批获取所有已索引的ID
|
||||
batch_size_check = 1000
|
||||
for offset in range(0, vector_count, batch_size_check):
|
||||
limit = min(batch_size_check, vector_count - offset)
|
||||
result = manager.vector_store.collection.get(
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
if result and "ids" in result:
|
||||
existing_node_ids.update(result["ids"])
|
||||
|
||||
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
|
||||
except Exception as e:
|
||||
logger.warning(f"获取已索引节点ID失败: {e}")
|
||||
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
|
||||
|
||||
# 3. 收集需要生成嵌入的节点
|
||||
print("🔍 扫描需要生成嵌入的节点...")
|
||||
all_memories = manager.graph_store.get_all_memories()
|
||||
|
||||
nodes_to_process = []
|
||||
total_target_nodes = 0
|
||||
type_stats = {nt: {"total": 0, "need_emb": 0, "already_indexed": 0} for nt in target_node_types}
|
||||
|
||||
for memory in all_memories:
|
||||
for node in memory.nodes:
|
||||
if node.node_type.value in target_node_types:
|
||||
total_target_nodes += 1
|
||||
type_stats[node.node_type.value]["total"] += 1
|
||||
|
||||
# 检查是否已在向量索引中
|
||||
if node.id in existing_node_ids:
|
||||
type_stats[node.node_type.value]["already_indexed"] += 1
|
||||
continue
|
||||
|
||||
if not node.has_embedding():
|
||||
nodes_to_process.append({
|
||||
"node": node,
|
||||
"memory_id": memory.id,
|
||||
})
|
||||
type_stats[node.node_type.value]["need_emb"] += 1
|
||||
|
||||
print("\n📊 扫描结果:")
|
||||
for node_type in target_node_types:
|
||||
stats = type_stats[node_type]
|
||||
already_ok = stats["already_indexed"]
|
||||
coverage = (stats["total"] - stats["need_emb"]) / stats["total"] * 100 if stats["total"] > 0 else 0
|
||||
print(f" - {node_type}: {stats['total']} 个节点, {stats['need_emb']} 个缺失嵌入, "
|
||||
f"{already_ok} 个已索引 (覆盖率: {coverage:.1f}%)")
|
||||
|
||||
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
|
||||
|
||||
if len(nodes_to_process) == 0:
|
||||
print("✅ 所有节点已有嵌入向量,无需生成")
|
||||
return
|
||||
|
||||
# 3. 批量生成嵌入
|
||||
print("🚀 开始生成嵌入向量...\n")
|
||||
|
||||
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
|
||||
success_count = 0
|
||||
failed_count = 0
|
||||
indexed_count = 0
|
||||
|
||||
for i in range(0, len(nodes_to_process), batch_size):
|
||||
batch = nodes_to_process[i : i + batch_size]
|
||||
batch_num = i // batch_size + 1
|
||||
|
||||
print(f"📦 批次 {batch_num}/{total_batches} ({len(batch)} 个节点)...")
|
||||
|
||||
try:
|
||||
# 提取文本内容
|
||||
texts = [item["node"].content for item in batch]
|
||||
|
||||
# 批量生成嵌入
|
||||
embeddings = await manager.embedding_generator.generate_batch(texts)
|
||||
|
||||
# 为节点设置嵌入并索引
|
||||
batch_nodes_for_index = []
|
||||
|
||||
for j, (item, embedding) in enumerate(zip(batch, embeddings)):
|
||||
node = item["node"]
|
||||
|
||||
if embedding is not None:
|
||||
# 设置嵌入向量
|
||||
node.embedding = embedding
|
||||
batch_nodes_for_index.append(node)
|
||||
success_count += 1
|
||||
else:
|
||||
failed_count += 1
|
||||
logger.warning(f" ⚠️ 节点 {node.id[:8]}... '{node.content[:30]}' 嵌入生成失败")
|
||||
|
||||
# 批量索引到向量数据库
|
||||
if batch_nodes_for_index:
|
||||
try:
|
||||
await manager.vector_store.add_nodes_batch(batch_nodes_for_index)
|
||||
indexed_count += len(batch_nodes_for_index)
|
||||
print(f" ✅ 成功: {len(batch_nodes_for_index)}/{len(batch)} 个节点已生成并索引")
|
||||
except Exception as e:
|
||||
# 如果批量失败,尝试逐个添加(跳过重复)
|
||||
logger.warning(f" 批量索引失败,尝试逐个添加: {e}")
|
||||
individual_success = 0
|
||||
for node in batch_nodes_for_index:
|
||||
try:
|
||||
await manager.vector_store.add_node(node)
|
||||
individual_success += 1
|
||||
indexed_count += 1
|
||||
except Exception as e2:
|
||||
if "Expected IDs to be unique" in str(e2):
|
||||
logger.debug(f" 跳过已存在节点: {node.id}")
|
||||
else:
|
||||
logger.error(f" 节点 {node.id} 索引失败: {e2}")
|
||||
print(f" ⚠️ 逐个索引: {individual_success}/{len(batch_nodes_for_index)} 个成功")
|
||||
|
||||
except Exception as e:
|
||||
failed_count += len(batch)
|
||||
logger.error(f"批次 {batch_num} 处理失败")
|
||||
print(f" ❌ 批次处理失败: {e}")
|
||||
|
||||
# 显示进度
|
||||
total_processed = min(i + batch_size, len(nodes_to_process))
|
||||
progress = total_processed / len(nodes_to_process) * 100
|
||||
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
|
||||
|
||||
# 4. 保存图数据(更新节点的 embedding 字段)
|
||||
print("💾 保存图数据...")
|
||||
try:
|
||||
await manager.persistence.save_graph_store(manager.graph_store)
|
||||
print("✅ 图数据已保存\n")
|
||||
except Exception as e:
|
||||
logger.error("保存图数据失败")
|
||||
print(f"❌ 保存失败: {e}\n")
|
||||
|
||||
# 5. 验证结果
|
||||
print("🔍 验证向量索引...")
|
||||
final_vector_count = manager.vector_store.collection.count()
|
||||
stats = manager.graph_store.get_statistics()
|
||||
total_nodes = stats["total_nodes"]
|
||||
|
||||
print(f"\n{'='*80}")
|
||||
print("📊 生成完成")
|
||||
print(f"{'='*80}")
|
||||
print(f"处理节点数: {len(nodes_to_process)}")
|
||||
print(f"成功生成: {success_count}")
|
||||
print(f"失败数量: {failed_count}")
|
||||
print(f"成功索引: {indexed_count}")
|
||||
print(f"向量索引节点数: {final_vector_count}")
|
||||
print(f"图存储节点数: {total_nodes}")
|
||||
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
|
||||
|
||||
# 6. 测试搜索
|
||||
print("🧪 测试搜索功能...")
|
||||
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
|
||||
|
||||
for query in test_queries:
|
||||
results = await manager.search_memories(query=query, top_k=3)
|
||||
if results:
|
||||
print(f"\n✅ 查询 '{query}' 找到 {len(results)} 条记忆:")
|
||||
for i, memory in enumerate(results[:2], 1):
|
||||
subject_node = memory.get_subject_node()
|
||||
# 获取主题节点(遍历所有节点找TOPIC类型)
|
||||
from src.memory_graph.models import NodeType
|
||||
topic_nodes = [n for n in memory.nodes if n.node_type == NodeType.TOPIC]
|
||||
subject = subject_node.content if subject_node else "?"
|
||||
topic = topic_nodes[0].content if topic_nodes else "?"
|
||||
print(f" {i}. {subject} - {topic} (重要性: {memory.importance:.2f})")
|
||||
else:
|
||||
print(f"\n⚠️ 查询 '{query}' 返回 0 条结果")
|
||||
|
||||
|
||||
async def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="为节点生成嵌入向量")
|
||||
parser.add_argument(
|
||||
"--node-types",
|
||||
type=str,
|
||||
default="主题,客体",
|
||||
help="需要生成嵌入的节点类型,逗号分隔(默认:主题,客体)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
type=int,
|
||||
default=50,
|
||||
help="批处理大小(默认:50)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
target_types = [t.strip() for t in args.node_types.split(",")]
|
||||
await generate_missing_embeddings(
|
||||
target_node_types=target_types,
|
||||
batch_size=args.batch_size,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
1051
scripts/migrate_database.py
Normal file
1051
scripts/migrate_database.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,140 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
从现有ChromaDB数据重建JSON元数据索引
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
|
||||
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry
|
||||
from src.chat.memory_system.memory_system import MemorySystem
|
||||
from src.common.logger import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def rebuild_metadata_index():
|
||||
"""从ChromaDB重建元数据索引"""
|
||||
print("=" * 80)
|
||||
print("重建JSON元数据索引")
|
||||
print("=" * 80)
|
||||
|
||||
# 初始化记忆系统
|
||||
print("\n🔧 初始化记忆系统...")
|
||||
ms = MemorySystem()
|
||||
await ms.initialize()
|
||||
print("✅ 记忆系统已初始化")
|
||||
|
||||
if not hasattr(ms.unified_storage, "metadata_index"):
|
||||
print("❌ 元数据索引管理器未初始化")
|
||||
return
|
||||
|
||||
# 获取所有记忆
|
||||
print("\n📥 从ChromaDB获取所有记忆...")
|
||||
from src.common.vector_db import vector_db_service
|
||||
|
||||
try:
|
||||
# 获取集合中的所有记忆ID
|
||||
collection_name = ms.unified_storage.config.memory_collection
|
||||
result = vector_db_service.get(
|
||||
collection_name=collection_name, include=["documents", "metadatas", "embeddings"]
|
||||
)
|
||||
|
||||
if not result or not result.get("ids"):
|
||||
print("❌ ChromaDB中没有找到记忆数据")
|
||||
return
|
||||
|
||||
ids = result["ids"]
|
||||
metadatas = result.get("metadatas", [])
|
||||
|
||||
print(f"✅ 找到 {len(ids)} 条记忆")
|
||||
|
||||
# 重建元数据索引
|
||||
print("\n🔨 开始重建元数据索引...")
|
||||
entries = []
|
||||
success_count = 0
|
||||
|
||||
for i, (memory_id, metadata) in enumerate(zip(ids, metadatas, strict=False), 1):
|
||||
try:
|
||||
# 从ChromaDB元数据重建索引条目
|
||||
import orjson
|
||||
|
||||
entry = MemoryMetadataIndexEntry(
|
||||
memory_id=memory_id,
|
||||
user_id=metadata.get("user_id", "unknown"),
|
||||
memory_type=metadata.get("memory_type", "general"),
|
||||
subjects=orjson.loads(metadata.get("subjects", "[]")),
|
||||
objects=[metadata.get("object")] if metadata.get("object") else [],
|
||||
keywords=orjson.loads(metadata.get("keywords", "[]")),
|
||||
tags=orjson.loads(metadata.get("tags", "[]")),
|
||||
importance=2, # 默认NORMAL
|
||||
confidence=2, # 默认MEDIUM
|
||||
created_at=metadata.get("created_at", 0.0),
|
||||
access_count=metadata.get("access_count", 0),
|
||||
chat_id=metadata.get("chat_id"),
|
||||
content_preview=None,
|
||||
)
|
||||
|
||||
# 尝试解析importance和confidence的枚举名称
|
||||
if "importance" in metadata:
|
||||
imp_str = metadata["importance"]
|
||||
if imp_str == "LOW":
|
||||
entry.importance = 1
|
||||
elif imp_str == "NORMAL":
|
||||
entry.importance = 2
|
||||
elif imp_str == "HIGH":
|
||||
entry.importance = 3
|
||||
elif imp_str == "CRITICAL":
|
||||
entry.importance = 4
|
||||
|
||||
if "confidence" in metadata:
|
||||
conf_str = metadata["confidence"]
|
||||
if conf_str == "LOW":
|
||||
entry.confidence = 1
|
||||
elif conf_str == "MEDIUM":
|
||||
entry.confidence = 2
|
||||
elif conf_str == "HIGH":
|
||||
entry.confidence = 3
|
||||
elif conf_str == "VERIFIED":
|
||||
entry.confidence = 4
|
||||
|
||||
entries.append(entry)
|
||||
success_count += 1
|
||||
|
||||
if i % 100 == 0:
|
||||
print(f" 处理进度: {i}/{len(ids)} ({success_count} 成功)")
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"处理记忆 {memory_id} 失败: {e}")
|
||||
continue
|
||||
|
||||
print(f"\n✅ 成功解析 {success_count}/{len(ids)} 条记忆元数据")
|
||||
|
||||
# 批量更新索引
|
||||
print("\n💾 保存元数据索引...")
|
||||
ms.unified_storage.metadata_index.batch_add_or_update(entries)
|
||||
ms.unified_storage.metadata_index.save()
|
||||
|
||||
# 显示统计信息
|
||||
stats = ms.unified_storage.metadata_index.get_stats()
|
||||
print("\n📊 重建后的索引统计:")
|
||||
print(f" - 总记忆数: {stats['total_memories']}")
|
||||
print(f" - 主语数量: {stats['subjects_count']}")
|
||||
print(f" - 关键词数量: {stats['keywords_count']}")
|
||||
print(f" - 标签数量: {stats['tags_count']}")
|
||||
print(" - 类型分布:")
|
||||
for mtype, count in stats["types"].items():
|
||||
print(f" - {mtype}: {count}")
|
||||
|
||||
print("\n✅ 元数据索引重建完成!")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"重建索引失败: {e}")
|
||||
print(f"❌ 重建索引失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(rebuild_metadata_index())
|
||||
@@ -1,190 +0,0 @@
|
||||
"""
|
||||
MCP 集成测试脚本
|
||||
|
||||
测试 MCP 客户端连接、工具列表获取和工具调用功能
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到 Python 路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system.core.component_registry import ComponentRegistry
|
||||
from src.plugin_system.core.mcp_client_manager import MCPClientManager
|
||||
|
||||
logger = get_logger("test_mcp_integration")
|
||||
|
||||
|
||||
async def test_mcp_client_manager():
|
||||
"""测试 MCPClientManager 基本功能"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 1: MCPClientManager 连接和工具列表")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# 初始化 MCP 客户端管理器
|
||||
manager = MCPClientManager()
|
||||
await manager.initialize()
|
||||
|
||||
print("\n✓ MCP 客户端管理器初始化成功")
|
||||
print(f"已连接服务器数量: {len(manager.clients)}")
|
||||
|
||||
# 获取所有工具
|
||||
tools = await manager.get_all_tools()
|
||||
print(f"\n获取到 {len(tools)} 个 MCP 工具:")
|
||||
|
||||
for tool in tools:
|
||||
print(f"\n 工具: {tool}")
|
||||
# 注意: 这里 tool 是字符串形式的工具名称
|
||||
# 如果需要工具详情,需要从其他地方获取
|
||||
|
||||
return manager, tools
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 测试失败: {e}")
|
||||
logger.exception("MCPClientManager 测试失败")
|
||||
return None, []
|
||||
|
||||
|
||||
async def test_tool_call(manager: MCPClientManager, tools):
|
||||
"""测试工具调用"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 2: MCP 工具调用")
|
||||
print("="*60)
|
||||
|
||||
if not tools:
|
||||
print("\n⚠ 没有可用的工具进行测试")
|
||||
return
|
||||
|
||||
try:
|
||||
# 工具列表测试已在第一个测试中完成
|
||||
print("\n✓ 工具列表获取成功")
|
||||
print(f"可用工具数量: {len(tools)}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 工具调用测试失败: {e}")
|
||||
logger.exception("工具调用测试失败")
|
||||
|
||||
|
||||
async def test_component_registry_integration():
|
||||
"""测试 ComponentRegistry 集成"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 3: ComponentRegistry MCP 工具集成")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
registry = ComponentRegistry()
|
||||
|
||||
# 加载 MCP 工具
|
||||
await registry.load_mcp_tools()
|
||||
|
||||
# 获取 MCP 工具
|
||||
mcp_tools = registry.get_mcp_tools()
|
||||
print(f"\n✓ ComponentRegistry 加载了 {len(mcp_tools)} 个 MCP 工具")
|
||||
|
||||
for tool in mcp_tools:
|
||||
print(f"\n 工具: {tool.name}")
|
||||
print(f" 描述: {tool.description}")
|
||||
print(f" 参数数量: {len(tool.parameters)}")
|
||||
|
||||
# 测试 is_mcp_tool 方法
|
||||
is_mcp = registry.is_mcp_tool(tool.name)
|
||||
print(f" is_mcp_tool 检测: {'✓' if is_mcp else '✗'}")
|
||||
|
||||
return mcp_tools
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ ComponentRegistry 集成测试失败: {e}")
|
||||
logger.exception("ComponentRegistry 集成测试失败")
|
||||
return []
|
||||
|
||||
|
||||
async def test_tool_execution(mcp_tools):
|
||||
"""测试通过适配器执行工具"""
|
||||
print("\n" + "="*60)
|
||||
print("测试 4: MCPToolAdapter 工具执行")
|
||||
print("="*60)
|
||||
|
||||
if not mcp_tools:
|
||||
print("\n⚠ 没有可用的 MCP 工具进行测试")
|
||||
return
|
||||
|
||||
try:
|
||||
# 选择第一个工具测试
|
||||
test_tool = mcp_tools[0]
|
||||
print(f"\n测试工具: {test_tool.name}")
|
||||
|
||||
# 构建测试参数
|
||||
test_args = {}
|
||||
for param_name, param_type, param_desc, is_required, enum_values in test_tool.parameters:
|
||||
if is_required:
|
||||
# 根据类型提供默认值
|
||||
from src.llm_models.payload_content.tool_option import ToolParamType
|
||||
|
||||
if param_type == ToolParamType.STRING:
|
||||
test_args[param_name] = "test_value"
|
||||
elif param_type == ToolParamType.INTEGER:
|
||||
test_args[param_name] = 1
|
||||
elif param_type == ToolParamType.FLOAT:
|
||||
test_args[param_name] = 1.0
|
||||
elif param_type == ToolParamType.BOOLEAN:
|
||||
test_args[param_name] = True
|
||||
|
||||
print(f"测试参数: {test_args}")
|
||||
|
||||
# 执行工具
|
||||
result = await test_tool.execute(test_args)
|
||||
|
||||
if result:
|
||||
print("\n✓ 工具执行成功")
|
||||
print(f"结果类型: {result.get('type')}")
|
||||
print(f"结果内容: {result.get('content', '')[:200]}...") # 只显示前200字符
|
||||
else:
|
||||
print("\n✗ 工具执行失败,返回 None")
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n✗ 工具执行测试失败: {e}")
|
||||
logger.exception("工具执行测试失败")
|
||||
|
||||
|
||||
async def main():
|
||||
"""主测试流程"""
|
||||
print("\n" + "="*60)
|
||||
print("MCP 集成测试")
|
||||
print("="*60)
|
||||
|
||||
try:
|
||||
# 测试 1: MCPClientManager 基本功能
|
||||
manager, tools = await test_mcp_client_manager()
|
||||
|
||||
if manager:
|
||||
# 测试 2: 工具调用
|
||||
await test_tool_call(manager, tools)
|
||||
|
||||
# 测试 3: ComponentRegistry 集成
|
||||
mcp_tools = await test_component_registry_integration()
|
||||
|
||||
# 测试 4: 工具执行
|
||||
await test_tool_execution(mcp_tools)
|
||||
|
||||
# 关闭连接
|
||||
await manager.close()
|
||||
print("\n✓ MCP 客户端连接已关闭")
|
||||
|
||||
print("\n" + "="*60)
|
||||
print("测试完成")
|
||||
print("="*60 + "\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n测试被用户中断")
|
||||
except Exception as e:
|
||||
print(f"\n测试过程中发生错误: {e}")
|
||||
logger.exception("测试失败")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -1,292 +0,0 @@
|
||||
"""
|
||||
三层记忆系统测试脚本
|
||||
用于验证系统各组件是否正常工作
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# 添加项目根目录到路径
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
|
||||
async def test_perceptual_memory():
|
||||
"""测试感知记忆层"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试1: 感知记忆层")
|
||||
print("=" * 60)
|
||||
|
||||
from src.memory_graph.three_tier.perceptual_manager import get_perceptual_manager
|
||||
|
||||
manager = get_perceptual_manager()
|
||||
await manager.initialize()
|
||||
|
||||
# 添加测试消息
|
||||
test_messages = [
|
||||
("user1", "今天天气真好", 1700000000.0),
|
||||
("user2", "是啊,适合出去玩", 1700000001.0),
|
||||
("user1", "我们去公园吧", 1700000002.0),
|
||||
("user2", "好主意!", 1700000003.0),
|
||||
("user1", "带上野餐垫", 1700000004.0),
|
||||
]
|
||||
|
||||
for sender, content, timestamp in test_messages:
|
||||
message = {
|
||||
"message_id": f"msg_{timestamp}",
|
||||
"sender": sender,
|
||||
"content": content,
|
||||
"timestamp": timestamp,
|
||||
"platform": "test",
|
||||
"stream_id": "test_stream",
|
||||
}
|
||||
await manager.add_message(message)
|
||||
|
||||
print(f"✅ 成功添加 {len(test_messages)} 条消息")
|
||||
|
||||
# 测试TopK召回
|
||||
results = await manager.recall_blocks("公园野餐", top_k=2)
|
||||
print(f"✅ TopK召回返回 {len(results)} 个块")
|
||||
|
||||
if results:
|
||||
print(f" 第一个块包含 {len(results[0].messages)} 条消息")
|
||||
|
||||
# 获取统计信息
|
||||
stats = manager.get_statistics() # 不是async方法
|
||||
print(f"✅ 统计信息: {stats}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_short_term_memory():
|
||||
"""测试短期记忆层"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试2: 短期记忆层")
|
||||
print("=" * 60)
|
||||
|
||||
from src.memory_graph.three_tier.models import MemoryBlock
|
||||
from src.memory_graph.three_tier.short_term_manager import get_short_term_manager
|
||||
|
||||
manager = get_short_term_manager()
|
||||
await manager.initialize()
|
||||
|
||||
# 创建测试块
|
||||
test_block = MemoryBlock(
|
||||
id="test_block_1",
|
||||
messages=[
|
||||
{
|
||||
"message_id": "msg1",
|
||||
"sender": "user1",
|
||||
"content": "我明天要参加一个重要的面试",
|
||||
"timestamp": 1700000000.0,
|
||||
"platform": "test",
|
||||
}
|
||||
],
|
||||
combined_text="我明天要参加一个重要的面试",
|
||||
recall_count=3,
|
||||
)
|
||||
|
||||
# 从感知块转换为短期记忆
|
||||
try:
|
||||
await manager.add_from_block(test_block)
|
||||
print("✅ 成功将感知块转换为短期记忆")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 转换失败(可能需要LLM): {e}")
|
||||
return False
|
||||
|
||||
# 测试搜索
|
||||
results = await manager.search_memories("面试", top_k=3)
|
||||
print(f"✅ 搜索返回 {len(results)} 条记忆")
|
||||
|
||||
# 获取统计
|
||||
stats = manager.get_statistics()
|
||||
print(f"✅ 统计信息: {stats}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_long_term_memory():
|
||||
"""测试长期记忆层"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试3: 长期记忆层")
|
||||
print("=" * 60)
|
||||
|
||||
from src.memory_graph.three_tier.long_term_manager import get_long_term_manager
|
||||
|
||||
manager = get_long_term_manager()
|
||||
await manager.initialize()
|
||||
|
||||
print("✅ 长期记忆管理器初始化成功")
|
||||
print(" (需要现有记忆图系统支持)")
|
||||
|
||||
# 获取统计
|
||||
stats = manager.get_statistics()
|
||||
print(f"✅ 统计信息: {stats}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_unified_manager():
|
||||
"""测试统一管理器"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试4: 统一管理器")
|
||||
print("=" * 60)
|
||||
|
||||
from src.memory_graph.three_tier.unified_manager import UnifiedMemoryManager
|
||||
|
||||
manager = UnifiedMemoryManager()
|
||||
await manager.initialize()
|
||||
|
||||
# 添加测试消息
|
||||
message = {
|
||||
"message_id": "unified_test_1",
|
||||
"sender": "user1",
|
||||
"content": "这是一条测试消息",
|
||||
"timestamp": 1700000000.0,
|
||||
"platform": "test",
|
||||
"stream_id": "test_stream",
|
||||
}
|
||||
await manager.add_message(message)
|
||||
|
||||
print("✅ 通过统一接口添加消息成功")
|
||||
|
||||
# 测试搜索
|
||||
results = await manager.search_memories("测试")
|
||||
print(f"✅ 统一搜索返回结果:")
|
||||
print(f" 感知块: {len(results.get('perceptual_blocks', []))}")
|
||||
print(f" 短期记忆: {len(results.get('short_term_memories', []))}")
|
||||
print(f" 长期记忆: {len(results.get('long_term_memories', []))}")
|
||||
|
||||
# 获取统计
|
||||
stats = manager.get_statistics() # 不是async方法
|
||||
print(f"✅ 综合统计:")
|
||||
print(f" 感知层: {stats.get('perceptual', {})}")
|
||||
print(f" 短期层: {stats.get('short_term', {})}")
|
||||
print(f" 长期层: {stats.get('long_term', {})}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_configuration():
|
||||
"""测试配置加载"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试5: 配置系统")
|
||||
print("=" * 60)
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
if not hasattr(global_config, "three_tier_memory"):
|
||||
print("❌ 配置类中未找到 three_tier_memory 字段")
|
||||
return False
|
||||
|
||||
config = global_config.three_tier_memory
|
||||
|
||||
if config is None:
|
||||
print("⚠️ 三层记忆配置为 None(可能未在 bot_config.toml 中配置)")
|
||||
print(" 请在 bot_config.toml 中添加 [three_tier_memory] 配置")
|
||||
return False
|
||||
|
||||
print(f"✅ 配置加载成功")
|
||||
print(f" 启用状态: {config.enable}")
|
||||
print(f" 数据目录: {config.data_dir}")
|
||||
print(f" 感知层最大块数: {config.perceptual_max_blocks}")
|
||||
print(f" 短期层最大记忆数: {config.short_term_max_memories}")
|
||||
print(f" 激活阈值: {config.activation_threshold}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def test_integration():
|
||||
"""测试系统集成"""
|
||||
print("\n" + "=" * 60)
|
||||
print("测试6: 系统集成")
|
||||
print("=" * 60)
|
||||
|
||||
# 首先需要确保配置启用
|
||||
from src.config.config import global_config
|
||||
|
||||
if not global_config.three_tier_memory or not global_config.three_tier_memory.enable:
|
||||
print("⚠️ 配置未启用,跳过集成测试")
|
||||
return False
|
||||
|
||||
# 测试单例模式
|
||||
from src.memory_graph.three_tier.manager_singleton import (
|
||||
get_unified_memory_manager,
|
||||
initialize_unified_memory_manager,
|
||||
)
|
||||
|
||||
# 初始化
|
||||
await initialize_unified_memory_manager()
|
||||
manager = get_unified_memory_manager()
|
||||
|
||||
if manager is None:
|
||||
print("❌ 统一管理器初始化失败")
|
||||
return False
|
||||
|
||||
print("✅ 单例模式正常工作")
|
||||
|
||||
# 测试多次获取
|
||||
manager2 = get_unified_memory_manager()
|
||||
if manager is not manager2:
|
||||
print("❌ 单例模式失败(返回不同实例)")
|
||||
return False
|
||||
|
||||
print("✅ 单例一致性验证通过")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def run_all_tests():
|
||||
"""运行所有测试"""
|
||||
print("\n" + "🔬" * 30)
|
||||
print("三层记忆系统集成测试")
|
||||
print("🔬" * 30)
|
||||
|
||||
tests = [
|
||||
("配置系统", test_configuration),
|
||||
("感知记忆层", test_perceptual_memory),
|
||||
("短期记忆层", test_short_term_memory),
|
||||
("长期记忆层", test_long_term_memory),
|
||||
("统一管理器", test_unified_manager),
|
||||
("系统集成", test_integration),
|
||||
]
|
||||
|
||||
results = []
|
||||
|
||||
for name, test_func in tests:
|
||||
try:
|
||||
result = await test_func()
|
||||
results.append((name, result))
|
||||
except Exception as e:
|
||||
print(f"\n❌ 测试 {name} 失败: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
results.append((name, False))
|
||||
|
||||
# 打印测试总结
|
||||
print("\n" + "=" * 60)
|
||||
print("测试总结")
|
||||
print("=" * 60)
|
||||
|
||||
passed = sum(1 for _, result in results if result)
|
||||
total = len(results)
|
||||
|
||||
for name, result in results:
|
||||
status = "✅ 通过" if result else "❌ 失败"
|
||||
print(f"{status} - {name}")
|
||||
|
||||
print(f"\n总计: {passed}/{total} 测试通过")
|
||||
|
||||
if passed == total:
|
||||
print("\n🎉 所有测试通过!三层记忆系统工作正常。")
|
||||
else:
|
||||
print("\n⚠️ 部分测试失败,请查看上方详细信息。")
|
||||
|
||||
return passed == total
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
success = asyncio.run(run_all_tests())
|
||||
sys.exit(0 if success else 1)
|
||||
@@ -1,185 +0,0 @@
|
||||
"""批量更新数据库导入语句的脚本
|
||||
|
||||
将旧的数据库导入路径更新为新的重构后的路径:
|
||||
- sqlalchemy_models -> core, core.models
|
||||
- sqlalchemy_database_api -> compatibility
|
||||
- database.database -> core
|
||||
"""
|
||||
|
||||
import re
|
||||
from pathlib import Path
|
||||
|
||||
# 定义导入映射规则
|
||||
IMPORT_MAPPINGS = {
|
||||
# 模型导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.+)":
|
||||
r"from src.common.database.core.models import \1",
|
||||
|
||||
# API导入 - 需要特殊处理
|
||||
r"from src\.common\.database\.sqlalchemy_database_api import (.+)":
|
||||
r"from src.common.database.compatibility import \1",
|
||||
|
||||
# get_db_session 从 sqlalchemy_database_api 导入
|
||||
r"from src\.common\.database\.sqlalchemy_database_api import get_db_session":
|
||||
r"from src.common.database.core import get_db_session",
|
||||
|
||||
# get_db_session 从 sqlalchemy_models 导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)":
|
||||
lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}"
|
||||
if "get_db_session" in m.group(0) else m.group(0),
|
||||
|
||||
# get_engine 导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)":
|
||||
lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}",
|
||||
|
||||
# Base 导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)":
|
||||
lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}",
|
||||
|
||||
# initialize_database 导入
|
||||
r"from src\.common\.database\.sqlalchemy_models import initialize_database":
|
||||
r"from src.common.database.core import check_and_migrate_database as initialize_database",
|
||||
|
||||
# database.py 导入
|
||||
r"from src\.common\.database\.database import stop_database":
|
||||
r"from src.common.database.core import close_engine as stop_database",
|
||||
|
||||
r"from src\.common\.database\.database import initialize_sql_database":
|
||||
r"from src.common.database.core import check_and_migrate_database as initialize_sql_database",
|
||||
}
|
||||
|
||||
# 需要排除的文件
|
||||
EXCLUDE_PATTERNS = [
|
||||
"**/database_refactoring_plan.md", # 文档文件
|
||||
"**/old/**", # 旧文件目录
|
||||
"**/sqlalchemy_*.py", # 旧的数据库文件本身
|
||||
"**/database.py", # 旧的database文件
|
||||
"**/db_*.py", # 旧的db文件
|
||||
]
|
||||
|
||||
|
||||
def should_exclude(file_path: Path) -> bool:
|
||||
"""检查文件是否应该被排除"""
|
||||
for pattern in EXCLUDE_PATTERNS:
|
||||
if file_path.match(pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]:
|
||||
"""更新单个文件中的导入语句
|
||||
|
||||
Args:
|
||||
file_path: 文件路径
|
||||
dry_run: 是否只是预览而不实际修改
|
||||
|
||||
Returns:
|
||||
(修改次数, 修改详情列表)
|
||||
"""
|
||||
try:
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
original_content = content
|
||||
changes = []
|
||||
|
||||
# 应用每个映射规则
|
||||
for pattern, replacement in IMPORT_MAPPINGS.items():
|
||||
matches = list(re.finditer(pattern, content))
|
||||
for match in matches:
|
||||
old_line = match.group(0)
|
||||
|
||||
# 处理函数类型的替换
|
||||
if callable(replacement):
|
||||
new_line_result = replacement(match)
|
||||
new_line = new_line_result if isinstance(new_line_result, str) else old_line
|
||||
else:
|
||||
new_line = re.sub(pattern, replacement, old_line)
|
||||
|
||||
if old_line != new_line and isinstance(new_line, str):
|
||||
content = content.replace(old_line, new_line, 1)
|
||||
changes.append(f" - {old_line}")
|
||||
changes.append(f" + {new_line}")
|
||||
|
||||
# 如果有修改且不是dry_run,写回文件
|
||||
if content != original_content:
|
||||
if not dry_run:
|
||||
file_path.write_text(content, encoding="utf-8")
|
||||
return len(changes) // 2, changes
|
||||
|
||||
return 0, []
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 处理文件 {file_path} 时出错: {e}")
|
||||
return 0, []
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
print("🔍 搜索需要更新导入的文件...")
|
||||
|
||||
# 获取项目根目录
|
||||
root_dir = Path(__file__).parent.parent
|
||||
|
||||
# 搜索所有Python文件
|
||||
all_python_files = list(root_dir.rglob("*.py"))
|
||||
|
||||
# 过滤掉排除的文件
|
||||
target_files = [f for f in all_python_files if not should_exclude(f)]
|
||||
|
||||
print(f"📊 找到 {len(target_files)} 个Python文件需要检查")
|
||||
print("\n" + "="*80)
|
||||
|
||||
# 第一遍:预览模式
|
||||
print("\n🔍 预览模式 - 检查需要更新的文件...\n")
|
||||
|
||||
files_to_update = []
|
||||
for file_path in target_files:
|
||||
count, changes = update_imports_in_file(file_path, dry_run=True)
|
||||
if count > 0:
|
||||
files_to_update.append((file_path, count, changes))
|
||||
|
||||
if not files_to_update:
|
||||
print("✅ 没有文件需要更新!")
|
||||
return
|
||||
|
||||
print(f"📝 发现 {len(files_to_update)} 个文件需要更新:\n")
|
||||
|
||||
total_changes = 0
|
||||
for file_path, count, changes in files_to_update:
|
||||
rel_path = file_path.relative_to(root_dir)
|
||||
print(f"\n📄 {rel_path} ({count} 处修改)")
|
||||
for change in changes[:10]: # 最多显示前5对修改
|
||||
print(change)
|
||||
if len(changes) > 10:
|
||||
print(f" ... 还有 {len(changes) - 10} 行")
|
||||
total_changes += count
|
||||
|
||||
print("\n" + "="*80)
|
||||
print("\n📊 统计:")
|
||||
print(f" - 需要更新的文件: {len(files_to_update)}")
|
||||
print(f" - 总修改次数: {total_changes}")
|
||||
|
||||
# 询问是否继续
|
||||
print("\n" + "="*80)
|
||||
response = input("\n是否执行更新?(yes/no): ").strip().lower()
|
||||
|
||||
if response != "yes":
|
||||
print("❌ 已取消更新")
|
||||
return
|
||||
|
||||
# 第二遍:实际更新
|
||||
print("\n✨ 开始更新文件...\n")
|
||||
|
||||
success_count = 0
|
||||
for file_path, _, _ in files_to_update:
|
||||
count, _ = update_imports_in_file(file_path, dry_run=False)
|
||||
if count > 0:
|
||||
rel_path = file_path.relative_to(root_dir)
|
||||
print(f"✅ {rel_path} ({count} 处修改)")
|
||||
success_count += 1
|
||||
|
||||
print("\n" + "="*80)
|
||||
print(f"\n🎉 完成!成功更新 {success_count} 个文件")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user