Merge branch 'MoFox-Studio:dev' into dev

This commit is contained in:
喵吃鸟
2025-11-27 18:46:41 +08:00
committed by GitHub
30 changed files with 1658 additions and 2226 deletions

View File

@@ -79,7 +79,9 @@ dependencies = [
"rjieba>=0.1.13",
"fastmcp>=2.13.0",
"mofox-wire",
"jinja2>=3.1.0"
"jinja2>=3.1.0",
"psycopg2-binary",
"PyMySQL"
]
[[tool.uv.index]]

View File

@@ -1,6 +1,10 @@
aiosqlite
aiofiles
aiomysql
asyncpg
psycopg[binary]
psycopg2-binary
PyMySQL
APScheduler
aiohttp
aiohttp-cors

View File

@@ -1,117 +0,0 @@
"""
检查表达方式数据库状态的诊断脚本
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import func, select
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression
async def check_database():
"""检查表达方式数据库状态"""
print("=" * 60)
print("表达方式数据库诊断报告")
print("=" * 60)
async with get_db_session() as session:
# 1. 统计总数
total_count = await session.execute(select(func.count()).select_from(Expression))
total = total_count.scalar()
print(f"\n📊 总表达方式数量: {total}")
if total == 0:
print("\n⚠️ 数据库为空!")
print("\n可能的原因:")
print("1. 还没有进行过表达学习")
print("2. 配置中禁用了表达学习")
print("3. 学习过程中发生了错误")
print("\n建议:")
print("- 检查 bot_config.toml 中的 [expression] 配置")
print("- 查看日志中是否有表达学习相关的错误")
print("- 确认聊天流的 learn_expression 配置为 true")
return
# 2. 按 chat_id 统计
print("\n📝 按聊天流统计:")
chat_counts = await session.execute(
select(Expression.chat_id, func.count())
.group_by(Expression.chat_id)
)
for chat_id, count in chat_counts:
print(f" - {chat_id}: {count} 个表达方式")
# 3. 按 type 统计
print("\n📝 按类型统计:")
type_counts = await session.execute(
select(Expression.type, func.count())
.group_by(Expression.type)
)
for expr_type, count in type_counts:
print(f" - {expr_type}: {count}")
# 4. 检查 situation 和 style 字段是否有空值
print("\n🔍 字段完整性检查:")
null_situation = await session.execute(
select(func.count())
.select_from(Expression)
.where(Expression.situation is None)
)
null_style = await session.execute(
select(func.count())
.select_from(Expression)
.where(Expression.style is None)
)
null_sit_count = null_situation.scalar()
null_sty_count = null_style.scalar()
print(f" - situation 为空: {null_sit_count}")
print(f" - style 为空: {null_sty_count}")
if null_sit_count > 0 or null_sty_count > 0:
print(" ⚠️ 发现空值!这会导致匹配失败")
# 5. 显示一些样例数据
print("\n📋 样例数据 (前10条):")
samples = await session.execute(
select(Expression)
.limit(10)
)
for i, expr in enumerate(samples.scalars(), 1):
print(f"\n [{i}] Chat: {expr.chat_id}")
print(f" Type: {expr.type}")
print(f" Situation: {expr.situation}")
print(f" Style: {expr.style}")
print(f" Count: {expr.count}")
# 6. 检查 style 字段的唯一值
print("\n📋 Style 字段样例 (前20个):")
unique_styles = await session.execute(
select(Expression.style)
.distinct()
.limit(20)
)
styles = list(unique_styles.scalars())
for style in styles:
print(f" - {style}")
print(f"\n (共 {len(styles)} 个不同的 style)")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)
if __name__ == "__main__":
asyncio.run(check_database())

View File

@@ -1,65 +0,0 @@
"""
检查数据库中 style 字段的内容特征
"""
import asyncio
import sys
from pathlib import Path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from sqlalchemy import select
from src.common.database.compatibility import get_db_session
from src.common.database.core.models import Expression
async def analyze_style_fields():
"""分析 style 字段的内容"""
print("=" * 60)
print("Style 字段内容分析")
print("=" * 60)
async with get_db_session() as session:
# 获取所有表达方式
result = await session.execute(select(Expression).limit(30))
expressions = result.scalars().all()
print(f"\n总共检查 {len(expressions)} 条记录\n")
# 按类型分类
style_examples = [
{
"situation": expr.situation,
"style": expr.style,
"length": len(expr.style) if expr.style else 0
}
for expr in expressions if expr.type == "style"
]
print("📋 Style 类型样例 (前15条):")
print("="*60)
for i, ex in enumerate(style_examples[:15], 1):
print(f"\n[{i}]")
print(f" Situation: {ex['situation']}")
print(f" Style: {ex['style']}")
print(f" 长度: {ex['length']} 字符")
# 判断是具体表达还是风格描述
if ex["length"] <= 20 and any(word in ex["style"] for word in ["简洁", "短句", "陈述", "疑问", "感叹", "省略", "完整"]):
style_type = "✓ 风格描述"
elif ex["length"] <= 10:
style_type = "? 可能是具体表达(较短)"
else:
style_type = "✗ 具体表达内容"
print(f" 类型判断: {style_type}")
print("\n" + "="*60)
print("分析完成")
print("="*60)
if __name__ == "__main__":
asyncio.run(analyze_style_fields())

View File

@@ -1,49 +0,0 @@
#!/usr/bin/env python3
"""清理 core/models.py只保留模型定义"""
import os
# 文件路径
models_file = os.path.join(
os.path.dirname(os.path.dirname(__file__)),
"src",
"common",
"database",
"core",
"models.py"
)
print(f"正在清理文件: {models_file}")
# 读取文件
with open(models_file, encoding="utf-8") as f:
lines = f.readlines()
# 找到最后一个模型类的结束位置MonthlyPlan的 __table_args__ 结束)
# 我们要保留到第593行包含
keep_lines = []
found_end = False
for i, line in enumerate(lines, 1):
keep_lines.append(line)
# 检查是否到达 MonthlyPlan 的 __table_args__ 结束
if i > 580 and line.strip() == ")":
# 再检查前一行是否有 Index 相关内容
if "idx_monthlyplan" in "".join(lines[max(0, i-5):i]):
print(f"找到模型定义结束位置: 第 {i}")
found_end = True
break
if not found_end:
print("❌ 未找到模型定义结束标记")
exit(1)
# 写回文件
with open(models_file, "w", encoding="utf-8") as f:
f.writelines(keep_lines)
print("✅ 文件清理完成")
print(f"保留行数: {len(keep_lines)}")
print(f"原始行数: {len(lines)}")
print(f"删除行数: {len(lines) - len(keep_lines)}")

View File

@@ -1,59 +0,0 @@
"""
调试 MCP 工具列表获取
直接测试 MCP 客户端是否能获取工具
"""
import asyncio
from fastmcp.client import Client, StreamableHttpTransport
async def test_direct_connection():
"""直接连接 MCP 服务器并获取工具列表"""
print("=" * 60)
print("直接测试 MCP 服务器连接")
print("=" * 60)
url = "http://localhost:8000/mcp"
print(f"\n连接到: {url}")
try:
# 创建传输层
transport = StreamableHttpTransport(url)
print("✓ 传输层创建成功")
# 创建客户端
async with Client(transport) as client:
print("✓ 客户端连接成功")
# 获取工具列表
print("\n正在获取工具列表...")
tools_result = await client.list_tools()
print(f"\n获取结果类型: {type(tools_result)}")
print(f"结果内容: {tools_result}")
# 检查是否有 tools 属性
if hasattr(tools_result, "tools"):
tools = tools_result.tools
print(f"\n✓ 找到 tools 属性,包含 {len(tools)} 个工具")
for i, tool in enumerate(tools, 1):
print(f"\n工具 {i}:")
print(f" 名称: {tool.name}")
print(f" 描述: {tool.description}")
if hasattr(tool, "inputSchema"):
print(f" 参数 Schema: {tool.inputSchema}")
else:
print("\n✗ 结果中没有 tools 属性")
print(f"可用属性: {dir(tools_result)}")
except Exception as e:
print(f"\n✗ 连接失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_direct_connection())

View File

@@ -1,88 +0,0 @@
"""
检查 StyleLearner 模型状态的诊断脚本
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.chat.express.style_learner import style_learner_manager
from src.common.logger import get_logger
logger = get_logger("debug_style_learner")
def check_style_learner_status(chat_id: str):
"""检查指定 chat_id 的 StyleLearner 状态"""
print("=" * 60)
print(f"StyleLearner 状态诊断 - Chat ID: {chat_id}")
print("=" * 60)
# 获取 learner
learner = style_learner_manager.get_learner(chat_id)
# 1. 基本信息
print("\n📊 基本信息:")
print(f" Chat ID: {learner.chat_id}")
print(f" 风格数量: {len(learner.style_to_id)}")
print(f" 下一个ID: {learner.next_style_id}")
print(f" 最大风格数: {learner.max_styles}")
# 2. 学习统计
print("\n📈 学习统计:")
print(f" 总样本数: {learner.learning_stats['total_samples']}")
print(f" 最后更新: {learner.learning_stats.get('last_update', 'N/A')}")
# 3. 风格列表前20个
print("\n📋 已学习的风格 (前20个):")
all_styles = learner.get_all_styles()
if not all_styles:
print(" ⚠️ 没有任何风格!模型尚未训练")
else:
for i, style in enumerate(all_styles[:20], 1):
style_id = learner.style_to_id.get(style)
situation = learner.id_to_situation.get(style_id, "N/A")
print(f" [{i}] {style}")
print(f" (ID: {style_id}, Situation: {situation})")
# 4. 测试预测
print("\n🔮 测试预测功能:")
if not all_styles:
print(" ⚠️ 无法测试,模型没有训练数据")
else:
test_situations = [
"表示惊讶",
"讨论游戏",
"表达赞同"
]
for test_sit in test_situations:
print(f"\n 测试输入: '{test_sit}'")
best_style, scores = learner.predict_style(test_sit, top_k=3)
if best_style:
print(f" ✓ 最佳匹配: {best_style}")
print(" Top 3:")
for style, score in list(scores.items())[:3]:
print(f" - {style}: {score:.4f}")
else:
print(" ✗ 预测失败")
print("\n" + "=" * 60)
print("诊断完成")
print("=" * 60)
if __name__ == "__main__":
# 从诊断报告中看到的 chat_id
test_chat_ids = [
"52fb94af9f500a01e023ea780e43606e", # 有78个表达方式
"46c8714c8a9b7ee169941fe99fcde07d", # 有22个表达方式
]
for chat_id in test_chat_ids:
check_style_learner_status(chat_id)
print("\n")

View File

@@ -1,403 +0,0 @@
"""
记忆去重工具
功能:
1. 扫描所有标记为"相似"关系的记忆边
2. 对相似记忆进行去重(保留重要性高的,删除另一个)
3. 支持干运行模式(预览不执行)
4. 提供详细的去重报告
使用方法:
# 预览模式(不实际删除)
python scripts/deduplicate_memories.py --dry-run
# 执行去重
python scripts/deduplicate_memories.py
# 指定相似度阈值
python scripts/deduplicate_memories.py --threshold 0.9
# 指定数据目录
python scripts/deduplicate_memories.py --data-dir data/memory_graph
"""
import argparse
import asyncio
import sys
from datetime import datetime
from pathlib import Path
import numpy as np
sys.path.insert(0, str(Path(__file__).parent.parent))
from src.common.logger import get_logger
from src.memory_graph.manager_singleton import initialize_memory_manager, shutdown_memory_manager
logger = get_logger(__name__)
class MemoryDeduplicator:
"""记忆去重器"""
def __init__(self, data_dir: str = "data/memory_graph", dry_run: bool = False, threshold: float = 0.85):
self.data_dir = data_dir
self.dry_run = dry_run
self.threshold = threshold
self.manager = None
# 统计信息
self.stats = {
"total_memories": 0,
"similar_pairs": 0,
"duplicates_found": 0,
"duplicates_removed": 0,
"errors": 0,
}
async def initialize(self):
"""初始化记忆管理器"""
logger.info(f"正在初始化记忆管理器 (data_dir={self.data_dir})...")
self.manager = await initialize_memory_manager(data_dir=self.data_dir)
if not self.manager:
raise RuntimeError("记忆管理器初始化失败")
self.stats["total_memories"] = len(self.manager.graph_store.get_all_memories())
logger.info(f"✅ 记忆管理器初始化成功,共 {self.stats['total_memories']} 条记忆")
async def find_similar_pairs(self) -> list[tuple[str, str, float]]:
"""
查找所有相似的记忆对(通过向量相似度计算)
Returns:
[(memory_id_1, memory_id_2, similarity), ...]
"""
logger.info("正在扫描相似记忆对...")
similar_pairs = []
seen_pairs = set() # 避免重复
# 获取所有记忆
all_memories = self.manager.graph_store.get_all_memories()
total_memories = len(all_memories)
logger.info(f"开始计算 {total_memories} 条记忆的相似度...")
# 两两比较记忆的相似度
for i, memory_i in enumerate(all_memories):
# 每处理10条记忆让出控制权
if i % 10 == 0:
await asyncio.sleep(0)
if i > 0:
logger.info(f"进度: {i}/{total_memories} ({i*100//total_memories}%)")
# 获取记忆i的向量从主题节点
vector_i = None
for node in memory_i.nodes:
if node.embedding is not None:
vector_i = node.embedding
break
if vector_i is None:
continue
# 与后续记忆比较
for j in range(i + 1, total_memories):
memory_j = all_memories[j]
# 获取记忆j的向量
vector_j = None
for node in memory_j.nodes:
if node.embedding is not None:
vector_j = node.embedding
break
if vector_j is None:
continue
# 计算余弦相似度
similarity = self._cosine_similarity(vector_i, vector_j)
# 只保存满足阈值的相似对
if similarity >= self.threshold:
pair_key = tuple(sorted([memory_i.id, memory_j.id]))
if pair_key not in seen_pairs:
seen_pairs.add(pair_key)
similar_pairs.append((memory_i.id, memory_j.id, similarity))
self.stats["similar_pairs"] = len(similar_pairs)
logger.info(f"找到 {len(similar_pairs)} 对相似记忆(阈值>={self.threshold}")
return similar_pairs
def _cosine_similarity(self, vec1: np.ndarray, vec2: np.ndarray) -> float:
"""计算余弦相似度"""
try:
vec1_norm = np.linalg.norm(vec1)
vec2_norm = np.linalg.norm(vec2)
if vec1_norm == 0 or vec2_norm == 0:
return 0.0
similarity = np.dot(vec1, vec2) / (vec1_norm * vec2_norm)
return float(similarity)
except Exception as e:
logger.error(f"计算余弦相似度失败: {e}")
return 0.0
def decide_which_to_keep(self, mem_id_1: str, mem_id_2: str) -> tuple[str | None, str | None]:
"""
决定保留哪个记忆,删除哪个
优先级:
1. 重要性更高的
2. 激活度更高的
3. 创建时间更早的
Returns:
(keep_id, remove_id)
"""
mem1 = self.manager.graph_store.get_memory_by_id(mem_id_1)
mem2 = self.manager.graph_store.get_memory_by_id(mem_id_2)
if not mem1 or not mem2:
logger.warning(f"记忆不存在: {mem_id_1} or {mem_id_2}")
return None, None
# 比较重要性
if mem1.importance > mem2.importance:
return mem_id_1, mem_id_2
elif mem1.importance < mem2.importance:
return mem_id_2, mem_id_1
# 重要性相同,比较激活度
if mem1.activation > mem2.activation:
return mem_id_1, mem_id_2
elif mem1.activation < mem2.activation:
return mem_id_2, mem_id_1
# 激活度也相同,保留更早创建的
if mem1.created_at < mem2.created_at:
return mem_id_1, mem_id_2
else:
return mem_id_2, mem_id_1
async def deduplicate_pair(self, mem_id_1: str, mem_id_2: str, similarity: float) -> bool:
"""
去重一对相似记忆
Returns:
是否成功去重
"""
keep_id, remove_id = self.decide_which_to_keep(mem_id_1, mem_id_2)
if not keep_id or not remove_id:
self.stats["errors"] += 1
return False
keep_mem = self.manager.graph_store.get_memory_by_id(keep_id)
remove_mem = self.manager.graph_store.get_memory_by_id(remove_id)
logger.info("")
logger.info(f"{'[预览]' if self.dry_run else '[执行]'} 去重相似记忆对 (相似度={similarity:.3f}):")
logger.info(f" 保留: {keep_id}")
logger.info(f" - 主题: {keep_mem.metadata.get('topic', 'N/A')}")
logger.info(f" - 重要性: {keep_mem.importance:.2f}")
logger.info(f" - 激活度: {keep_mem.activation:.2f}")
logger.info(f" - 创建时间: {keep_mem.created_at}")
logger.info(f" 删除: {remove_id}")
logger.info(f" - 主题: {remove_mem.metadata.get('topic', 'N/A')}")
logger.info(f" - 重要性: {remove_mem.importance:.2f}")
logger.info(f" - 激活度: {remove_mem.activation:.2f}")
logger.info(f" - 创建时间: {remove_mem.created_at}")
if self.dry_run:
logger.info(" [预览模式] 不执行实际删除")
self.stats["duplicates_found"] += 1
return True
try:
# 增强保留记忆的属性
keep_mem.importance = min(1.0, keep_mem.importance + 0.05)
keep_mem.activation = min(1.0, keep_mem.activation + 0.05)
# 累加访问次数
if hasattr(keep_mem, "access_count") and hasattr(remove_mem, "access_count"):
keep_mem.access_count += remove_mem.access_count
# 删除相似记忆
await self.manager.delete_memory(remove_id)
self.stats["duplicates_removed"] += 1
logger.info(" ✅ 删除成功")
# 让出控制权
await asyncio.sleep(0)
return True
except Exception as e:
logger.error(f" ❌ 删除失败: {e}")
self.stats["errors"] += 1
return False
async def run(self):
"""执行去重"""
start_time = datetime.now()
print("="*70)
print("记忆去重工具")
print("="*70)
print(f"数据目录: {self.data_dir}")
print(f"相似度阈值: {self.threshold}")
print(f"模式: {'预览模式(不实际删除)' if self.dry_run else '执行模式(会实际删除)'}")
print("="*70)
print()
# 初始化
await self.initialize()
# 查找相似对
similar_pairs = await self.find_similar_pairs()
if not similar_pairs:
logger.info("未找到需要去重的相似记忆对")
print()
print("="*70)
print("未找到需要去重的记忆")
print("="*70)
return
# 去重处理
logger.info(f"开始{'预览' if self.dry_run else '执行'}去重...")
print()
processed_pairs = set() # 避免重复处理
for mem_id_1, mem_id_2, similarity in similar_pairs:
# 检查是否已处理(可能一个记忆已被删除)
pair_key = tuple(sorted([mem_id_1, mem_id_2]))
if pair_key in processed_pairs:
continue
# 检查记忆是否仍存在
if not self.manager.graph_store.get_memory_by_id(mem_id_1):
logger.debug(f"记忆 {mem_id_1} 已不存在,跳过")
continue
if not self.manager.graph_store.get_memory_by_id(mem_id_2):
logger.debug(f"记忆 {mem_id_2} 已不存在,跳过")
continue
# 执行去重
success = await self.deduplicate_pair(mem_id_1, mem_id_2, similarity)
if success:
processed_pairs.add(pair_key)
# 保存数据(如果不是干运行)
if not self.dry_run:
logger.info("正在保存数据...")
await self.manager.persistence.save_graph_store(self.manager.graph_store)
logger.info("✅ 数据已保存")
# 统计报告
elapsed = (datetime.now() - start_time).total_seconds()
print()
print("="*70)
print("去重报告")
print("="*70)
print(f"总记忆数: {self.stats['total_memories']}")
print(f"相似记忆对: {self.stats['similar_pairs']}")
print(f"发现重复: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}")
print(f"{'预览通过' if self.dry_run else '成功删除'}: {self.stats['duplicates_found'] if self.dry_run else self.stats['duplicates_removed']}")
print(f"错误数: {self.stats['errors']}")
print(f"耗时: {elapsed:.2f}")
if self.dry_run:
print()
print("⚠️ 这是预览模式,未实际删除任何记忆")
print("💡 要执行实际删除,请运行: python scripts/deduplicate_memories.py")
else:
print()
print("✅ 去重完成!")
final_count = len(self.manager.graph_store.get_all_memories())
print(f"📊 最终记忆数: {final_count} (减少 {self.stats['total_memories'] - final_count} 条)")
print("="*70)
async def cleanup(self):
"""清理资源"""
if self.manager:
await shutdown_memory_manager()
async def main():
"""主函数"""
parser = argparse.ArgumentParser(
description="记忆去重工具 - 对标记为相似的记忆进行一键去重",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
# 预览模式(推荐先运行)
python scripts/deduplicate_memories.py --dry-run
# 执行去重
python scripts/deduplicate_memories.py
# 指定相似度阈值(只处理相似度>=0.9的记忆对)
python scripts/deduplicate_memories.py --threshold 0.9
# 指定数据目录
python scripts/deduplicate_memories.py --data-dir data/memory_graph
# 组合使用
python scripts/deduplicate_memories.py --dry-run --threshold 0.95 --data-dir data/test
"""
)
parser.add_argument(
"--dry-run",
action="store_true",
help="预览模式,不实际删除记忆(推荐先运行此模式)"
)
parser.add_argument(
"--threshold",
type=float,
default=0.85,
help="相似度阈值,只处理相似度>=此值的记忆对(默认: 0.85"
)
parser.add_argument(
"--data-dir",
type=str,
default="data/memory_graph",
help="记忆数据目录(默认: data/memory_graph"
)
args = parser.parse_args()
# 创建去重器
deduplicator = MemoryDeduplicator(
data_dir=args.data_dir,
dry_run=args.dry_run,
threshold=args.threshold
)
try:
# 执行去重
await deduplicator.run()
except KeyboardInterrupt:
print("\n\n⚠️ 用户中断操作")
except Exception as e:
logger.error(f"执行失败: {e}")
print(f"\n❌ 执行失败: {e}")
return 1
finally:
# 清理资源
await deduplicator.cleanup()
return 0
if __name__ == "__main__":
sys.exit(asyncio.run(main()))

View File

@@ -1,195 +0,0 @@
import os
import sys
import time
# Add project root to Python path
from src.common.database.database_model import ChatStreams, Expression
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, project_root)
def get_chat_name(chat_id: str) -> str:
"""Get chat name from chat_id by querying ChatStreams table directly"""
try:
# 直接从数据库查询ChatStreams表
chat_stream = ChatStreams.get_or_none(ChatStreams.stream_id == chat_id)
if chat_stream is None:
return f"未知聊天 ({chat_id})"
# 如果有群组信息,显示群组名称
if chat_stream.group_name:
return f"{chat_stream.group_name} ({chat_id})"
# 如果是私聊,显示用户昵称
elif chat_stream.user_nickname:
return f"{chat_stream.user_nickname}的私聊 ({chat_id})"
else:
return f"未知聊天 ({chat_id})"
except Exception:
return f"查询失败 ({chat_id})"
def calculate_time_distribution(expressions) -> dict[str, int]:
"""Calculate distribution of last active time in days"""
now = time.time()
distribution = {
"0-1天": 0,
"1-3天": 0,
"3-7天": 0,
"7-14天": 0,
"14-30天": 0,
"30-60天": 0,
"60-90天": 0,
"90+天": 0,
}
for expr in expressions:
diff_days = (now - expr.last_active_time) / (24 * 3600)
if diff_days < 1:
distribution["0-1天"] += 1
elif diff_days < 3:
distribution["1-3天"] += 1
elif diff_days < 7:
distribution["3-7天"] += 1
elif diff_days < 14:
distribution["7-14天"] += 1
elif diff_days < 30:
distribution["14-30天"] += 1
elif diff_days < 60:
distribution["30-60天"] += 1
elif diff_days < 90:
distribution["60-90天"] += 1
else:
distribution["90+天"] += 1
return distribution
def calculate_count_distribution(expressions) -> dict[str, int]:
"""Calculate distribution of count values"""
distribution = {"0-1": 0, "1-2": 0, "2-3": 0, "3-4": 0, "4-5": 0, "5-10": 0, "10+": 0}
for expr in expressions:
cnt = expr.count
if cnt < 1:
distribution["0-1"] += 1
elif cnt < 2:
distribution["1-2"] += 1
elif cnt < 3:
distribution["2-3"] += 1
elif cnt < 4:
distribution["3-4"] += 1
elif cnt < 5:
distribution["4-5"] += 1
elif cnt < 10:
distribution["5-10"] += 1
else:
distribution["10+"] += 1
return distribution
def get_top_expressions_by_chat(chat_id: str, top_n: int = 5) -> list[Expression]:
"""Get top N most used expressions for a specific chat_id"""
return Expression.select().where(Expression.chat_id == chat_id).order_by(Expression.count.desc()).limit(top_n)
def show_overall_statistics(expressions, total: int) -> None:
"""Show overall statistics"""
time_dist = calculate_time_distribution(expressions)
count_dist = calculate_count_distribution(expressions)
print("\n=== 总体统计 ===")
print(f"总表达式数量: {total}")
print("\n上次激活时间分布:")
for period, count in time_dist.items():
print(f"{period}: {count} ({count / total * 100:.2f}%)")
print("\ncount分布:")
for range_, count in count_dist.items():
print(f"{range_}: {count} ({count / total * 100:.2f}%)")
def show_chat_statistics(chat_id: str, chat_name: str) -> None:
"""Show statistics for a specific chat"""
chat_exprs = list(Expression.select().where(Expression.chat_id == chat_id))
chat_total = len(chat_exprs)
print(f"\n=== {chat_name} ===")
print(f"表达式数量: {chat_total}")
if chat_total == 0:
print("该聊天没有表达式数据")
return
# Time distribution for this chat
time_dist = calculate_time_distribution(chat_exprs)
print("\n上次激活时间分布:")
for period, count in time_dist.items():
if count > 0:
print(f"{period}: {count} ({count / chat_total * 100:.2f}%)")
# Count distribution for this chat
count_dist = calculate_count_distribution(chat_exprs)
print("\ncount分布:")
for range_, count in count_dist.items():
if count > 0:
print(f"{range_}: {count} ({count / chat_total * 100:.2f}%)")
# Top expressions
print("\nTop 10使用最多的表达式:")
top_exprs = get_top_expressions_by_chat(chat_id, 10)
for i, expr in enumerate(top_exprs, 1):
print(f"{i}. [{expr.type}] Count: {expr.count}")
print(f" Situation: {expr.situation}")
print(f" Style: {expr.style}")
print()
def interactive_menu() -> None:
"""Interactive menu for expression statistics"""
# Get all expressions
expressions = list(Expression.select())
if not expressions:
print("数据库中没有找到表达式")
return
total = len(expressions)
# Get unique chat_ids and their names
chat_ids = list({expr.chat_id for expr in expressions})
chat_info = [(chat_id, get_chat_name(chat_id)) for chat_id in chat_ids]
chat_info.sort(key=lambda x: x[1]) # Sort by chat name
while True:
print("\n" + "=" * 50)
print("表达式统计分析")
print("=" * 50)
print("0. 显示总体统计")
for i, (chat_id, chat_name) in enumerate(chat_info, 1):
chat_count = sum(1 for expr in expressions if expr.chat_id == chat_id)
print(f"{i}. {chat_name} ({chat_count}个表达式)")
print("q. 退出")
choice = input("\n请选择要查看的统计 (输入序号): ").strip()
if choice.lower() == "q":
print("再见!")
break
try:
choice_num = int(choice)
if choice_num == 0:
show_overall_statistics(expressions, total)
elif 1 <= choice_num <= len(chat_info):
chat_id, chat_name = chat_info[choice_num - 1]
show_chat_statistics(chat_id, chat_name)
else:
print("无效的选择,请重新输入")
except ValueError:
print("请输入有效的数字")
input("\n按回车键继续...")
if __name__ == "__main__":
interactive_menu()

View File

@@ -1,66 +0,0 @@
#!/usr/bin/env python3
"""提取models.py中的模型定义"""
import re
# 读取原始文件
with open("src/common/database/sqlalchemy_models.py", encoding="utf-8") as f:
content = f.read()
# 找到get_string_field函数的开始和结束
get_string_field_start = content.find("# MySQL兼容的字段类型辅助函数")
get_string_field_end = content.find("\n\nclass ChatStreams(Base):")
get_string_field = content[get_string_field_start:get_string_field_end]
# 找到第一个class定义开始
first_class_pos = content.find("class ChatStreams(Base):")
# 找到所有class定义直到遇到非class的def
# 简单策略:找到所有以"class "开头且继承Base的类
classes_pattern = r"class \w+\(Base\):.*?(?=\nclass \w+\(Base\):|$)"
matches = list(re.finditer(classes_pattern, content[first_class_pos:], re.DOTALL))
if matches:
# 取最后一个匹配的结束位置
models_content = content[first_class_pos:first_class_pos + matches[-1].end()]
else:
# 备用方案从第一个class到文件的85%位置
models_end = int(len(content) * 0.85)
models_content = content[first_class_pos:models_end]
# 创建新文件内容
header = '''"""SQLAlchemy数据库模型定义
本文件只包含纯模型定义使用SQLAlchemy 2.0的Mapped类型注解风格。
引擎和会话管理已移至core/engine.py和core/session.py。
所有模型使用统一的类型注解风格:
field_name: Mapped[PyType] = mapped_column(Type, ...)
这样IDE/Pylance能正确推断实例属性类型。
"""
import datetime
import time
from sqlalchemy import Boolean, DateTime, Float, Index, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column
# 创建基类
Base = declarative_base()
'''
new_content = header + get_string_field + "\n\n" + models_content
# 写入新文件
with open("src/common/database/core/models.py", "w", encoding="utf-8") as f:
f.write(new_content)
print("✅ Models file rewritten successfully")
print(f"File size: {len(new_content)} characters")
pattern = r"^class \w+\(Base\):"
model_count = len(re.findall(pattern, models_content, re.MULTILINE))
print(f"Number of model classes: {model_count}")

View File

@@ -1,267 +0,0 @@
"""
为现有节点生成嵌入向量
批量为图存储中缺少嵌入向量的节点生成并索引嵌入向量
使用场景:
1. 历史记忆节点没有嵌入向量
2. 嵌入生成器之前未配置,现在需要补充生成
3. 向量索引损坏需要重建
使用方法:
python scripts/generate_missing_embeddings.py [--node-types TOPIC,OBJECT] [--batch-size 50]
参数说明:
--node-types: 需要生成嵌入的节点类型,默认为 TOPIC,OBJECT
--batch-size: 批量处理大小,默认为 50
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
async def generate_missing_embeddings(
target_node_types: list[str] | None = None,
batch_size: int = 50,
):
"""
为缺失嵌入向量的节点生成嵌入
Args:
target_node_types: 需要处理的节点类型列表(如 ["主题", "客体"]
batch_size: 批处理大小
"""
from src.common.logger import get_logger
from src.memory_graph.manager_singleton import get_memory_manager, initialize_memory_manager
from src.memory_graph.models import NodeType
logger = get_logger("generate_missing_embeddings")
if target_node_types is None:
target_node_types = [NodeType.TOPIC.value, NodeType.OBJECT.value]
print(f"\n{'='*80}")
print("🔧 为节点生成嵌入向量")
print(f"{'='*80}\n")
print(f"目标节点类型: {', '.join(target_node_types)}")
print(f"批处理大小: {batch_size}\n")
# 1. 初始化记忆管理器
print("🔧 正在初始化记忆管理器...")
await initialize_memory_manager()
manager = get_memory_manager()
if manager is None:
print("❌ 记忆管理器初始化失败")
return
print("✅ 记忆管理器已初始化\n")
# 2. 获取已索引的节点ID
print("🔍 检查现有向量索引...")
existing_node_ids = set()
try:
vector_count = manager.vector_store.collection.count()
if vector_count > 0:
# 分批获取所有已索引的ID
batch_size_check = 1000
for offset in range(0, vector_count, batch_size_check):
limit = min(batch_size_check, vector_count - offset)
result = manager.vector_store.collection.get(
limit=limit,
offset=offset,
)
if result and "ids" in result:
existing_node_ids.update(result["ids"])
print(f"✅ 发现 {len(existing_node_ids)} 个已索引节点\n")
except Exception as e:
logger.warning(f"获取已索引节点ID失败: {e}")
print("⚠️ 无法获取已索引节点,将尝试跳过重复项\n")
# 3. 收集需要生成嵌入的节点
print("🔍 扫描需要生成嵌入的节点...")
all_memories = manager.graph_store.get_all_memories()
nodes_to_process = []
total_target_nodes = 0
type_stats = {nt: {"total": 0, "need_emb": 0, "already_indexed": 0} for nt in target_node_types}
for memory in all_memories:
for node in memory.nodes:
if node.node_type.value in target_node_types:
total_target_nodes += 1
type_stats[node.node_type.value]["total"] += 1
# 检查是否已在向量索引中
if node.id in existing_node_ids:
type_stats[node.node_type.value]["already_indexed"] += 1
continue
if not node.has_embedding():
nodes_to_process.append({
"node": node,
"memory_id": memory.id,
})
type_stats[node.node_type.value]["need_emb"] += 1
print("\n📊 扫描结果:")
for node_type in target_node_types:
stats = type_stats[node_type]
already_ok = stats["already_indexed"]
coverage = (stats["total"] - stats["need_emb"]) / stats["total"] * 100 if stats["total"] > 0 else 0
print(f" - {node_type}: {stats['total']} 个节点, {stats['need_emb']} 个缺失嵌入, "
f"{already_ok} 个已索引 (覆盖率: {coverage:.1f}%)")
print(f"\n 总计: {total_target_nodes} 个目标节点, {len(nodes_to_process)} 个需要生成嵌入\n")
if len(nodes_to_process) == 0:
print("✅ 所有节点已有嵌入向量,无需生成")
return
# 3. 批量生成嵌入
print("🚀 开始生成嵌入向量...\n")
total_batches = (len(nodes_to_process) + batch_size - 1) // batch_size
success_count = 0
failed_count = 0
indexed_count = 0
for i in range(0, len(nodes_to_process), batch_size):
batch = nodes_to_process[i : i + batch_size]
batch_num = i // batch_size + 1
print(f"📦 批次 {batch_num}/{total_batches} ({len(batch)} 个节点)...")
try:
# 提取文本内容
texts = [item["node"].content for item in batch]
# 批量生成嵌入
embeddings = await manager.embedding_generator.generate_batch(texts)
# 为节点设置嵌入并索引
batch_nodes_for_index = []
for j, (item, embedding) in enumerate(zip(batch, embeddings)):
node = item["node"]
if embedding is not None:
# 设置嵌入向量
node.embedding = embedding
batch_nodes_for_index.append(node)
success_count += 1
else:
failed_count += 1
logger.warning(f" ⚠️ 节点 {node.id[:8]}... '{node.content[:30]}' 嵌入生成失败")
# 批量索引到向量数据库
if batch_nodes_for_index:
try:
await manager.vector_store.add_nodes_batch(batch_nodes_for_index)
indexed_count += len(batch_nodes_for_index)
print(f" ✅ 成功: {len(batch_nodes_for_index)}/{len(batch)} 个节点已生成并索引")
except Exception as e:
# 如果批量失败,尝试逐个添加(跳过重复)
logger.warning(f" 批量索引失败,尝试逐个添加: {e}")
individual_success = 0
for node in batch_nodes_for_index:
try:
await manager.vector_store.add_node(node)
individual_success += 1
indexed_count += 1
except Exception as e2:
if "Expected IDs to be unique" in str(e2):
logger.debug(f" 跳过已存在节点: {node.id}")
else:
logger.error(f" 节点 {node.id} 索引失败: {e2}")
print(f" ⚠️ 逐个索引: {individual_success}/{len(batch_nodes_for_index)} 个成功")
except Exception as e:
failed_count += len(batch)
logger.error(f"批次 {batch_num} 处理失败")
print(f" ❌ 批次处理失败: {e}")
# 显示进度
total_processed = min(i + batch_size, len(nodes_to_process))
progress = total_processed / len(nodes_to_process) * 100
print(f" 📊 总进度: {total_processed}/{len(nodes_to_process)} ({progress:.1f}%)\n")
# 4. 保存图数据(更新节点的 embedding 字段)
print("💾 保存图数据...")
try:
await manager.persistence.save_graph_store(manager.graph_store)
print("✅ 图数据已保存\n")
except Exception as e:
logger.error("保存图数据失败")
print(f"❌ 保存失败: {e}\n")
# 5. 验证结果
print("🔍 验证向量索引...")
final_vector_count = manager.vector_store.collection.count()
stats = manager.graph_store.get_statistics()
total_nodes = stats["total_nodes"]
print(f"\n{'='*80}")
print("📊 生成完成")
print(f"{'='*80}")
print(f"处理节点数: {len(nodes_to_process)}")
print(f"成功生成: {success_count}")
print(f"失败数量: {failed_count}")
print(f"成功索引: {indexed_count}")
print(f"向量索引节点数: {final_vector_count}")
print(f"图存储节点数: {total_nodes}")
print(f"索引覆盖率: {final_vector_count / total_nodes * 100:.1f}%\n")
# 6. 测试搜索
print("🧪 测试搜索功能...")
test_queries = ["小红帽蕾克", "拾风", "杰瑞喵"]
for query in test_queries:
results = await manager.search_memories(query=query, top_k=3)
if results:
print(f"\n✅ 查询 '{query}' 找到 {len(results)} 条记忆:")
for i, memory in enumerate(results[:2], 1):
subject_node = memory.get_subject_node()
# 获取主题节点遍历所有节点找TOPIC类型
from src.memory_graph.models import NodeType
topic_nodes = [n for n in memory.nodes if n.node_type == NodeType.TOPIC]
subject = subject_node.content if subject_node else "?"
topic = topic_nodes[0].content if topic_nodes else "?"
print(f" {i}. {subject} - {topic} (重要性: {memory.importance:.2f})")
else:
print(f"\n⚠️ 查询 '{query}' 返回 0 条结果")
async def main():
import argparse
parser = argparse.ArgumentParser(description="为节点生成嵌入向量")
parser.add_argument(
"--node-types",
type=str,
default="主题,客体",
help="需要生成嵌入的节点类型,逗号分隔(默认:主题,客体)",
)
parser.add_argument(
"--batch-size",
type=int,
default=50,
help="批处理大小默认50",
)
args = parser.parse_args()
target_types = [t.strip() for t in args.node_types.split(",")]
await generate_missing_embeddings(
target_node_types=target_types,
batch_size=args.batch_size,
)
if __name__ == "__main__":
asyncio.run(main())

1051
scripts/migrate_database.py Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,140 +0,0 @@
#!/usr/bin/env python
"""
从现有ChromaDB数据重建JSON元数据索引
"""
import asyncio
import os
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.chat.memory_system.memory_metadata_index import MemoryMetadataIndexEntry
from src.chat.memory_system.memory_system import MemorySystem
from src.common.logger import get_logger
logger = get_logger(__name__)
async def rebuild_metadata_index():
"""从ChromaDB重建元数据索引"""
print("=" * 80)
print("重建JSON元数据索引")
print("=" * 80)
# 初始化记忆系统
print("\n🔧 初始化记忆系统...")
ms = MemorySystem()
await ms.initialize()
print("✅ 记忆系统已初始化")
if not hasattr(ms.unified_storage, "metadata_index"):
print("❌ 元数据索引管理器未初始化")
return
# 获取所有记忆
print("\n📥 从ChromaDB获取所有记忆...")
from src.common.vector_db import vector_db_service
try:
# 获取集合中的所有记忆ID
collection_name = ms.unified_storage.config.memory_collection
result = vector_db_service.get(
collection_name=collection_name, include=["documents", "metadatas", "embeddings"]
)
if not result or not result.get("ids"):
print("❌ ChromaDB中没有找到记忆数据")
return
ids = result["ids"]
metadatas = result.get("metadatas", [])
print(f"✅ 找到 {len(ids)} 条记忆")
# 重建元数据索引
print("\n🔨 开始重建元数据索引...")
entries = []
success_count = 0
for i, (memory_id, metadata) in enumerate(zip(ids, metadatas, strict=False), 1):
try:
# 从ChromaDB元数据重建索引条目
import orjson
entry = MemoryMetadataIndexEntry(
memory_id=memory_id,
user_id=metadata.get("user_id", "unknown"),
memory_type=metadata.get("memory_type", "general"),
subjects=orjson.loads(metadata.get("subjects", "[]")),
objects=[metadata.get("object")] if metadata.get("object") else [],
keywords=orjson.loads(metadata.get("keywords", "[]")),
tags=orjson.loads(metadata.get("tags", "[]")),
importance=2, # 默认NORMAL
confidence=2, # 默认MEDIUM
created_at=metadata.get("created_at", 0.0),
access_count=metadata.get("access_count", 0),
chat_id=metadata.get("chat_id"),
content_preview=None,
)
# 尝试解析importance和confidence的枚举名称
if "importance" in metadata:
imp_str = metadata["importance"]
if imp_str == "LOW":
entry.importance = 1
elif imp_str == "NORMAL":
entry.importance = 2
elif imp_str == "HIGH":
entry.importance = 3
elif imp_str == "CRITICAL":
entry.importance = 4
if "confidence" in metadata:
conf_str = metadata["confidence"]
if conf_str == "LOW":
entry.confidence = 1
elif conf_str == "MEDIUM":
entry.confidence = 2
elif conf_str == "HIGH":
entry.confidence = 3
elif conf_str == "VERIFIED":
entry.confidence = 4
entries.append(entry)
success_count += 1
if i % 100 == 0:
print(f" 处理进度: {i}/{len(ids)} ({success_count} 成功)")
except Exception as e:
logger.warning(f"处理记忆 {memory_id} 失败: {e}")
continue
print(f"\n✅ 成功解析 {success_count}/{len(ids)} 条记忆元数据")
# 批量更新索引
print("\n💾 保存元数据索引...")
ms.unified_storage.metadata_index.batch_add_or_update(entries)
ms.unified_storage.metadata_index.save()
# 显示统计信息
stats = ms.unified_storage.metadata_index.get_stats()
print("\n📊 重建后的索引统计:")
print(f" - 总记忆数: {stats['total_memories']}")
print(f" - 主语数量: {stats['subjects_count']}")
print(f" - 关键词数量: {stats['keywords_count']}")
print(f" - 标签数量: {stats['tags_count']}")
print(" - 类型分布:")
for mtype, count in stats["types"].items():
print(f" - {mtype}: {count}")
print("\n✅ 元数据索引重建完成!")
except Exception as e:
logger.error(f"重建索引失败: {e}")
print(f"❌ 重建索引失败: {e}")
if __name__ == "__main__":
asyncio.run(rebuild_metadata_index())

View File

@@ -1,190 +0,0 @@
"""
MCP 集成测试脚本
测试 MCP 客户端连接、工具列表获取和工具调用功能
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.common.logger import get_logger
from src.plugin_system.core.component_registry import ComponentRegistry
from src.plugin_system.core.mcp_client_manager import MCPClientManager
logger = get_logger("test_mcp_integration")
async def test_mcp_client_manager():
"""测试 MCPClientManager 基本功能"""
print("\n" + "="*60)
print("测试 1: MCPClientManager 连接和工具列表")
print("="*60)
try:
# 初始化 MCP 客户端管理器
manager = MCPClientManager()
await manager.initialize()
print("\n✓ MCP 客户端管理器初始化成功")
print(f"已连接服务器数量: {len(manager.clients)}")
# 获取所有工具
tools = await manager.get_all_tools()
print(f"\n获取到 {len(tools)} 个 MCP 工具:")
for tool in tools:
print(f"\n 工具: {tool}")
# 注意: 这里 tool 是字符串形式的工具名称
# 如果需要工具详情,需要从其他地方获取
return manager, tools
except Exception as e:
print(f"\n✗ 测试失败: {e}")
logger.exception("MCPClientManager 测试失败")
return None, []
async def test_tool_call(manager: MCPClientManager, tools):
"""测试工具调用"""
print("\n" + "="*60)
print("测试 2: MCP 工具调用")
print("="*60)
if not tools:
print("\n⚠ 没有可用的工具进行测试")
return
try:
# 工具列表测试已在第一个测试中完成
print("\n✓ 工具列表获取成功")
print(f"可用工具数量: {len(tools)}")
except Exception as e:
print(f"\n✗ 工具调用测试失败: {e}")
logger.exception("工具调用测试失败")
async def test_component_registry_integration():
"""测试 ComponentRegistry 集成"""
print("\n" + "="*60)
print("测试 3: ComponentRegistry MCP 工具集成")
print("="*60)
try:
registry = ComponentRegistry()
# 加载 MCP 工具
await registry.load_mcp_tools()
# 获取 MCP 工具
mcp_tools = registry.get_mcp_tools()
print(f"\n✓ ComponentRegistry 加载了 {len(mcp_tools)} 个 MCP 工具")
for tool in mcp_tools:
print(f"\n 工具: {tool.name}")
print(f" 描述: {tool.description}")
print(f" 参数数量: {len(tool.parameters)}")
# 测试 is_mcp_tool 方法
is_mcp = registry.is_mcp_tool(tool.name)
print(f" is_mcp_tool 检测: {'' if is_mcp else ''}")
return mcp_tools
except Exception as e:
print(f"\n✗ ComponentRegistry 集成测试失败: {e}")
logger.exception("ComponentRegistry 集成测试失败")
return []
async def test_tool_execution(mcp_tools):
"""测试通过适配器执行工具"""
print("\n" + "="*60)
print("测试 4: MCPToolAdapter 工具执行")
print("="*60)
if not mcp_tools:
print("\n⚠ 没有可用的 MCP 工具进行测试")
return
try:
# 选择第一个工具测试
test_tool = mcp_tools[0]
print(f"\n测试工具: {test_tool.name}")
# 构建测试参数
test_args = {}
for param_name, param_type, param_desc, is_required, enum_values in test_tool.parameters:
if is_required:
# 根据类型提供默认值
from src.llm_models.payload_content.tool_option import ToolParamType
if param_type == ToolParamType.STRING:
test_args[param_name] = "test_value"
elif param_type == ToolParamType.INTEGER:
test_args[param_name] = 1
elif param_type == ToolParamType.FLOAT:
test_args[param_name] = 1.0
elif param_type == ToolParamType.BOOLEAN:
test_args[param_name] = True
print(f"测试参数: {test_args}")
# 执行工具
result = await test_tool.execute(test_args)
if result:
print("\n✓ 工具执行成功")
print(f"结果类型: {result.get('type')}")
print(f"结果内容: {result.get('content', '')[:200]}...") # 只显示前200字符
else:
print("\n✗ 工具执行失败,返回 None")
except Exception as e:
print(f"\n✗ 工具执行测试失败: {e}")
logger.exception("工具执行测试失败")
async def main():
"""主测试流程"""
print("\n" + "="*60)
print("MCP 集成测试")
print("="*60)
try:
# 测试 1: MCPClientManager 基本功能
manager, tools = await test_mcp_client_manager()
if manager:
# 测试 2: 工具调用
await test_tool_call(manager, tools)
# 测试 3: ComponentRegistry 集成
mcp_tools = await test_component_registry_integration()
# 测试 4: 工具执行
await test_tool_execution(mcp_tools)
# 关闭连接
await manager.close()
print("\n✓ MCP 客户端连接已关闭")
print("\n" + "="*60)
print("测试完成")
print("="*60 + "\n")
except KeyboardInterrupt:
print("\n\n测试被用户中断")
except Exception as e:
print(f"\n测试过程中发生错误: {e}")
logger.exception("测试失败")
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -1,292 +0,0 @@
"""
三层记忆系统测试脚本
用于验证系统各组件是否正常工作
"""
import asyncio
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
async def test_perceptual_memory():
"""测试感知记忆层"""
print("\n" + "=" * 60)
print("测试1: 感知记忆层")
print("=" * 60)
from src.memory_graph.three_tier.perceptual_manager import get_perceptual_manager
manager = get_perceptual_manager()
await manager.initialize()
# 添加测试消息
test_messages = [
("user1", "今天天气真好", 1700000000.0),
("user2", "是啊,适合出去玩", 1700000001.0),
("user1", "我们去公园吧", 1700000002.0),
("user2", "好主意!", 1700000003.0),
("user1", "带上野餐垫", 1700000004.0),
]
for sender, content, timestamp in test_messages:
message = {
"message_id": f"msg_{timestamp}",
"sender": sender,
"content": content,
"timestamp": timestamp,
"platform": "test",
"stream_id": "test_stream",
}
await manager.add_message(message)
print(f"✅ 成功添加 {len(test_messages)} 条消息")
# 测试TopK召回
results = await manager.recall_blocks("公园野餐", top_k=2)
print(f"✅ TopK召回返回 {len(results)} 个块")
if results:
print(f" 第一个块包含 {len(results[0].messages)} 条消息")
# 获取统计信息
stats = manager.get_statistics() # 不是async方法
print(f"✅ 统计信息: {stats}")
return True
async def test_short_term_memory():
"""测试短期记忆层"""
print("\n" + "=" * 60)
print("测试2: 短期记忆层")
print("=" * 60)
from src.memory_graph.three_tier.models import MemoryBlock
from src.memory_graph.three_tier.short_term_manager import get_short_term_manager
manager = get_short_term_manager()
await manager.initialize()
# 创建测试块
test_block = MemoryBlock(
id="test_block_1",
messages=[
{
"message_id": "msg1",
"sender": "user1",
"content": "我明天要参加一个重要的面试",
"timestamp": 1700000000.0,
"platform": "test",
}
],
combined_text="我明天要参加一个重要的面试",
recall_count=3,
)
# 从感知块转换为短期记忆
try:
await manager.add_from_block(test_block)
print("✅ 成功将感知块转换为短期记忆")
except Exception as e:
print(f"⚠️ 转换失败可能需要LLM: {e}")
return False
# 测试搜索
results = await manager.search_memories("面试", top_k=3)
print(f"✅ 搜索返回 {len(results)} 条记忆")
# 获取统计
stats = manager.get_statistics()
print(f"✅ 统计信息: {stats}")
return True
async def test_long_term_memory():
"""测试长期记忆层"""
print("\n" + "=" * 60)
print("测试3: 长期记忆层")
print("=" * 60)
from src.memory_graph.three_tier.long_term_manager import get_long_term_manager
manager = get_long_term_manager()
await manager.initialize()
print("✅ 长期记忆管理器初始化成功")
print(" (需要现有记忆图系统支持)")
# 获取统计
stats = manager.get_statistics()
print(f"✅ 统计信息: {stats}")
return True
async def test_unified_manager():
"""测试统一管理器"""
print("\n" + "=" * 60)
print("测试4: 统一管理器")
print("=" * 60)
from src.memory_graph.three_tier.unified_manager import UnifiedMemoryManager
manager = UnifiedMemoryManager()
await manager.initialize()
# 添加测试消息
message = {
"message_id": "unified_test_1",
"sender": "user1",
"content": "这是一条测试消息",
"timestamp": 1700000000.0,
"platform": "test",
"stream_id": "test_stream",
}
await manager.add_message(message)
print("✅ 通过统一接口添加消息成功")
# 测试搜索
results = await manager.search_memories("测试")
print(f"✅ 统一搜索返回结果:")
print(f" 感知块: {len(results.get('perceptual_blocks', []))}")
print(f" 短期记忆: {len(results.get('short_term_memories', []))}")
print(f" 长期记忆: {len(results.get('long_term_memories', []))}")
# 获取统计
stats = manager.get_statistics() # 不是async方法
print(f"✅ 综合统计:")
print(f" 感知层: {stats.get('perceptual', {})}")
print(f" 短期层: {stats.get('short_term', {})}")
print(f" 长期层: {stats.get('long_term', {})}")
return True
async def test_configuration():
"""测试配置加载"""
print("\n" + "=" * 60)
print("测试5: 配置系统")
print("=" * 60)
from src.config.config import global_config
if not hasattr(global_config, "three_tier_memory"):
print("❌ 配置类中未找到 three_tier_memory 字段")
return False
config = global_config.three_tier_memory
if config is None:
print("⚠️ 三层记忆配置为 None可能未在 bot_config.toml 中配置)")
print(" 请在 bot_config.toml 中添加 [three_tier_memory] 配置")
return False
print(f"✅ 配置加载成功")
print(f" 启用状态: {config.enable}")
print(f" 数据目录: {config.data_dir}")
print(f" 感知层最大块数: {config.perceptual_max_blocks}")
print(f" 短期层最大记忆数: {config.short_term_max_memories}")
print(f" 激活阈值: {config.activation_threshold}")
return True
async def test_integration():
"""测试系统集成"""
print("\n" + "=" * 60)
print("测试6: 系统集成")
print("=" * 60)
# 首先需要确保配置启用
from src.config.config import global_config
if not global_config.three_tier_memory or not global_config.three_tier_memory.enable:
print("⚠️ 配置未启用,跳过集成测试")
return False
# 测试单例模式
from src.memory_graph.three_tier.manager_singleton import (
get_unified_memory_manager,
initialize_unified_memory_manager,
)
# 初始化
await initialize_unified_memory_manager()
manager = get_unified_memory_manager()
if manager is None:
print("❌ 统一管理器初始化失败")
return False
print("✅ 单例模式正常工作")
# 测试多次获取
manager2 = get_unified_memory_manager()
if manager is not manager2:
print("❌ 单例模式失败(返回不同实例)")
return False
print("✅ 单例一致性验证通过")
return True
async def run_all_tests():
"""运行所有测试"""
print("\n" + "🔬" * 30)
print("三层记忆系统集成测试")
print("🔬" * 30)
tests = [
("配置系统", test_configuration),
("感知记忆层", test_perceptual_memory),
("短期记忆层", test_short_term_memory),
("长期记忆层", test_long_term_memory),
("统一管理器", test_unified_manager),
("系统集成", test_integration),
]
results = []
for name, test_func in tests:
try:
result = await test_func()
results.append((name, result))
except Exception as e:
print(f"\n❌ 测试 {name} 失败: {e}")
import traceback
traceback.print_exc()
results.append((name, False))
# 打印测试总结
print("\n" + "=" * 60)
print("测试总结")
print("=" * 60)
passed = sum(1 for _, result in results if result)
total = len(results)
for name, result in results:
status = "✅ 通过" if result else "❌ 失败"
print(f"{status} - {name}")
print(f"\n总计: {passed}/{total} 测试通过")
if passed == total:
print("\n🎉 所有测试通过!三层记忆系统工作正常。")
else:
print("\n⚠️ 部分测试失败,请查看上方详细信息。")
return passed == total
if __name__ == "__main__":
success = asyncio.run(run_all_tests())
sys.exit(0 if success else 1)

View File

@@ -1,185 +0,0 @@
"""批量更新数据库导入语句的脚本
将旧的数据库导入路径更新为新的重构后的路径:
- sqlalchemy_models -> core, core.models
- sqlalchemy_database_api -> compatibility
- database.database -> core
"""
import re
from pathlib import Path
# 定义导入映射规则
IMPORT_MAPPINGS = {
# 模型导入
r"from src\.common\.database\.sqlalchemy_models import (.+)":
r"from src.common.database.core.models import \1",
# API导入 - 需要特殊处理
r"from src\.common\.database\.sqlalchemy_database_api import (.+)":
r"from src.common.database.compatibility import \1",
# get_db_session 从 sqlalchemy_database_api 导入
r"from src\.common\.database\.sqlalchemy_database_api import get_db_session":
r"from src.common.database.core import get_db_session",
# get_db_session 从 sqlalchemy_models 导入
r"from src\.common\.database\.sqlalchemy_models import (.*)get_db_session(.*)":
lambda m: f"from src.common.database.core import {m.group(1)}get_db_session{m.group(2)}"
if "get_db_session" in m.group(0) else m.group(0),
# get_engine 导入
r"from src\.common\.database\.sqlalchemy_models import (.*)get_engine(.*)":
lambda m: f"from src.common.database.core import {m.group(1)}get_engine{m.group(2)}",
# Base 导入
r"from src\.common\.database\.sqlalchemy_models import (.*)Base(.*)":
lambda m: f"from src.common.database.core.models import {m.group(1)}Base{m.group(2)}",
# initialize_database 导入
r"from src\.common\.database\.sqlalchemy_models import initialize_database":
r"from src.common.database.core import check_and_migrate_database as initialize_database",
# database.py 导入
r"from src\.common\.database\.database import stop_database":
r"from src.common.database.core import close_engine as stop_database",
r"from src\.common\.database\.database import initialize_sql_database":
r"from src.common.database.core import check_and_migrate_database as initialize_sql_database",
}
# 需要排除的文件
EXCLUDE_PATTERNS = [
"**/database_refactoring_plan.md", # 文档文件
"**/old/**", # 旧文件目录
"**/sqlalchemy_*.py", # 旧的数据库文件本身
"**/database.py", # 旧的database文件
"**/db_*.py", # 旧的db文件
]
def should_exclude(file_path: Path) -> bool:
"""检查文件是否应该被排除"""
for pattern in EXCLUDE_PATTERNS:
if file_path.match(pattern):
return True
return False
def update_imports_in_file(file_path: Path, dry_run: bool = True) -> tuple[int, list[str]]:
"""更新单个文件中的导入语句
Args:
file_path: 文件路径
dry_run: 是否只是预览而不实际修改
Returns:
(修改次数, 修改详情列表)
"""
try:
content = file_path.read_text(encoding="utf-8")
original_content = content
changes = []
# 应用每个映射规则
for pattern, replacement in IMPORT_MAPPINGS.items():
matches = list(re.finditer(pattern, content))
for match in matches:
old_line = match.group(0)
# 处理函数类型的替换
if callable(replacement):
new_line_result = replacement(match)
new_line = new_line_result if isinstance(new_line_result, str) else old_line
else:
new_line = re.sub(pattern, replacement, old_line)
if old_line != new_line and isinstance(new_line, str):
content = content.replace(old_line, new_line, 1)
changes.append(f" - {old_line}")
changes.append(f" + {new_line}")
# 如果有修改且不是dry_run写回文件
if content != original_content:
if not dry_run:
file_path.write_text(content, encoding="utf-8")
return len(changes) // 2, changes
return 0, []
except Exception as e:
print(f"❌ 处理文件 {file_path} 时出错: {e}")
return 0, []
def main():
"""主函数"""
print("🔍 搜索需要更新导入的文件...")
# 获取项目根目录
root_dir = Path(__file__).parent.parent
# 搜索所有Python文件
all_python_files = list(root_dir.rglob("*.py"))
# 过滤掉排除的文件
target_files = [f for f in all_python_files if not should_exclude(f)]
print(f"📊 找到 {len(target_files)} 个Python文件需要检查")
print("\n" + "="*80)
# 第一遍:预览模式
print("\n🔍 预览模式 - 检查需要更新的文件...\n")
files_to_update = []
for file_path in target_files:
count, changes = update_imports_in_file(file_path, dry_run=True)
if count > 0:
files_to_update.append((file_path, count, changes))
if not files_to_update:
print("✅ 没有文件需要更新!")
return
print(f"📝 发现 {len(files_to_update)} 个文件需要更新:\n")
total_changes = 0
for file_path, count, changes in files_to_update:
rel_path = file_path.relative_to(root_dir)
print(f"\n📄 {rel_path} ({count} 处修改)")
for change in changes[:10]: # 最多显示前5对修改
print(change)
if len(changes) > 10:
print(f" ... 还有 {len(changes) - 10}")
total_changes += count
print("\n" + "="*80)
print("\n📊 统计:")
print(f" - 需要更新的文件: {len(files_to_update)}")
print(f" - 总修改次数: {total_changes}")
# 询问是否继续
print("\n" + "="*80)
response = input("\n是否执行更新?(yes/no): ").strip().lower()
if response != "yes":
print("❌ 已取消更新")
return
# 第二遍:实际更新
print("\n✨ 开始更新文件...\n")
success_count = 0
for file_path, _, _ in files_to_update:
count, _ = update_imports_in_file(file_path, dry_run=False)
if count > 0:
rel_path = file_path.relative_to(root_dir)
print(f"{rel_path} ({count} 处修改)")
success_count += 1
print("\n" + "="*80)
print(f"\n🎉 完成!成功更新 {success_count} 个文件")
if __name__ == "__main__":
main()

View File

@@ -238,6 +238,14 @@ class BatchDatabaseWriter:
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in update_data.items() if key != "stream_id"}
)
elif global_config.database.database_type == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_=update_data
)
else:
# 默认使用SQLite语法
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
@@ -264,6 +272,14 @@ class BatchDatabaseWriter:
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in update_data.items() if key != "stream_id"}
)
elif global_config.database.database_type == "postgresql":
from sqlalchemy.dialects.postgresql import insert as pg_insert
stmt = pg_insert(ChatStreams).values(stream_id=stream_id, **update_data)
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_=update_data
)
else:
from sqlalchemy.dialects.sqlite import insert as sqlite_insert

View File

@@ -4,6 +4,7 @@ import time
from rich.traceback import install
from sqlalchemy.dialects.mysql import insert as mysql_insert
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo
@@ -663,6 +664,13 @@ class ChatManager:
stmt = stmt.on_duplicate_key_update(
**{key: value for key, value in fields_to_save.items() if key != "stream_id"}
)
elif global_config.database.database_type == "postgresql":
stmt = pg_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
# PostgreSQL 需要使用 constraint 参数或正确的 index_elements
stmt = stmt.on_conflict_do_update(
index_elements=[ChatStreams.stream_id],
set_=fields_to_save
)
else:
stmt = sqlite_insert(ChatStreams).values(stream_id=s_data_dict["stream_id"], **fields_to_save)
stmt = stmt.on_conflict_do_update(index_elements=["stream_id"], set_=fields_to_save)

View File

@@ -142,8 +142,11 @@ class MessageStorageBatcher:
return None
# 将ORM对象转换为字典只包含列字段
# 排除 id 字段,让数据库自动生成(对于 PostgreSQL SERIAL 类型尤其重要)
message_dict = {}
for column in Messages.__table__.columns:
if column.name == "id":
continue # 跳过自增主键,让数据库自动生成
message_dict[column.name] = getattr(message_obj, column.name)
return message_dict

View File

@@ -1143,7 +1143,6 @@ class Prompt:
Returns:
str: 构建好的跨群聊上下文字符串。
"""
logger.info(f"Building cross context with target_user_info: {target_user_info}")
if not global_config.cross_context.enable:
return ""

View File

@@ -169,6 +169,11 @@ class ConnectionPoolManager:
self, session_factory: async_sessionmaker[AsyncSession]
) -> ConnectionInfo | None:
"""获取可复用的连接"""
# 导入方言适配器获取 ping 查询
from src.common.database.core.dialect_adapter import DialectAdapter
ping_query = DialectAdapter.get_ping_query()
async with self._lock:
# 清理过期连接
await self._cleanup_expired_connections_locked()
@@ -178,8 +183,8 @@ class ConnectionPoolManager:
if not connection_info.in_use and not connection_info.is_expired(self.max_lifetime, self.max_idle):
# 验证连接是否仍然有效
try:
# 执行一个简单的查询来验证连接
await connection_info.session.execute(text("SELECT 1"))
# 执行 ping 查询来验证连接
await connection_info.session.execute(text(ping_query))
return connection_info
except Exception as e:
logger.debug(f"连接验证失败,将移除: {e}")

View File

@@ -5,8 +5,22 @@
- 会话管理
- 模型定义
- 数据库迁移
- 方言适配
支持的数据库:
- SQLite (默认)
- MySQL
- PostgreSQL
"""
from .dialect_adapter import (
DatabaseDialect,
DialectAdapter,
DialectConfig,
get_dialect_adapter,
get_indexed_string_field,
get_text_field,
)
from .engine import close_engine, get_engine, get_engine_info
from .migration import check_and_migrate_database, create_all_tables, drop_all_tables
from .models import (
@@ -50,6 +64,10 @@ __all__ = [
"BotPersonalityInterests",
"CacheEntries",
"ChatStreams",
# Dialect Adapter
"DatabaseDialect",
"DialectAdapter",
"DialectConfig",
"Emoji",
"Expression",
"GraphEdges",
@@ -77,10 +95,13 @@ __all__ = [
# Session
"get_db_session",
"get_db_session_direct",
"get_dialect_adapter",
# Engine
"get_engine",
"get_engine_info",
"get_indexed_string_field",
"get_session_factory",
"get_string_field",
"get_text_field",
"reset_session_factory",
]

View File

@@ -0,0 +1,230 @@
"""数据库方言适配器
提供跨数据库兼容性支持,处理不同数据库之间的差异:
- SQLite: 轻量级本地数据库
- MySQL: 高性能关系型数据库
- PostgreSQL: 功能丰富的开源数据库
主要职责:
1. 提供数据库特定的类型映射
2. 处理方言特定的查询语法
3. 提供数据库特定的优化配置
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from sqlalchemy import String, Text
from sqlalchemy.types import TypeEngine
class DatabaseDialect(Enum):
"""数据库方言枚举"""
SQLITE = "sqlite"
MYSQL = "mysql"
POSTGRESQL = "postgresql"
@dataclass
class DialectConfig:
"""方言配置"""
dialect: DatabaseDialect
# 连接验证查询
ping_query: str
# 是否支持 RETURNING 子句
supports_returning: bool
# 是否支持原生 JSON 类型
supports_native_json: bool
# 是否支持数组类型
supports_arrays: bool
# 是否需要指定字符串长度用于索引
requires_length_for_index: bool
# 默认字符串长度(用于索引列)
default_string_length: int
# 事务隔离级别
isolation_level: str
# 额外的引擎参数
engine_kwargs: dict[str, Any] = field(default_factory=dict)
# 预定义的方言配置
DIALECT_CONFIGS: dict[DatabaseDialect, DialectConfig] = {
DatabaseDialect.SQLITE: DialectConfig(
dialect=DatabaseDialect.SQLITE,
ping_query="SELECT 1",
supports_returning=True, # SQLite 3.35+ 支持
supports_native_json=False,
supports_arrays=False,
requires_length_for_index=False,
default_string_length=255,
isolation_level="SERIALIZABLE",
engine_kwargs={
"connect_args": {
"check_same_thread": False,
"timeout": 60,
}
},
),
DatabaseDialect.MYSQL: DialectConfig(
dialect=DatabaseDialect.MYSQL,
ping_query="SELECT 1",
supports_returning=False, # MySQL 8.0.21+ 有限支持
supports_native_json=True, # MySQL 5.7+
supports_arrays=False,
requires_length_for_index=True, # MySQL 索引需要指定长度
default_string_length=255,
isolation_level="READ COMMITTED",
engine_kwargs={
"pool_pre_ping": True,
"pool_recycle": 3600,
},
),
DatabaseDialect.POSTGRESQL: DialectConfig(
dialect=DatabaseDialect.POSTGRESQL,
ping_query="SELECT 1",
supports_returning=True,
supports_native_json=True,
supports_arrays=True,
requires_length_for_index=False,
default_string_length=255,
isolation_level="READ COMMITTED",
engine_kwargs={
"pool_pre_ping": True,
"pool_recycle": 3600,
},
),
}
class DialectAdapter:
"""数据库方言适配器
根据当前配置的数据库类型,提供相应的类型映射和查询支持
"""
_current_dialect: DatabaseDialect | None = None
_config: DialectConfig | None = None
@classmethod
def initialize(cls, db_type: str) -> None:
"""初始化适配器
Args:
db_type: 数据库类型字符串 ("sqlite", "mysql", "postgresql")
"""
try:
cls._current_dialect = DatabaseDialect(db_type.lower())
cls._config = DIALECT_CONFIGS[cls._current_dialect]
except ValueError:
raise ValueError(f"不支持的数据库类型: {db_type},支持的类型: sqlite, mysql, postgresql")
@classmethod
def get_dialect(cls) -> DatabaseDialect:
"""获取当前数据库方言"""
if cls._current_dialect is None:
# 延迟初始化:从配置获取
from src.config.config import global_config
if global_config is None:
raise RuntimeError("配置尚未初始化,无法获取数据库方言")
cls.initialize(global_config.database.database_type)
return cls._current_dialect # type: ignore
@classmethod
def get_config(cls) -> DialectConfig:
"""获取当前方言配置"""
if cls._config is None:
cls.get_dialect() # 触发初始化
return cls._config # type: ignore
@classmethod
def get_string_type(cls, max_length: int = 255, indexed: bool = False) -> TypeEngine:
"""获取适合当前数据库的字符串类型
Args:
max_length: 最大长度
indexed: 是否用于索引
Returns:
SQLAlchemy 类型
"""
config = cls.get_config()
# MySQL 索引列需要指定长度
if config.requires_length_for_index and indexed:
return String(max_length)
# SQLite 和 PostgreSQL 可以使用 Text
if config.dialect in (DatabaseDialect.SQLITE, DatabaseDialect.POSTGRESQL):
return Text() if not indexed else String(max_length)
# MySQL 使用 VARCHAR
return String(max_length)
@classmethod
def get_ping_query(cls) -> str:
"""获取连接验证查询"""
return cls.get_config().ping_query
@classmethod
def supports_returning(cls) -> bool:
"""是否支持 RETURNING 子句"""
return cls.get_config().supports_returning
@classmethod
def supports_native_json(cls) -> bool:
"""是否支持原生 JSON 类型"""
return cls.get_config().supports_native_json
@classmethod
def get_engine_kwargs(cls) -> dict[str, Any]:
"""获取引擎额外参数"""
return cls.get_config().engine_kwargs.copy()
@classmethod
def is_sqlite(cls) -> bool:
"""是否为 SQLite"""
return cls.get_dialect() == DatabaseDialect.SQLITE
@classmethod
def is_mysql(cls) -> bool:
"""是否为 MySQL"""
return cls.get_dialect() == DatabaseDialect.MYSQL
@classmethod
def is_postgresql(cls) -> bool:
"""是否为 PostgreSQL"""
return cls.get_dialect() == DatabaseDialect.POSTGRESQL
def get_dialect_adapter() -> type[DialectAdapter]:
"""获取方言适配器类"""
return DialectAdapter
def get_indexed_string_field(max_length: int = 255) -> TypeEngine:
"""获取用于索引的字符串字段类型
这是一个便捷函数,用于在模型定义中获取适合当前数据库的字符串类型
Args:
max_length: 最大长度(对于 MySQL 是必需的)
Returns:
SQLAlchemy 类型
"""
return DialectAdapter.get_string_type(max_length, indexed=True)
def get_text_field() -> TypeEngine:
"""获取文本字段类型
用于不需要索引的大文本字段
Returns:
SQLAlchemy Text 类型
"""
return Text()

View File

@@ -1,6 +1,11 @@
"""数据库引擎管理
单一职责创建和管理SQLAlchemy异步引擎
支持的数据库类型:
- SQLite: 轻量级本地数据库,使用 aiosqlite 驱动
- MySQL: 高性能关系型数据库,使用 aiomysql 驱动
- PostgreSQL: 功能丰富的开源数据库,使用 asyncpg 驱动
"""
import asyncio
@@ -13,6 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine
from src.common.logger import get_logger
from ..utils.exceptions import DatabaseInitializationError
from .dialect_adapter import DialectAdapter
logger = get_logger("database.engine")
@@ -52,79 +58,27 @@ async def get_engine() -> AsyncEngine:
config = global_config.database
db_type = config.database_type
# 初始化方言适配器
DialectAdapter.initialize(db_type)
logger.info(f"正在初始化 {db_type.upper()} 数据库引擎...")
# 构建数据库URL和引擎参数
# 根据数据库类型构建URL和引擎参数
if db_type == "mysql":
# MySQL配置
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
if config.mysql_unix_socket:
# Unix socket连接
encoded_socket = quote_plus(config.mysql_unix_socket)
url = (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@/{config.mysql_database}"
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
)
else:
# TCP连接
url = (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
engine_kwargs = {
"echo": False,
"future": True,
"pool_size": config.connection_pool_size,
"max_overflow": config.connection_pool_size * 2,
"pool_timeout": config.connection_timeout,
"pool_recycle": 3600,
"pool_pre_ping": True,
"connect_args": {
"autocommit": config.mysql_autocommit,
"charset": config.mysql_charset,
"connect_timeout": config.connection_timeout,
},
}
logger.info(
f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
)
url, engine_kwargs = _build_mysql_config(config)
elif db_type == "postgresql":
url, engine_kwargs = _build_postgresql_config(config)
else:
# SQLite配置
if not os.path.isabs(config.sqlite_path):
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
url = f"sqlite+aiosqlite:///{db_path}"
engine_kwargs = {
"echo": False,
"future": True,
"connect_args": {
"check_same_thread": False,
"timeout": 60,
},
}
logger.info(f"SQLite配置: {db_path}")
url, engine_kwargs = _build_sqlite_config(config)
# 创建异步引擎
_engine = create_async_engine(url, **engine_kwargs)
# SQLite特定优化
# 数据库特定优化
if db_type == "sqlite":
await _enable_sqlite_optimizations(_engine)
elif db_type == "postgresql":
await _enable_postgresql_optimizations(_engine)
logger.info(f"{db_type.upper()} 数据库引擎初始化成功")
return _engine
@@ -134,6 +88,141 @@ async def get_engine() -> AsyncEngine:
raise DatabaseInitializationError(f"引擎初始化失败: {e}") from e
def _build_sqlite_config(config) -> tuple[str, dict]:
"""构建 SQLite 配置
Args:
config: 数据库配置对象
Returns:
(url, engine_kwargs) 元组
"""
if not os.path.isabs(config.sqlite_path):
ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "..", ".."))
db_path = os.path.join(ROOT_PATH, config.sqlite_path)
else:
db_path = config.sqlite_path
# 确保数据库目录存在
os.makedirs(os.path.dirname(db_path), exist_ok=True)
url = f"sqlite+aiosqlite:///{db_path}"
engine_kwargs = {
"echo": False,
"future": True,
"connect_args": {
"check_same_thread": False,
"timeout": 60,
},
}
logger.info(f"SQLite配置: {db_path}")
return url, engine_kwargs
def _build_mysql_config(config) -> tuple[str, dict]:
"""构建 MySQL 配置
Args:
config: 数据库配置对象
Returns:
(url, engine_kwargs) 元组
"""
encoded_user = quote_plus(config.mysql_user)
encoded_password = quote_plus(config.mysql_password)
if config.mysql_unix_socket:
# Unix socket连接
encoded_socket = quote_plus(config.mysql_unix_socket)
url = (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@/{config.mysql_database}"
f"?unix_socket={encoded_socket}&charset={config.mysql_charset}"
)
else:
# TCP连接
url = (
f"mysql+aiomysql://{encoded_user}:{encoded_password}"
f"@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
f"?charset={config.mysql_charset}"
)
engine_kwargs = {
"echo": False,
"future": True,
"pool_size": config.connection_pool_size,
"max_overflow": config.connection_pool_size * 2,
"pool_timeout": config.connection_timeout,
"pool_recycle": 3600,
"pool_pre_ping": True,
"connect_args": {
"autocommit": config.mysql_autocommit,
"charset": config.mysql_charset,
"connect_timeout": config.connection_timeout,
},
}
logger.info(
f"MySQL配置: {config.mysql_user}@{config.mysql_host}:{config.mysql_port}/{config.mysql_database}"
)
return url, engine_kwargs
def _build_postgresql_config(config) -> tuple[str, dict]:
"""构建 PostgreSQL 配置
Args:
config: 数据库配置对象
Returns:
(url, engine_kwargs) 元组
"""
encoded_user = quote_plus(config.postgresql_user)
encoded_password = quote_plus(config.postgresql_password)
# 构建基本 URL
url = (
f"postgresql+asyncpg://{encoded_user}:{encoded_password}"
f"@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
)
# SSL 配置
connect_args = {}
if config.postgresql_ssl_mode != "disable":
ssl_config = {"ssl": config.postgresql_ssl_mode}
if config.postgresql_ssl_ca:
ssl_config["ssl_ca"] = config.postgresql_ssl_ca
if config.postgresql_ssl_cert:
ssl_config["ssl_cert"] = config.postgresql_ssl_cert
if config.postgresql_ssl_key:
ssl_config["ssl_key"] = config.postgresql_ssl_key
connect_args.update(ssl_config)
# 设置 schema如果不是 public
if config.postgresql_schema and config.postgresql_schema != "public":
connect_args["server_settings"] = {"search_path": config.postgresql_schema}
engine_kwargs = {
"echo": False,
"future": True,
"pool_size": config.connection_pool_size,
"max_overflow": config.connection_pool_size * 2,
"pool_timeout": config.connection_timeout,
"pool_recycle": 3600,
"pool_pre_ping": True,
}
if connect_args:
engine_kwargs["connect_args"] = connect_args
logger.info(
f"PostgreSQL配置: {config.postgresql_user}@{config.postgresql_host}:{config.postgresql_port}/{config.postgresql_database}"
)
return url, engine_kwargs
async def close_engine():
"""关闭数据库引擎
@@ -181,6 +270,33 @@ async def _enable_sqlite_optimizations(engine: AsyncEngine):
logger.warning(f"⚠️ SQLite性能优化失败: {e},将使用默认配置")
async def _enable_postgresql_optimizations(engine: AsyncEngine):
"""启用PostgreSQL性能优化
优化项:
- 设置合适的 work_mem
- 启用 JIT 编译(如果可用)
- 设置合适的 statement_timeout
Args:
engine: SQLAlchemy异步引擎
"""
try:
async with engine.begin() as conn:
# 设置会话级别的参数
# work_mem: 排序和哈希操作的内存64MB
await conn.execute(text("SET work_mem = '64MB'"))
# 设置语句超时5分钟
await conn.execute(text("SET statement_timeout = '300000'"))
# 启用自动 EXPLAIN可选用于调试
# await conn.execute(text("SET auto_explain.log_min_duration = '1000'"))
logger.info("✅ PostgreSQL性能优化已启用")
except Exception as e:
logger.warning(f"⚠️ PostgreSQL性能优化失败: {e},将使用默认配置")
async def get_engine_info() -> dict:
"""获取引擎信息(用于监控和调试)

View File

@@ -99,12 +99,17 @@ async def check_and_migrate_database(existing_engine=None):
def add_columns_sync(conn):
dialect = conn.dialect
compiler = dialect.ddl_compiler(dialect, None)
for column_name in missing_columns:
column = table.c[column_name]
column_type = compiler.get_column_specification(column)
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type}"
# 获取列类型的 SQL 表示
# 使用 compile 方法获取正确的类型字符串
type_compiler = dialect.type_compiler(dialect)
column_type_sql = column.type.compile(dialect=dialect)
# 构建 ALTER TABLE 语句
sql = f"ALTER TABLE {table.name} ADD COLUMN {column.name} {column_type_sql}"
if column.default:
# 手动处理不同方言的默认值
@@ -114,26 +119,18 @@ async def check_and_migrate_database(existing_engine=None):
):
# SQLite 将布尔值存储为 0 或 1
default_value = "1" if default_arg else "0"
elif hasattr(compiler, "render_literal_value"):
try:
# 尝试使用 render_literal_value
default_value = compiler.render_literal_value(
default_arg, column.type
)
except AttributeError:
# 如果失败,则回退到简单的字符串转换
default_value = (
f"'{default_arg}'"
if isinstance(default_arg, str)
else str(default_arg)
)
elif dialect.name == "mysql" and isinstance(default_arg, bool):
# MySQL 也使用 1/0 表示布尔值
default_value = "1" if default_arg else "0"
elif isinstance(default_arg, bool):
# PostgreSQL 使用 TRUE/FALSE
default_value = "TRUE" if default_arg else "FALSE"
elif isinstance(default_arg, str):
default_value = f"'{default_arg}'"
elif default_arg is None:
default_value = "NULL"
else:
# 对于没有 render_literal_value 的旧版或特定方言
default_value = (
f"'{default_arg}'"
if isinstance(default_arg, str)
else str(default_arg)
)
default_value = str(default_arg)
sql += f" DEFAULT {default_value}"

View File

@@ -3,6 +3,11 @@
本文件只包含纯模型定义使用SQLAlchemy 2.0的Mapped类型注解风格。
引擎和会话管理已移至core/engine.py和core/session.py。
支持的数据库类型:
- SQLite: 使用 Text 类型
- MySQL: 使用 VARCHAR(max_length) 用于索引字段
- PostgreSQL: 使用 Text 类型PostgreSQL 的 Text 类型性能与 VARCHAR 相当)
所有模型使用统一的类型注解风格:
field_name: Mapped[PyType] = mapped_column(Type, ...)
@@ -20,16 +25,34 @@ from sqlalchemy.orm import Mapped, mapped_column
Base = declarative_base()
# MySQL兼容的字段类型辅助函数
# 数据库兼容的字段类型辅助函数
def get_string_field(max_length=255, **kwargs):
"""
根据数据库类型返回合适的字符串字段
MySQL需要指定长度的VARCHAR用于索引SQLite可以使用Text
根据数据库类型返回合适的字符串字段类型
对于需要索引的字段:
- MySQL: 必须使用 VARCHAR(max_length),因为索引需要指定长度
- PostgreSQL: 可以使用 Text但为了兼容性使用 VARCHAR
- SQLite: 可以使用 Text无长度限制
Args:
max_length: 最大长度(对于 MySQL 是必需的)
**kwargs: 传递给 String/Text 的额外参数
Returns:
SQLAlchemy 类型
"""
from src.config.config import global_config
if global_config.database.database_type == "mysql":
db_type = global_config.database.database_type
# MySQL 索引需要指定长度的 VARCHAR
if db_type == "mysql":
return String(max_length, **kwargs)
# PostgreSQL 可以使用 Text但为了跨数据库迁移兼容性使用 VARCHAR
elif db_type == "postgresql":
return String(max_length, **kwargs)
# SQLite 使用 Text无长度限制
else:
return Text(**kwargs)
@@ -477,7 +500,7 @@ class BanUser(Base):
__tablename__ = "ban_users"
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
platform: Mapped[str] = mapped_column(Text, nullable=False)
platform: Mapped[str] = mapped_column(get_string_field(50), nullable=False) # 使用有限长度,以便创建索引
user_id: Mapped[str] = mapped_column(get_string_field(50), nullable=False, index=True)
violation_num: Mapped[int] = mapped_column(Integer, nullable=False, default=0, index=True)
reason: Mapped[str] = mapped_column(Text, nullable=False)

View File

@@ -1,6 +1,11 @@
"""数据库会话管理
单一职责:提供数据库会话工厂和上下文管理器
支持的数据库类型:
- SQLite: 设置 PRAGMA 参数优化并发
- MySQL: 无特殊会话设置
- PostgreSQL: 可选设置 schema 搜索路径
"""
import asyncio
@@ -53,12 +58,43 @@ async def get_session_factory() -> async_sessionmaker:
return _session_factory
async def _apply_session_settings(session: AsyncSession, db_type: str) -> None:
"""应用数据库特定的会话设置
Args:
session: 数据库会话
db_type: 数据库类型
"""
try:
if db_type == "sqlite":
# SQLite 特定的 PRAGMA 设置
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
elif db_type == "postgresql":
# PostgreSQL 特定设置(如果需要)
# 可以设置 schema 搜索路径等
from src.config.config import global_config
schema = global_config.database.postgresql_schema
if schema and schema != "public":
await session.execute(text(f"SET search_path TO {schema}"))
# MySQL 通常不需要会话级别的特殊设置
except Exception:
# 复用连接时设置可能已存在,忽略错误
pass
@asynccontextmanager
async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
"""获取数据库会话上下文管理器
这是数据库操作的主要入口点,通过连接池管理器提供透明的连接复用。
支持的数据库:
- SQLite: 自动设置 busy_timeout 和外键约束
- MySQL: 直接使用,无特殊设置
- PostgreSQL: 支持自定义 schema
使用示例:
async with get_db_session() as session:
result = await session.execute(select(User))
@@ -75,16 +111,10 @@ async def get_db_session() -> AsyncGenerator[AsyncSession, None]:
# 使用连接池管理器(透明复用连接)
async with pool_manager.get_session(session_factory) as session:
# 为SQLite设置特定的PRAGMA
# 获取数据库类型并应用特定设置
from src.config.config import global_config
if global_config.database.database_type == "sqlite":
try:
await session.execute(text("PRAGMA busy_timeout = 60000"))
await session.execute(text("PRAGMA foreign_keys = ON"))
except Exception:
# 复用连接时PRAGMA可能已设置忽略错误
pass
await _apply_session_settings(session, global_config.database.database_type)
yield session
@@ -103,6 +133,11 @@ async def get_db_session_direct() -> AsyncGenerator[AsyncSession, None]:
async with session_factory() as session:
try:
# 应用数据库特定设置
from src.config.config import global_config
await _apply_session_settings(session, global_config.database.database_type)
yield session
except Exception:
await session.rollback()

View File

@@ -373,8 +373,14 @@ class AdaptiveBatchScheduler:
"""批量执行插入操作"""
async with get_db_session() as session:
try:
# 收集数据
all_data = [op.data for op in operations if op.data]
# 收集数据,并过滤掉 id=None 的情况(让数据库自动生成)
all_data = []
for op in operations:
if op.data:
# 过滤掉 id 为 None 的键,让数据库自动生成主键
filtered_data = {k: v for k, v in op.data.items() if not (k == "id" and v is None)}
all_data.append(filtered_data)
if not all_data:
return

View File

@@ -16,8 +16,10 @@ from src.config.config_base import ValidatedConfigBase
class DatabaseConfig(ValidatedConfigBase):
"""数据库配置类"""
database_type: Literal["sqlite", "mysql"] = Field(default="sqlite", description="数据库类型")
database_type: Literal["sqlite", "mysql", "postgresql"] = Field(default="sqlite", description="数据库类型")
sqlite_path: str = Field(default="data/MaiBot.db", description="SQLite数据库文件路径")
# MySQL 配置
mysql_host: str = Field(default="localhost", description="MySQL服务器地址")
mysql_port: int = Field(default=3306, ge=1, le=65535, description="MySQL服务器端口")
mysql_database: str = Field(default="maibot", description="MySQL数据库名")
@@ -33,6 +35,22 @@ class DatabaseConfig(ValidatedConfigBase):
mysql_ssl_key: str = Field(default="", description="SSL密钥路径")
mysql_autocommit: bool = Field(default=True, description="自动提交事务")
mysql_sql_mode: str = Field(default="TRADITIONAL", description="SQL模式")
# PostgreSQL 配置
postgresql_host: str = Field(default="localhost", description="PostgreSQL服务器地址")
postgresql_port: int = Field(default=5432, ge=1, le=65535, description="PostgreSQL服务器端口")
postgresql_database: str = Field(default="maibot", description="PostgreSQL数据库名")
postgresql_user: str = Field(default="postgres", description="PostgreSQL用户名")
postgresql_password: str = Field(default="", description="PostgreSQL密码")
postgresql_schema: str = Field(default="public", description="PostgreSQL模式名")
postgresql_ssl_mode: Literal["disable", "allow", "prefer", "require", "verify-ca", "verify-full"] = Field(
default="prefer", description="PostgreSQL SSL模式"
)
postgresql_ssl_ca: str = Field(default="", description="PostgreSQL SSL CA证书路径")
postgresql_ssl_cert: str = Field(default="", description="PostgreSQL SSL客户端证书路径")
postgresql_ssl_key: str = Field(default="", description="PostgreSQL SSL密钥路径")
# 通用连接池配置
connection_pool_size: int = Field(default=10, ge=1, description="连接池大小")
connection_timeout: int = Field(default=10, ge=1, description="连接超时时间")

View File

@@ -1,5 +1,5 @@
[inner]
version = "7.8.3"
version = "7.9.0"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值
@@ -12,7 +12,7 @@ version = "7.8.3"
#----以上是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
[database]# 数据库配置
database_type = "sqlite" # 数据库类型,支持 "sqlite" 或 "mysql"
database_type = "sqlite" # 数据库类型,支持 "sqlite"、"mysql" 或 "postgresql"
# SQLite 配置(当 database_type = "sqlite" 时使用)
sqlite_path = "data/MaiBot.db" # SQLite数据库文件路径
@@ -36,8 +36,22 @@ mysql_ssl_key = "" # SSL客户端密钥路径
mysql_autocommit = true # 自动提交事务
mysql_sql_mode = "TRADITIONAL" # SQL模式
# 连接池配置
connection_pool_size = 10 # 连接池大小仅MySQL有效
# PostgreSQL 配置(当 database_type = "postgresql" 时使用)
postgresql_host = "localhost" # PostgreSQL服务器地址
postgresql_port = 5432 # PostgreSQL服务器端口
postgresql_database = "maibot" # PostgreSQL数据库名
postgresql_user = "postgres" # PostgreSQL用户名
postgresql_password = "" # PostgreSQL密码
postgresql_schema = "public" # PostgreSQL模式名schema
# PostgreSQL SSL 配置
postgresql_ssl_mode = "prefer" # SSL模式: disable, allow, prefer, require, verify-ca, verify-full
postgresql_ssl_ca = "" # SSL CA证书路径
postgresql_ssl_cert = "" # SSL客户端证书路径
postgresql_ssl_key = "" # SSL客户端密钥路径
# 连接池配置MySQL 和 PostgreSQL 有效)
connection_pool_size = 10 # 连接池大小
connection_timeout = 10 # 连接超时时间(秒)
# 批量动作记录存储配置