Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox-Core into dev
This commit is contained in:
@@ -31,6 +31,7 @@ if str(PROJECT_ROOT) not in sys.path:
|
||||
|
||||
# 切换工作目录到项目根目录
|
||||
import os
|
||||
|
||||
os.chdir(PROJECT_ROOT)
|
||||
|
||||
# 日志目录
|
||||
|
||||
@@ -25,8 +25,6 @@ sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
# ==================== 配置 ====================
|
||||
|
||||
@@ -82,7 +80,7 @@ EVALUATION_PROMPT = """你是一个非常严格的记忆价值评估专家。你
|
||||
|
||||
**保留示例**:
|
||||
- "用户张三说他是程序员,在杭州工作" ✅
|
||||
- "李四说他喜欢打篮球,每周三都会去" ✅
|
||||
- "李四说他喜欢打篮球,每周三都会去" ✅
|
||||
- "小明说他女朋友叫小红,在一起2年了" ✅
|
||||
- "用户A的生日是3月15日" ✅
|
||||
|
||||
@@ -111,7 +109,7 @@ EVALUATION_PROMPT = """你是一个非常严格的记忆价值评估专家。你
|
||||
}},
|
||||
{{
|
||||
"memory_id": "另一个ID",
|
||||
"action": "keep",
|
||||
"action": "keep",
|
||||
"reason": "保留原因"
|
||||
}}
|
||||
]
|
||||
@@ -134,7 +132,7 @@ class MemoryCleaner:
|
||||
def __init__(self, dry_run: bool = True, batch_size: int = 10, concurrency: int = 5):
|
||||
"""
|
||||
初始化清理器
|
||||
|
||||
|
||||
Args:
|
||||
dry_run: 是否为模拟运行(不实际修改数据)
|
||||
batch_size: 每批处理的记忆数量
|
||||
@@ -146,10 +144,10 @@ class MemoryCleaner:
|
||||
self.data_dir = project_root / "data" / "memory_graph"
|
||||
self.memory_file = self.data_dir / "memory_graph.json"
|
||||
self.backup_dir = self.data_dir / "backups"
|
||||
|
||||
|
||||
# 并发控制
|
||||
self.semaphore: asyncio.Semaphore | None = None
|
||||
|
||||
|
||||
# 统计信息
|
||||
self.stats = {
|
||||
"total": 0,
|
||||
@@ -160,7 +158,7 @@ class MemoryCleaner:
|
||||
"deleted_nodes": 0,
|
||||
"deleted_edges": 0,
|
||||
}
|
||||
|
||||
|
||||
# 日志文件
|
||||
self.log_file = self.data_dir / f"cleanup_log_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
self.cleanup_log = []
|
||||
@@ -168,23 +166,23 @@ class MemoryCleaner:
|
||||
def load_memories(self) -> dict:
|
||||
"""加载记忆数据"""
|
||||
print(f"📂 加载记忆文件: {self.memory_file}")
|
||||
|
||||
|
||||
if not self.memory_file.exists():
|
||||
raise FileNotFoundError(f"记忆文件不存在: {self.memory_file}")
|
||||
|
||||
with open(self.memory_file, "r", encoding="utf-8") as f:
|
||||
|
||||
with open(self.memory_file, encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
|
||||
|
||||
return data
|
||||
|
||||
def extract_memory_text(self, memory_dict: dict) -> str:
|
||||
"""从记忆字典中提取可读文本"""
|
||||
parts = []
|
||||
|
||||
|
||||
# 提取基本信息
|
||||
memory_id = memory_dict.get("id", "unknown")
|
||||
parts.append(f"ID: {memory_id}")
|
||||
|
||||
|
||||
# 提取节点内容
|
||||
nodes = memory_dict.get("nodes", [])
|
||||
for node in nodes:
|
||||
@@ -192,14 +190,14 @@ class MemoryCleaner:
|
||||
content = node.get("content", "")
|
||||
if content:
|
||||
parts.append(f"[{node_type}] {content}")
|
||||
|
||||
|
||||
# 提取边关系
|
||||
edges = memory_dict.get("edges", [])
|
||||
for edge in edges:
|
||||
relation = edge.get("relation", "")
|
||||
if relation:
|
||||
parts.append(f"关系: {relation}")
|
||||
|
||||
|
||||
# 提取元数据
|
||||
metadata = memory_dict.get("metadata", {})
|
||||
if metadata:
|
||||
@@ -207,24 +205,24 @@ class MemoryCleaner:
|
||||
parts.append(f"上下文: {metadata['context']}")
|
||||
if "emotion" in metadata:
|
||||
parts.append(f"情感: {metadata['emotion']}")
|
||||
|
||||
|
||||
# 提取重要性和状态
|
||||
importance = memory_dict.get("importance", 0)
|
||||
status = memory_dict.get("status", "unknown")
|
||||
created_at = memory_dict.get("created_at", "unknown")
|
||||
|
||||
|
||||
parts.append(f"重要性: {importance}, 状态: {status}, 创建时间: {created_at}")
|
||||
|
||||
|
||||
return "\n".join(parts)
|
||||
|
||||
async def evaluate_batch(self, memories: list[dict], batch_id: int = 0) -> tuple[int, list[dict]]:
|
||||
"""
|
||||
使用 LLM 评估一批记忆(带并发控制)
|
||||
|
||||
|
||||
Args:
|
||||
memories: 记忆字典列表
|
||||
batch_id: 批次编号
|
||||
|
||||
|
||||
Returns:
|
||||
(批次ID, 评估结果列表)
|
||||
"""
|
||||
@@ -234,27 +232,27 @@ class MemoryCleaner:
|
||||
for i, mem in enumerate(memories):
|
||||
text = self.extract_memory_text(mem)
|
||||
memory_texts.append(f"=== 记忆 {i+1} ===\n{text}")
|
||||
|
||||
|
||||
combined_text = "\n\n".join(memory_texts)
|
||||
prompt = EVALUATION_PROMPT.format(memories=combined_text)
|
||||
|
||||
|
||||
try:
|
||||
# 使用 LLMRequest 调用模型
|
||||
if model_config is None:
|
||||
raise RuntimeError("model_config 未初始化,请确保已加载配置")
|
||||
task_config = model_config.model_task_config.utils
|
||||
llm = LLMRequest(task_config, request_type="memory_cleanup")
|
||||
response_text, (reasoning, model_name, _) = await llm.generate_response_async(
|
||||
response_text, (_reasoning, model_name, _) = await llm.generate_response_async(
|
||||
prompt=prompt,
|
||||
temperature=0.2,
|
||||
max_tokens=4000,
|
||||
)
|
||||
|
||||
|
||||
print(f" ✅ 批次 {batch_id} 完成 (模型: {model_name})")
|
||||
|
||||
|
||||
# 解析 JSON 响应
|
||||
response_text = response_text.strip()
|
||||
|
||||
|
||||
# 尝试提取 JSON
|
||||
if "```json" in response_text:
|
||||
json_start = response_text.find("```json") + 7
|
||||
@@ -264,17 +262,17 @@ class MemoryCleaner:
|
||||
json_start = response_text.find("```") + 3
|
||||
json_end = response_text.find("```", json_start)
|
||||
response_text = response_text[json_start:json_end].strip()
|
||||
|
||||
|
||||
result = json.loads(response_text)
|
||||
evaluations = result.get("evaluations", [])
|
||||
|
||||
|
||||
# 为评估结果添加实际的 memory_id
|
||||
for j, eval_result in enumerate(evaluations):
|
||||
if j < len(memories):
|
||||
eval_result["memory_id"] = memories[j].get("id", f"unknown_{batch_id}_{j}")
|
||||
|
||||
|
||||
return (batch_id, evaluations)
|
||||
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" ❌ 批次 {batch_id} JSON 解析失败: {e}")
|
||||
return (batch_id, [])
|
||||
@@ -291,36 +289,36 @@ class MemoryCleaner:
|
||||
"""创建数据备份"""
|
||||
self.backup_dir.mkdir(parents=True, exist_ok=True)
|
||||
backup_file = self.backup_dir / f"memory_graph_backup_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
|
||||
|
||||
|
||||
print(f"💾 创建备份: {backup_file}")
|
||||
with open(backup_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
return backup_file
|
||||
|
||||
def apply_changes(self, data: dict, evaluations: list[dict]) -> dict:
|
||||
"""
|
||||
应用评估结果到数据
|
||||
|
||||
|
||||
Args:
|
||||
data: 原始数据
|
||||
evaluations: 评估结果列表
|
||||
|
||||
|
||||
Returns:
|
||||
修改后的数据
|
||||
"""
|
||||
# 创建评估结果索引
|
||||
eval_map = {e["memory_id"]: e for e in evaluations if "memory_id" in e}
|
||||
|
||||
{e["memory_id"]: e for e in evaluations if "memory_id" in e}
|
||||
|
||||
# 需要删除的记忆 ID
|
||||
to_delete = set()
|
||||
# 需要更新的记忆
|
||||
to_update = {}
|
||||
|
||||
|
||||
for eval_result in evaluations:
|
||||
memory_id = eval_result.get("memory_id")
|
||||
action = eval_result.get("action")
|
||||
|
||||
|
||||
if action == "delete":
|
||||
to_delete.add(memory_id)
|
||||
self.stats["deleted"] += 1
|
||||
@@ -342,18 +340,18 @@ class MemoryCleaner:
|
||||
})
|
||||
else:
|
||||
self.stats["kept"] += 1
|
||||
|
||||
|
||||
if self.dry_run:
|
||||
print("🔍 [DRY RUN] 不实际修改数据")
|
||||
return data
|
||||
|
||||
|
||||
# 实际修改数据
|
||||
# 1. 删除记忆
|
||||
memories = data.get("memories", {})
|
||||
for mem_id in to_delete:
|
||||
if mem_id in memories:
|
||||
del memories[mem_id]
|
||||
|
||||
|
||||
# 2. 更新记忆内容
|
||||
for mem_id, new_content in to_update.items():
|
||||
if mem_id in memories:
|
||||
@@ -363,42 +361,42 @@ class MemoryCleaner:
|
||||
if node.get("node_type") in ["主题", "topic", "TOPIC"]:
|
||||
node["content"] = new_content
|
||||
break
|
||||
|
||||
|
||||
# 3. 清理孤立节点和边
|
||||
data = self.cleanup_orphaned_nodes_and_edges(data)
|
||||
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def cleanup_orphaned_nodes_and_edges(self, data: dict) -> dict:
|
||||
"""
|
||||
清理孤立的节点和边
|
||||
|
||||
|
||||
孤立节点:其 metadata.memory_ids 中的所有记忆都已被删除
|
||||
孤立边:其 source 或 target 节点已被删除
|
||||
"""
|
||||
print("\n🔗 清理孤立节点和边...")
|
||||
|
||||
|
||||
# 获取当前所有有效的记忆 ID
|
||||
valid_memory_ids = set(data.get("memories", {}).keys())
|
||||
print(f" 有效记忆数: {len(valid_memory_ids)}")
|
||||
|
||||
|
||||
# 清理节点
|
||||
nodes = data.get("nodes", [])
|
||||
original_node_count = len(nodes)
|
||||
|
||||
|
||||
valid_nodes = []
|
||||
valid_node_ids = set()
|
||||
|
||||
|
||||
for node in nodes:
|
||||
node_id = node.get("id")
|
||||
metadata = node.get("metadata", {})
|
||||
memory_ids = metadata.get("memory_ids", [])
|
||||
|
||||
|
||||
# 检查节点关联的记忆是否还存在
|
||||
if memory_ids:
|
||||
# 过滤掉已删除的记忆 ID
|
||||
remaining_memory_ids = [mid for mid in memory_ids if mid in valid_memory_ids]
|
||||
|
||||
|
||||
if remaining_memory_ids:
|
||||
# 更新 metadata 中的 memory_ids
|
||||
metadata["memory_ids"] = remaining_memory_ids
|
||||
@@ -410,32 +408,32 @@ class MemoryCleaner:
|
||||
# 保守处理:保留这些节点
|
||||
valid_nodes.append(node)
|
||||
valid_node_ids.add(node_id)
|
||||
|
||||
|
||||
deleted_nodes = original_node_count - len(valid_nodes)
|
||||
data["nodes"] = valid_nodes
|
||||
print(f" ✅ 节点: {original_node_count} → {len(valid_nodes)} (删除 {deleted_nodes})")
|
||||
|
||||
|
||||
# 清理边
|
||||
edges = data.get("edges", [])
|
||||
original_edge_count = len(edges)
|
||||
|
||||
|
||||
valid_edges = []
|
||||
for edge in edges:
|
||||
source = edge.get("source")
|
||||
target = edge.get("target")
|
||||
|
||||
|
||||
# 只保留两端节点都存在的边
|
||||
if source in valid_node_ids and target in valid_node_ids:
|
||||
valid_edges.append(edge)
|
||||
|
||||
|
||||
deleted_edges = original_edge_count - len(valid_edges)
|
||||
data["edges"] = valid_edges
|
||||
print(f" ✅ 边: {original_edge_count} → {len(valid_edges)} (删除 {deleted_edges})")
|
||||
|
||||
|
||||
# 更新统计
|
||||
self.stats["deleted_nodes"] = deleted_nodes
|
||||
self.stats["deleted_edges"] = deleted_edges
|
||||
|
||||
|
||||
return data
|
||||
|
||||
def save_data(self, data: dict):
|
||||
@@ -443,7 +441,7 @@ class MemoryCleaner:
|
||||
if self.dry_run:
|
||||
print("🔍 [DRY RUN] 跳过保存")
|
||||
return
|
||||
|
||||
|
||||
print(f"💾 保存数据到: {self.memory_file}")
|
||||
with open(self.memory_file, "w", encoding="utf-8") as f:
|
||||
json.dump(data, f, ensure_ascii=False, indent=2)
|
||||
@@ -468,88 +466,88 @@ class MemoryCleaner:
|
||||
print(f"批次大小: {self.batch_size}")
|
||||
print(f"并发数: {self.concurrency}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 初始化
|
||||
await self.initialize()
|
||||
|
||||
|
||||
# 加载数据
|
||||
data = self.load_memories()
|
||||
|
||||
|
||||
# 获取所有记忆
|
||||
memories = data.get("memories", {})
|
||||
memory_list = list(memories.values())
|
||||
self.stats["total"] = len(memory_list)
|
||||
|
||||
|
||||
print(f"📊 总记忆数: {self.stats['total']}")
|
||||
|
||||
|
||||
if not memory_list:
|
||||
print("⚠️ 没有记忆需要处理")
|
||||
return
|
||||
|
||||
|
||||
# 创建备份
|
||||
if not self.dry_run:
|
||||
self.create_backup(data)
|
||||
|
||||
|
||||
# 分批
|
||||
batches = []
|
||||
for i in range(0, len(memory_list), self.batch_size):
|
||||
batch = memory_list[i:i + self.batch_size]
|
||||
batches.append(batch)
|
||||
|
||||
|
||||
total_batches = len(batches)
|
||||
print(f"📦 共 {total_batches} 个批次,开始并发处理...\n")
|
||||
|
||||
|
||||
# 并发处理所有批次
|
||||
start_time = datetime.now()
|
||||
tasks = [
|
||||
self.evaluate_batch(batch, batch_id=idx)
|
||||
for idx, batch in enumerate(batches)
|
||||
]
|
||||
|
||||
|
||||
# 使用 asyncio.gather 并发执行
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
end_time = datetime.now()
|
||||
elapsed = (end_time - start_time).total_seconds()
|
||||
|
||||
|
||||
# 收集所有评估结果
|
||||
all_evaluations = []
|
||||
success_count = 0
|
||||
error_count = 0
|
||||
|
||||
|
||||
for result in results:
|
||||
if isinstance(result, Exception):
|
||||
print(f" ❌ 批次异常: {result}")
|
||||
error_count += 1
|
||||
elif isinstance(result, tuple):
|
||||
batch_id, evaluations = result
|
||||
_batch_id, evaluations = result
|
||||
if evaluations:
|
||||
all_evaluations.extend(evaluations)
|
||||
success_count += 1
|
||||
else:
|
||||
error_count += 1
|
||||
|
||||
|
||||
print(f"\n⏱️ 并发处理完成,耗时 {elapsed:.1f} 秒")
|
||||
print(f" 成功批次: {success_count}/{total_batches}, 失败: {error_count}")
|
||||
|
||||
|
||||
# 统计评估结果
|
||||
delete_count = sum(1 for e in all_evaluations if e.get("action") == "delete")
|
||||
keep_count = sum(1 for e in all_evaluations if e.get("action") == "keep")
|
||||
summarize_count = sum(1 for e in all_evaluations if e.get("action") == "summarize")
|
||||
|
||||
|
||||
print(f" 📊 评估结果: 保留 {keep_count}, 删除 {delete_count}, 精简 {summarize_count}")
|
||||
|
||||
|
||||
# 应用更改
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 应用更改...")
|
||||
data = self.apply_changes(data, all_evaluations)
|
||||
|
||||
|
||||
# 保存数据
|
||||
self.save_data(data)
|
||||
|
||||
|
||||
# 保存日志
|
||||
self.save_log()
|
||||
|
||||
|
||||
# 打印统计
|
||||
print("\n" + "=" * 60)
|
||||
print("📊 清理统计")
|
||||
@@ -563,7 +561,7 @@ class MemoryCleaner:
|
||||
print(f"错误: {self.stats['errors']}")
|
||||
print(f"处理速度: {self.stats['total'] / elapsed:.1f} 条/秒")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
if self.dry_run:
|
||||
print("\n⚠️ 这是模拟运行,实际数据未被修改")
|
||||
print("如要实际执行,请移除 --dry-run 参数")
|
||||
@@ -575,25 +573,25 @@ class MemoryCleaner:
|
||||
print("=" * 60)
|
||||
print(f"模式: {'模拟运行 (DRY RUN)' if self.dry_run else '实际执行'}")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
# 加载数据
|
||||
data = self.load_memories()
|
||||
|
||||
|
||||
# 统计原始数据
|
||||
memories = data.get("memories", {})
|
||||
nodes = data.get("nodes", [])
|
||||
edges = data.get("edges", [])
|
||||
|
||||
|
||||
print(f"📊 当前状态: {len(memories)} 条记忆, {len(nodes)} 个节点, {len(edges)} 条边")
|
||||
|
||||
|
||||
if not self.dry_run:
|
||||
self.create_backup(data)
|
||||
|
||||
|
||||
# 清理孤立节点和边
|
||||
if self.dry_run:
|
||||
# 模拟运行:统计但不修改
|
||||
valid_memory_ids = set(memories.keys())
|
||||
|
||||
|
||||
# 统计要删除的节点
|
||||
nodes_to_keep = 0
|
||||
for node in nodes:
|
||||
@@ -605,9 +603,9 @@ class MemoryCleaner:
|
||||
nodes_to_keep += 1
|
||||
else:
|
||||
nodes_to_keep += 1
|
||||
|
||||
|
||||
nodes_to_delete = len(nodes) - nodes_to_keep
|
||||
|
||||
|
||||
# 统计要删除的边(需要先确定哪些节点会被保留)
|
||||
valid_node_ids = set()
|
||||
for node in nodes:
|
||||
@@ -619,11 +617,11 @@ class MemoryCleaner:
|
||||
valid_node_ids.add(node.get("id"))
|
||||
else:
|
||||
valid_node_ids.add(node.get("id"))
|
||||
|
||||
|
||||
edges_to_keep = sum(1 for e in edges if e.get("source") in valid_node_ids and e.get("target") in valid_node_ids)
|
||||
edges_to_delete = len(edges) - edges_to_keep
|
||||
|
||||
print(f"\n🔍 [DRY RUN] 预计清理:")
|
||||
|
||||
print("\n🔍 [DRY RUN] 预计清理:")
|
||||
print(f" 节点: {len(nodes)} → {nodes_to_keep} (删除 {nodes_to_delete})")
|
||||
print(f" 边: {len(edges)} → {edges_to_keep} (删除 {edges_to_delete})")
|
||||
print("\n⚠️ 这是模拟运行,实际数据未被修改")
|
||||
@@ -631,8 +629,8 @@ class MemoryCleaner:
|
||||
else:
|
||||
data = self.cleanup_orphaned_nodes_and_edges(data)
|
||||
self.save_data(data)
|
||||
|
||||
print(f"\n✅ 清理完成!")
|
||||
|
||||
print("\n✅ 清理完成!")
|
||||
print(f" 删除节点: {self.stats['deleted_nodes']}")
|
||||
print(f" 删除边: {self.stats['deleted_edges']}")
|
||||
|
||||
@@ -661,15 +659,15 @@ async def main():
|
||||
action="store_true",
|
||||
help="只清理孤立节点和边,不重新评估记忆"
|
||||
)
|
||||
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
cleaner = MemoryCleaner(
|
||||
dry_run=args.dry_run,
|
||||
batch_size=args.batch_size,
|
||||
concurrency=args.concurrency,
|
||||
)
|
||||
|
||||
|
||||
if args.cleanup_only:
|
||||
await cleaner.run_cleanup_only()
|
||||
else:
|
||||
|
||||
@@ -8,7 +8,7 @@
|
||||
python scripts/migrate_database.py --help
|
||||
python scripts/migrate_database.py --source sqlite --target postgresql
|
||||
python scripts/migrate_database.py --source postgresql --target sqlite --batch-size 5000
|
||||
|
||||
|
||||
# 交互式向导模式(推荐)
|
||||
python scripts/migrate_database.py
|
||||
|
||||
@@ -55,19 +55,21 @@ try:
|
||||
except ImportError:
|
||||
tomllib = None
|
||||
|
||||
from typing import Any, Iterable, Callable
|
||||
|
||||
from collections.abc import Iterable
|
||||
from datetime import datetime as dt
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
create_engine,
|
||||
MetaData,
|
||||
Table,
|
||||
create_engine,
|
||||
inspect,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy import (
|
||||
types as sqltypes,
|
||||
)
|
||||
from sqlalchemy.engine import Engine, Connection
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
# ====== 为了在 Windows 上更友好的输出中文,提前设置环境 ======
|
||||
@@ -320,7 +322,7 @@ def convert_value_for_target(
|
||||
"""
|
||||
# 获取目标类型的类名
|
||||
target_type_name = target_col_type.__class__.__name__.upper()
|
||||
source_type_name = source_col_type.__class__.__name__.upper()
|
||||
source_col_type.__class__.__name__.upper()
|
||||
|
||||
# 处理 None 值
|
||||
if val is None:
|
||||
@@ -500,7 +502,7 @@ def migrate_table_data(
|
||||
target_cols_by_name = {c.key: c for c in target_table.columns}
|
||||
|
||||
# 识别主键列(通常是 id),迁移时保留原始 ID 以避免重复数据
|
||||
primary_key_cols = {c.key for c in source_table.primary_key.columns}
|
||||
{c.key for c in source_table.primary_key.columns}
|
||||
|
||||
# 使用流式查询,避免一次性加载太多数据
|
||||
# 使用 text() 原始 SQL 查询,避免 SQLAlchemy 自动类型转换(如 DateTime)导致的错误
|
||||
@@ -776,7 +778,7 @@ class DatabaseMigrator:
|
||||
for table_name in self.metadata.tables:
|
||||
dependencies[table_name] = set()
|
||||
|
||||
for table_name, table in self.metadata.tables.items():
|
||||
for table_name in self.metadata.tables.keys():
|
||||
fks = inspector.get_foreign_keys(table_name)
|
||||
for fk in fks:
|
||||
# 被引用的表
|
||||
@@ -919,7 +921,7 @@ class DatabaseMigrator:
|
||||
self.stats["errors"].append(f"表 {source_table.name} 迁移失败: {e}")
|
||||
|
||||
self.stats["end_time"] = time.time()
|
||||
|
||||
|
||||
# 迁移完成后,自动修复 PostgreSQL 特有问题
|
||||
if self.target_type == "postgresql" and self.target_engine:
|
||||
fix_postgresql_boolean_columns(self.target_engine)
|
||||
@@ -927,7 +929,6 @@ class DatabaseMigrator:
|
||||
|
||||
def print_summary(self):
|
||||
"""打印迁移总结"""
|
||||
import time
|
||||
|
||||
duration = None
|
||||
if self.stats["start_time"] is not None and self.stats["end_time"] is not None:
|
||||
@@ -1262,104 +1263,104 @@ def interactive_setup() -> dict:
|
||||
|
||||
def fix_postgresql_sequences(engine: Engine):
|
||||
"""修复 PostgreSQL 序列值
|
||||
|
||||
|
||||
迁移数据后,PostgreSQL 的序列(用于自增主键)可能没有更新到正确的值,
|
||||
导致插入新记录时出现主键冲突。此函数会自动检测并重置所有序列。
|
||||
|
||||
|
||||
Args:
|
||||
engine: PostgreSQL 数据库引擎
|
||||
"""
|
||||
if engine.dialect.name != "postgresql":
|
||||
logger.info("非 PostgreSQL 数据库,跳过序列修复")
|
||||
return
|
||||
|
||||
|
||||
logger.info("正在修复 PostgreSQL 序列...")
|
||||
|
||||
|
||||
with engine.connect() as conn:
|
||||
# 获取所有带有序列的表
|
||||
result = conn.execute(text('''
|
||||
SELECT
|
||||
result = conn.execute(text("""
|
||||
SELECT
|
||||
t.table_name,
|
||||
c.column_name,
|
||||
pg_get_serial_sequence(t.table_name, c.column_name) as sequence_name
|
||||
FROM information_schema.tables t
|
||||
JOIN information_schema.columns c
|
||||
JOIN information_schema.columns c
|
||||
ON t.table_name = c.table_name AND t.table_schema = c.table_schema
|
||||
WHERE t.table_schema = 'public'
|
||||
WHERE t.table_schema = 'public'
|
||||
AND t.table_type = 'BASE TABLE'
|
||||
AND c.column_default LIKE 'nextval%'
|
||||
ORDER BY t.table_name
|
||||
'''))
|
||||
|
||||
"""))
|
||||
|
||||
sequences = result.fetchall()
|
||||
logger.info("发现 %d 个带序列的表", len(sequences))
|
||||
|
||||
|
||||
fixed_count = 0
|
||||
for table_name, column_name, seq_name in sequences:
|
||||
if seq_name:
|
||||
try:
|
||||
# 获取当前表中该列的最大值
|
||||
max_result = conn.execute(text(f'SELECT COALESCE(MAX({column_name}), 0) FROM {table_name}'))
|
||||
max_result = conn.execute(text(f"SELECT COALESCE(MAX({column_name}), 0) FROM {table_name}"))
|
||||
max_val = max_result.scalar()
|
||||
|
||||
|
||||
# 设置序列的下一个值
|
||||
next_val = max_val + 1
|
||||
conn.execute(text(f"SELECT setval('{seq_name}', {next_val}, false)"))
|
||||
conn.commit()
|
||||
|
||||
|
||||
logger.info(" ✅ %s.%s: 最大值=%d, 序列设为=%d", table_name, column_name, max_val, next_val)
|
||||
fixed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(" ❌ %s.%s: 修复失败 - %s", table_name, column_name, e)
|
||||
|
||||
|
||||
logger.info("序列修复完成!共修复 %d 个序列", fixed_count)
|
||||
|
||||
|
||||
def fix_postgresql_boolean_columns(engine: Engine):
|
||||
"""修复 PostgreSQL 布尔列类型
|
||||
|
||||
|
||||
从 SQLite 迁移后,布尔列可能是 INTEGER 类型。此函数将其转换为 BOOLEAN。
|
||||
|
||||
|
||||
Args:
|
||||
engine: PostgreSQL 数据库引擎
|
||||
"""
|
||||
if engine.dialect.name != "postgresql":
|
||||
logger.info("非 PostgreSQL 数据库,跳过布尔列修复")
|
||||
return
|
||||
|
||||
|
||||
# 已知需要转换为 BOOLEAN 的列
|
||||
BOOLEAN_COLUMNS = {
|
||||
'messages': ['is_mentioned', 'is_emoji', 'is_picid', 'is_command',
|
||||
'is_notify', 'is_public_notice', 'should_reply', 'should_act'],
|
||||
'action_records': ['action_done', 'action_build_into_prompt'],
|
||||
"messages": ["is_mentioned", "is_emoji", "is_picid", "is_command",
|
||||
"is_notify", "is_public_notice", "should_reply", "should_act"],
|
||||
"action_records": ["action_done", "action_build_into_prompt"],
|
||||
}
|
||||
|
||||
|
||||
logger.info("正在检查并修复 PostgreSQL 布尔列...")
|
||||
|
||||
|
||||
with engine.connect() as conn:
|
||||
fixed_count = 0
|
||||
for table_name, columns in BOOLEAN_COLUMNS.items():
|
||||
for col_name in columns:
|
||||
try:
|
||||
# 检查当前类型
|
||||
result = conn.execute(text(f'''
|
||||
SELECT data_type FROM information_schema.columns
|
||||
result = conn.execute(text(f"""
|
||||
SELECT data_type FROM information_schema.columns
|
||||
WHERE table_name = '{table_name}' AND column_name = '{col_name}'
|
||||
'''))
|
||||
"""))
|
||||
row = result.fetchone()
|
||||
if row and row[0] != 'boolean':
|
||||
if row and row[0] != "boolean":
|
||||
# 需要修复
|
||||
conn.execute(text(f'''
|
||||
ALTER TABLE {table_name}
|
||||
ALTER COLUMN {col_name} TYPE BOOLEAN
|
||||
conn.execute(text(f"""
|
||||
ALTER TABLE {table_name}
|
||||
ALTER COLUMN {col_name} TYPE BOOLEAN
|
||||
USING CASE WHEN {col_name} = 0 THEN FALSE ELSE TRUE END
|
||||
'''))
|
||||
"""))
|
||||
conn.commit()
|
||||
logger.info(" ✅ %s.%s: %s -> BOOLEAN", table_name, col_name, row[0])
|
||||
fixed_count += 1
|
||||
except Exception as e:
|
||||
logger.warning(" ⚠️ %s.%s: 检查/修复失败 - %s", table_name, col_name, e)
|
||||
|
||||
|
||||
if fixed_count > 0:
|
||||
logger.info("布尔列修复完成!共修复 %d 列", fixed_count)
|
||||
else:
|
||||
|
||||
@@ -134,7 +134,7 @@ async def test_tool_calling():
|
||||
print("测试 4: 工具调用功能")
|
||||
print("=" * 60)
|
||||
|
||||
from src.llm_models.payload_content.tool_option import ToolOption, ToolOptionBuilder, ToolParamType
|
||||
from src.llm_models.payload_content.tool_option import ToolOptionBuilder, ToolParamType
|
||||
|
||||
provider = APIProvider(
|
||||
name="bedrock_test",
|
||||
@@ -171,7 +171,7 @@ async def test_tool_calling():
|
||||
)
|
||||
|
||||
if response.tool_calls:
|
||||
print(f"✅ 模型调用了工具:")
|
||||
print("✅ 模型调用了工具:")
|
||||
for call in response.tool_calls:
|
||||
print(f" - 工具名: {call.func_name}")
|
||||
print(f" - 参数: {call.args}")
|
||||
|
||||
Reference in New Issue
Block a user