feat:更新记忆系统

This commit is contained in:
SengokuCola
2025-08-13 23:17:28 +08:00
parent 3962fc601f
commit fed0c0fd04
10 changed files with 732 additions and 406 deletions

View File

View File

@@ -0,0 +1,312 @@
import json
import os
import asyncio
from src.common.database.database_model import GraphNodes
from src.common.logger import get_logger
logger = get_logger("migrate")
async def migrate_memory_items_to_string():
"""
将数据库中记忆节点的memory_items从list格式迁移到string格式
并根据原始list的项目数量设置weight值
"""
logger.info("开始迁移记忆节点格式...")
migration_stats = {
"total_nodes": 0,
"converted_nodes": 0,
"already_string_nodes": 0,
"empty_nodes": 0,
"error_nodes": 0,
"weight_updated_nodes": 0,
"truncated_nodes": 0
}
try:
# 获取所有图节点
all_nodes = GraphNodes.select()
migration_stats["total_nodes"] = all_nodes.count()
logger.info(f"找到 {migration_stats['total_nodes']} 个记忆节点")
for node in all_nodes:
try:
concept = node.concept
memory_items_raw = node.memory_items.strip() if node.memory_items else ""
original_weight = node.weight if hasattr(node, 'weight') and node.weight is not None else 1.0
# 如果为空,跳过
if not memory_items_raw:
migration_stats["empty_nodes"] += 1
logger.debug(f"跳过空节点: {concept}")
continue
try:
# 尝试解析JSON
parsed_data = json.loads(memory_items_raw)
if isinstance(parsed_data, list):
# 如果是list格式需要转换
if parsed_data:
# 转换为字符串格式
new_memory_items = " | ".join(str(item) for item in parsed_data)
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 内容过长,从 {original_length} 字符截断到 100 字符")
new_weight = float(len(parsed_data)) # weight = list项目数量
# 更新数据库
node.memory_items = new_memory_items
node.weight = new_weight
node.save()
migration_stats["converted_nodes"] += 1
migration_stats["weight_updated_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.info(f"转换节点 '{concept}': {len(parsed_data)} 项 -> 字符串{length_info}, weight: {original_weight} -> {new_weight}")
else:
# 空list设置为空字符串
node.memory_items = ""
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
logger.debug(f"转换空list节点: {concept}")
elif isinstance(parsed_data, str):
# 已经是字符串格式检查长度和weight
current_content = parsed_data
original_length = len(current_content)
content_truncated = False
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
content_truncated = True
migration_stats["truncated_nodes"] += 1
node.memory_items = current_content
logger.debug(f"节点 '{concept}' 字符串内容过长,从 {original_length} 字符截断到 100 字符")
# 检查weight是否需要更新
update_needed = False
if original_weight == 1.0:
# 如果weight还是默认值可以根据内容复杂度估算
content_parts = current_content.split(" | ") if " | " in current_content else [current_content]
estimated_weight = max(1.0, float(len(content_parts)))
if estimated_weight != original_weight:
node.weight = estimated_weight
update_needed = True
logger.debug(f"更新字符串节点权重 '{concept}': {original_weight} -> {estimated_weight}")
# 如果内容被截断或权重需要更新,保存到数据库
if content_truncated or update_needed:
node.save()
if update_needed:
migration_stats["weight_updated_nodes"] += 1
if content_truncated:
migration_stats["converted_nodes"] += 1 # 算作转换节点
else:
migration_stats["already_string_nodes"] += 1
else:
migration_stats["already_string_nodes"] += 1
else:
# 其他JSON类型转换为字符串
new_memory_items = str(parsed_data) if parsed_data else ""
original_length = len(new_memory_items)
# 检查长度并截断
if len(new_memory_items) > 100:
new_memory_items = new_memory_items[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 其他类型内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = new_memory_items
node.weight = 1.0
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"转换其他类型节点: {concept}{length_info}")
except json.JSONDecodeError:
# 不是JSON格式假设已经是纯字符串
# 检查是否是带引号的字符串
if memory_items_raw.startswith('"') and memory_items_raw.endswith('"'):
# 去掉引号
clean_content = memory_items_raw[1:-1]
original_length = len(clean_content)
# 检查长度并截断
if len(clean_content) > 100:
clean_content = clean_content[:100]
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 去引号内容过长,从 {original_length} 字符截断到 100 字符")
node.memory_items = clean_content
node.save()
migration_stats["converted_nodes"] += 1
length_info = f" (截断: {original_length}→100)" if original_length > 100 else ""
logger.debug(f"去除引号节点: {concept}{length_info}")
else:
# 已经是纯字符串格式,检查长度
current_content = memory_items_raw
original_length = len(current_content)
# 检查长度并截断
if len(current_content) > 100:
current_content = current_content[:100]
node.memory_items = current_content
node.save()
migration_stats["converted_nodes"] += 1 # 算作转换节点
migration_stats["truncated_nodes"] += 1
logger.debug(f"节点 '{concept}' 纯字符串内容过长,从 {original_length} 字符截断到 100 字符")
else:
migration_stats["already_string_nodes"] += 1
logger.debug(f"已是字符串格式节点: {concept}")
except Exception as e:
migration_stats["error_nodes"] += 1
logger.error(f"处理节点 {concept} 时发生错误: {e}")
continue
except Exception as e:
logger.error(f"迁移过程中发生严重错误: {e}")
raise
# 输出迁移统计
logger.info("=== 记忆节点迁移完成 ===")
logger.info(f"总节点数: {migration_stats['total_nodes']}")
logger.info(f"已转换节点: {migration_stats['converted_nodes']}")
logger.info(f"已是字符串格式: {migration_stats['already_string_nodes']}")
logger.info(f"空节点: {migration_stats['empty_nodes']}")
logger.info(f"错误节点: {migration_stats['error_nodes']}")
logger.info(f"权重更新节点: {migration_stats['weight_updated_nodes']}")
logger.info(f"内容截断节点: {migration_stats['truncated_nodes']}")
success_rate = (migration_stats['converted_nodes'] + migration_stats['already_string_nodes']) / migration_stats['total_nodes'] * 100 if migration_stats['total_nodes'] > 0 else 0
logger.info(f"迁移成功率: {success_rate:.1f}%")
return migration_stats
async def set_all_person_known():
"""
将person_info库中所有记录的is_known字段设置为True
在设置之前先清理掉user_id或platform为空的记录
"""
logger.info("开始设置所有person_info记录为已认识...")
try:
from src.common.database.database_model import PersonInfo
# 获取所有PersonInfo记录
all_persons = PersonInfo.select()
total_count = all_persons.count()
logger.info(f"找到 {total_count} 个人员记录")
if total_count == 0:
logger.info("没有找到任何人员记录")
return {"total": 0, "deleted": 0, "updated": 0, "known_count": 0}
# 删除user_id或platform为空的记录
deleted_count = 0
invalid_records = PersonInfo.select().where(
(PersonInfo.user_id.is_null()) |
(PersonInfo.user_id == '') |
(PersonInfo.platform.is_null()) |
(PersonInfo.platform == '')
)
# 记录要删除的记录信息
for record in invalid_records:
user_id_info = f"'{record.user_id}'" if record.user_id else "NULL"
platform_info = f"'{record.platform}'" if record.platform else "NULL"
person_name_info = f"'{record.person_name}'" if record.person_name else "无名称"
logger.debug(f"删除无效记录: person_id={record.person_id}, user_id={user_id_info}, platform={platform_info}, person_name={person_name_info}")
# 执行删除操作
deleted_count = PersonInfo.delete().where(
(PersonInfo.user_id.is_null()) |
(PersonInfo.user_id == '') |
(PersonInfo.platform.is_null()) |
(PersonInfo.platform == '')
).execute()
if deleted_count > 0:
logger.info(f"删除了 {deleted_count} 个user_id或platform为空的记录")
else:
logger.info("没有发现user_id或platform为空的记录")
# 重新获取剩余记录数量
remaining_count = PersonInfo.select().count()
logger.info(f"清理后剩余 {remaining_count} 个有效记录")
if remaining_count == 0:
logger.info("清理后没有剩余记录")
return {"total": total_count, "deleted": deleted_count, "updated": 0, "known_count": 0}
# 批量更新剩余记录的is_known字段为True
updated_count = PersonInfo.update(is_known=True).execute()
logger.info(f"成功更新 {updated_count} 个人员记录的is_known字段为True")
# 验证更新结果
known_count = PersonInfo.select().where(PersonInfo.is_known).count()
result = {
"total": total_count,
"deleted": deleted_count,
"updated": updated_count,
"known_count": known_count
}
logger.info("=== person_info更新完成 ===")
logger.info(f"原始记录数: {result['total']}")
logger.info(f"删除记录数: {result['deleted']}")
logger.info(f"更新记录数: {result['updated']}")
logger.info(f"已认识记录数: {result['known_count']}")
return result
except Exception as e:
logger.error(f"更新person_info过程中发生错误: {e}")
raise
async def check_and_run_migrations():
# 获取根目录
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
data_dir = os.path.join(project_root, "data")
temp_dir = os.path.join(data_dir, "temp")
done_file = os.path.join(temp_dir, "done.mem")
# 检查done.mem是否存在
if not os.path.exists(done_file):
# 如果temp目录不存在则创建
if not os.path.exists(temp_dir):
os.makedirs(temp_dir, exist_ok=True)
# 执行迁移函数
# 依次执行两个异步函数
await asyncio.sleep(3)
await migrate_memory_items_to_string()
await set_all_person_known()
# 创建done.mem文件
with open(done_file, "w", encoding="utf-8") as f:
f.write("done")