Files
Mofox-Core/src/migrate_helper/migrate.py
2025-08-13 23:17:28 +08:00

312 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import json
import os
import asyncio
from src.common.database.database_model import GraphNodes
from src.common.logger import get_logger
logger = get_logger("migrate")
async def migrate_memory_items_to_string():
"""
将数据库中记忆节点的memory_items从list格式迁移到string格式
并根据原始list的项目数量设置weight值
"""
logger.info("开始迁移记忆节点格式...")
migration_stats = {
"total_nodes": 0,
"converted_nodes": 0,
"already_string_nodes": 0,
"empty_nodes": 0,
"error_nodes": 0,
"weight_updated_nodes": 0,
"truncated_nodes": 0
}
try:
# 获取所有图节点
all_nodes = GraphNodes.select()
migration_stats["total_nodes"] = all_nodes.count()
logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点")
for node in all_nodes:
try:
concept = node.concept
memory_items_raw = node.memory_items.strip() if node.memory_items else ""
original_weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0
# 如果为空,跳过
if not memory_items_raw:
migration_stats["empty_nodes"] += 1
logger.debug(f"跳过空节点: {concept}")
continue
try:
# 尝试解析JSON
parsed_data = json.loads(memory_items_raw)
if isinstance(parsed_data, list):
# 如果是list格式需要转换
if parsed_data:
# 转换为字符串格式
new_memory_items = " | ".join(str(item) for item in parsed_data)
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符")
new_weight = float(len(parsed_data)) # weight = list项目数量
# 更新数据库
node.memory_items = new_memory_items
node.weight = new_weight
node.save()
migration_stats["converted_nodes"] += 1
migration_stats["weight_updated_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.info(f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}")
else:
# 空list设置为空字符串
node.memory_items = ""
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
logger.debug(f"转换空list节点: {concept}")
elif isinstance(parsed_data, str):
# 已经是字符串格式检查长度和weight
current_content = parsed_data
original_length = len(current_content)
content_truncated = False
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
content_truncated = True
migration_stats["truncated_nodes"] += 1
node.memory_items = current_content
logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符")
# 检查weight是否需要更新
update_needed = False
if original_weight == 1.0:
# 如果weight还是默认值可以根据内容复杂度估算
content_parts = current_content.split(" | ") if " | " in current_content else [current_content]
estimated_weight = max(1.0, float(len(content_parts)))
if estimated_weight != original_weight:
node.weight = estimated_weight
update_needed = True
logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}")
# 如果内容被截断或权重需要更新,保存到数据库
if content_truncated or update_needed:
node.save()
if update_needed:
migration_stats["weight_updated_nodes"] += 1
if content_truncated:
migration_stats["converted_nodes"] += 1 # 算作转换节点
else:
migration_stats["already_string_nodes"] += 1
else:
migration_stats["already_string_nodes"] += 1
else:
# 其他JSON类型转换为字符串
new_memory_items = str(parsed_data) if parsed_data else ""
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = new_memory_items
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"转换其他类型节点: {concept}{length_info}")
except json.JSONDecodeError:
# 不是JSON格式假设已经是纯字符串
# 检查是否是带引号的字符串
if memory_items_raw.startswith('"') and memory_items_raw.endswith('"'):
# 去掉引号
clean_content = memory_items_raw[1:-1]
original_length = len(clean_content)
# 检查长度并截断
if len(clean_content) > 100:
clean_content = clean_content[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = clean_content
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"去除引号节点: {concept}{length_info}")
else:
# 已经是纯字符串格式,检查长度
current_content = memory_items_raw
original_length = len(current_content)
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
node.memory_items = current_content
node.save()
migration_stats["converted_nodes"] += 1 # 算作转换节点
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符")
else:
migration_stats["already_string_nodes"] += 1
logger.debug(f"已是字符串格式节点: {concept}")
except Exception as e:
migration_stats["error_nodes"] += 1
logger.error(f"处理节点 {concept} 时发生错误: {e}")
continue
except Exception as e:
logger.error(f"迁移过程中发生严重错误: {e}")
raise
# 输出迁移统计
logger.info("=== 记忆节点迁移完成 ===")
logger.info(f"总节点数: {migration_stats['total_nodes']}")
logger.info(f"已转换节点: {migration_stats['converted_nodes']}")
logger.info(f"已是字符串格式: {migration_stats['already_string_nodes']}")
logger.info(f"空节点: {migration_stats['empty_nodes']}")
logger.info(f"错误节点: {migration_stats['error_nodes']}")
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
success_rate = (migration_stats['converted_nodes'] + migration_stats['already_string_nodes']) / migration_stats['total_nodes'] * 100 if migration_stats['total_nodes'] > 0 else 0
logger.info(f"迁移成功率: {success_rate:.1f}%")
return migration_stats
async def set_all_person_known():
"""
将person_info库中所有记录的is_known字段设置为True
在设置之前先清理掉user_id或platform为空的记录
"""
logger.info("开始设置所有person_info记录为已认识...")
try:
from src.common.database.database_model import PersonInfo
# 获取所有PersonInfo记录
all_persons = PersonInfo.select()
total_count = all_persons.count()
logger.info(f"找到 {total_count} 个人员记录")
if total_count == 0:
logger.info("没有找到任何人员记录")
return {"total": 0, "deleted": 0, "updated": 0, "known_count": 0}
# 删除user_id或platform为空的记录
deleted_count = 0
invalid_records = PersonInfo.select().where(
(PersonInfo.user_id.is_null()) |
(PersonInfo.user_id == '') |
(PersonInfo.platform.is_null()) |
(PersonInfo.platform == '')
)
# 记录要删除的记录信息
for record in invalid_records:
user_id_info = f"'{record.user_id}'" if record.user_id else "NULL"
platform_info = f"'{record.platform}'" if record.platform else "NULL"
person_name_info = f"'{record.person_name}'" if record.person_name else "无名称"
logger.debug(f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}")
# 执行删除操作
deleted_count = PersonInfo.delete().where(
(PersonInfo.user_id.is_null()) |
(PersonInfo.user_id == '') |
(PersonInfo.platform.is_null()) |
(PersonInfo.platform == '')
).execute()
if deleted_count > 0:
logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录")
else:
logger.info("没有发现user_id或platform为空的记录")
# 重新获取剩余记录数量
remaining_count = PersonInfo.select().count()
logger.info(f"清理后剩余 {remaining_count} 个有效记录")
if remaining_count == 0:
logger.info("清理后没有剩余记录")
return {"total": total_count, "deleted": deleted_count, "updated": 0, "known_count": 0}
# 批量更新剩余记录的is_known字段为True
updated_count = PersonInfo.update(is_known=True).execute()
logger.info(f"成功更新 {updated_count} 个人员记录的is_known字段为True")
# 验证更新结果
known_count = PersonInfo.select().where(PersonInfo.is_known).count()
result = {
"total": total_count,
"deleted": deleted_count,
"updated": updated_count,
"known_count": known_count
}
logger.info("=== person_info更新完成 ===")
logger.info(f"原始记录数: {result['total']}")
logger.info(f"删除记录数: {result['deleted']}")
logger.info(f"更新记录数: {result['updated']}")
logger.info(f"已认识记录数: {result['known_count']}")
return result
except Exception as e:
logger.error(f"更新person_info过程中发生错误: {e}")
raise
async def check_and_run_migrations():
# 获取根目录
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
data_dir = os.path.join(project_root, "data")
temp_dir = os.path.join(data_dir, "temp")
done_file = os.path.join(temp_dir, "done.mem")
# 检查done.mem是否存在
if not os.path.exists(done_file):
# 如果temp目录不存在则创建
if not os.path.exists(temp_dir):
os.makedirs(temp_dir, exist_ok=True)
# 执行迁移函数
# 依次执行两个异步函数
await asyncio.sleep(3)
await migrate_memory_items_to_string()
await set_all_person_known()
# 创建done.mem文件
with open(done_file, "w", encoding="utf-8") as f:
f.write("done")