This commit is contained in:
明天好像没什么
2025-11-07 21:01:45 +08:00
parent 80b040da2f
commit c8d7c09625
49 changed files with 854 additions and 872 deletions

View File

@@ -17,19 +17,19 @@
import argparse
import json
from pathlib import Path
from typing import Dict, Any, List, Tuple
import logging
from pathlib import Path
from typing import Any
import orjson
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[
logging.StreamHandler(),
logging.FileHandler('embedding_cleanup.log', encoding='utf-8')
logging.FileHandler("embedding_cleanup.log", encoding="utf-8")
]
)
logger = logging.getLogger(__name__)
@@ -49,13 +49,13 @@ class EmbeddingCleaner:
self.cleaned_files = []
self.errors = []
self.stats = {
'files_processed': 0,
'embedings_removed': 0,
'bytes_saved': 0,
'nodes_processed': 0
"files_processed": 0,
"embedings_removed": 0,
"bytes_saved": 0,
"nodes_processed": 0
}
def find_json_files(self) -> List[Path]:
def find_json_files(self) -> list[Path]:
"""查找可能包含向量数据的 JSON 文件"""
json_files = []
@@ -65,7 +65,7 @@ class EmbeddingCleaner:
json_files.append(memory_graph_file)
# 测试数据文件
test_dir = self.data_dir / "test_*"
self.data_dir / "test_*"
for test_path in self.data_dir.glob("test_*/memory_graph.json"):
if test_path.exists():
json_files.append(test_path)
@@ -82,7 +82,7 @@ class EmbeddingCleaner:
logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件")
return json_files
def analyze_embedding_in_data(self, data: Dict[str, Any]) -> int:
def analyze_embedding_in_data(self, data: dict[str, Any]) -> int:
"""
分析数据中的 embedding 字段数量
@@ -97,7 +97,7 @@ class EmbeddingCleaner:
def count_embeddings(obj):
nonlocal embedding_count
if isinstance(obj, dict):
if 'embedding' in obj:
if "embedding" in obj:
embedding_count += 1
for value in obj.values():
count_embeddings(value)
@@ -108,7 +108,7 @@ class EmbeddingCleaner:
count_embeddings(data)
return embedding_count
def clean_embedding_from_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Any], int]:
def clean_embedding_from_data(self, data: dict[str, Any]) -> tuple[dict[str, Any], int]:
"""
从数据中移除 embedding 字段
@@ -123,8 +123,8 @@ class EmbeddingCleaner:
def remove_embeddings(obj):
nonlocal removed_count
if isinstance(obj, dict):
if 'embedding' in obj:
del obj['embedding']
if "embedding" in obj:
del obj["embedding"]
removed_count += 1
for value in obj.values():
remove_embeddings(value)
@@ -162,14 +162,14 @@ class EmbeddingCleaner:
data = orjson.loads(original_content)
except orjson.JSONDecodeError:
# 回退到标准 json
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, encoding="utf-8") as f:
data = json.load(f)
# 分析 embedding 数据
embedding_count = self.analyze_embedding_in_data(data)
if embedding_count == 0:
logger.info(f" ✓ 文件中没有 embedding 数据,跳过")
logger.info(" ✓ 文件中没有 embedding 数据,跳过")
return True
logger.info(f" 发现 {embedding_count} 个 embedding 字段")
@@ -193,30 +193,30 @@ class EmbeddingCleaner:
cleaned_data,
indent=2,
ensure_ascii=False
).encode('utf-8')
).encode("utf-8")
cleaned_size = len(cleaned_content)
bytes_saved = original_size - cleaned_size
# 原子写入
temp_file = file_path.with_suffix('.tmp')
temp_file = file_path.with_suffix(".tmp")
temp_file.write_bytes(cleaned_content)
temp_file.replace(file_path)
logger.info(f" ✓ 清理完成:")
logger.info(" ✓ 清理完成:")
logger.info(f" - 移除 embedding 字段: {removed_count}")
logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)")
logger.info(f" - 新文件大小: {cleaned_size:,} 字节")
# 更新统计
self.stats['embedings_removed'] += removed_count
self.stats['bytes_saved'] += bytes_saved
self.stats["embedings_removed"] += removed_count
self.stats["bytes_saved"] += bytes_saved
else:
logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段")
self.stats['embedings_removed'] += embedding_count
self.stats["embedings_removed"] += embedding_count
self.stats['files_processed'] += 1
self.stats["files_processed"] += 1
self.cleaned_files.append(file_path)
return True
@@ -236,12 +236,12 @@ class EmbeddingCleaner:
节点数量
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
with open(file_path, encoding="utf-8") as f:
data = json.load(f)
node_count = 0
if 'nodes' in data and isinstance(data['nodes'], list):
node_count = len(data['nodes'])
if "nodes" in data and isinstance(data["nodes"], list):
node_count = len(data["nodes"])
return node_count
@@ -268,7 +268,7 @@ class EmbeddingCleaner:
# 统计总节点数
total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files)
self.stats['nodes_processed'] = total_nodes
self.stats["nodes_processed"] = total_nodes
logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点")
@@ -295,8 +295,8 @@ class EmbeddingCleaner:
if not dry_run:
logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节")
if self.stats['bytes_saved'] > 0:
mb_saved = self.stats['bytes_saved'] / 1024 / 1024
if self.stats["bytes_saved"] > 0:
mb_saved = self.stats["bytes_saved"] / 1024 / 1024
logger.info(f"节省空间: {mb_saved:.2f} MB")
if self.errors:
@@ -342,7 +342,7 @@ def main():
print(" 请确保向量数据库正在正常工作。")
print()
response = input("确认继续?(yes/no): ")
if response.lower() not in ['yes', 'y', '']:
if response.lower() not in ["yes", "y", ""]:
print("操作已取消")
return
@@ -352,4 +352,4 @@ def main():
if __name__ == "__main__":
main()
main()

View File

@@ -10,10 +10,10 @@
示例:
# 进程监控(启动 bot 并监控)
python scripts/memory_profiler.py --monitor --interval 10
# 对象分析(深度对象统计)
python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt
# 生成可视化图表
python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15
"""
@@ -22,7 +22,6 @@ import argparse
import asyncio
import gc
import json
import os
import subprocess
import sys
import threading
@@ -30,7 +29,6 @@ import time
from collections import defaultdict
from datetime import datetime
from pathlib import Path
from typing import Dict, List, Optional
import psutil
@@ -56,29 +54,29 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
if bot_process.pid is None:
print("❌ Bot 进程 PID 为空")
return
print(f"🔍 开始监控 Bot 内存PID: {bot_process.pid}")
print(f"监控间隔: {interval}")
print("按 Ctrl+C 停止监控和 Bot\n")
try:
process = psutil.Process(bot_process.pid)
except psutil.NoSuchProcess:
print("❌ 无法找到 Bot 进程")
return
history = []
iteration = 0
try:
while bot_process.poll() is None:
try:
mem_info = process.memory_info()
mem_percent = process.memory_percent()
children = process.children(recursive=True)
children_mem = sum(child.memory_info().rss for child in children)
info = {
"timestamp": time.strftime("%H:%M:%S"),
"rss_mb": mem_info.rss / 1024 / 1024,
@@ -87,24 +85,24 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
"children_count": len(children),
"children_mem_mb": children_mem / 1024 / 1024,
}
history.append(info)
iteration += 1
print(f"{'=' * 80}")
print(f"检查点 #{iteration} - {info['timestamp']}")
print(f"Bot 进程 (PID: {bot_process.pid})")
print(f" RSS: {info['rss_mb']:.2f} MB")
print(f" VMS: {info['vms_mb']:.2f} MB")
print(f" 占比: {info['percent']:.2f}%")
if children:
print(f" 子进程: {info['children_count']}")
print(f" 子进程内存: {info['children_mem_mb']:.2f} MB")
total_mem = info['rss_mb'] + info['children_mem_mb']
total_mem = info["rss_mb"] + info["children_mem_mb"]
print(f" 总内存: {total_mem:.2f} MB")
print(f"\n 📋 子进程详情:")
print("\n 📋 子进程详情:")
for idx, child in enumerate(children, 1):
try:
child_mem = child.memory_info().rss / 1024 / 1024
@@ -116,30 +114,30 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
print(f" 命令: {child_cmdline}")
except (psutil.NoSuchProcess, psutil.AccessDenied):
print(f" [{idx}] 无法访问进程信息")
if len(history) > 1:
prev = history[-2]
rss_diff = info['rss_mb'] - prev['rss_mb']
print(f"\n变化:")
rss_diff = info["rss_mb"] - prev["rss_mb"]
print("\n变化:")
print(f" RSS: {rss_diff:+.2f} MB")
if rss_diff > 10:
print(f" ⚠️ 内存增长较快!")
if info['rss_mb'] > 1000:
print(f" ⚠️ 内存使用超过 1GB")
print(" ⚠️ 内存增长较快!")
if info["rss_mb"] > 1000:
print(" ⚠️ 内存使用超过 1GB")
print(f"{'=' * 80}\n")
await asyncio.sleep(interval)
except psutil.NoSuchProcess:
print("\n❌ Bot 进程已结束")
break
except Exception as e:
print(f"\n❌ 监控出错: {e}")
break
except KeyboardInterrupt:
print("\n\n⚠️ 用户中断监控")
finally:
if history and bot_process.pid:
save_process_history(history, bot_process.pid)
@@ -149,25 +147,25 @@ def save_process_history(history: list, pid: int):
"""保存进程监控历史"""
output_dir = Path("data/memory_diagnostics")
output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = output_dir / f"process_monitor_{timestamp}_pid{pid}.txt"
with open(output_file, "w", encoding="utf-8") as f:
f.write("Bot 进程内存监控历史记录\n")
f.write("=" * 80 + "\n\n")
f.write(f"Bot PID: {pid}\n\n")
for info in history:
f.write(f"时间: {info['timestamp']}\n")
f.write(f"RSS: {info['rss_mb']:.2f} MB\n")
f.write(f"VMS: {info['vms_mb']:.2f} MB\n")
f.write(f"占比: {info['percent']:.2f}%\n")
if info['children_count'] > 0:
if info["children_count"] > 0:
f.write(f"子进程: {info['children_count']}\n")
f.write(f"子进程内存: {info['children_mem_mb']:.2f} MB\n")
f.write("\n")
print(f"\n✅ 监控历史已保存到: {output_file}")
@@ -182,28 +180,28 @@ async def run_monitor_mode(interval: int):
print(" 3. 显示子进程详细信息")
print(" 4. 自动保存监控历史")
print("=" * 80 + "\n")
project_root = Path(__file__).parent.parent
bot_file = project_root / "bot.py"
if not bot_file.exists():
print(f"❌ 找不到 bot.py: {bot_file}")
return 1
# 检测虚拟环境
venv_python = project_root / ".venv" / "Scripts" / "python.exe"
if not venv_python.exists():
venv_python = project_root / ".venv" / "bin" / "python"
if venv_python.exists():
python_exe = str(venv_python)
print(f"🐍 使用虚拟环境: {venv_python}")
else:
python_exe = sys.executable
print(f"⚠️ 未找到虚拟环境,使用当前 Python: {python_exe}")
print(f"🤖 启动 Bot: {bot_file}")
bot_process = subprocess.Popen(
[python_exe, str(bot_file)],
cwd=str(project_root),
@@ -212,9 +210,9 @@ async def run_monitor_mode(interval: int):
text=True,
bufsize=1,
)
await asyncio.sleep(2)
if bot_process.poll() is not None:
print("❌ Bot 启动失败")
if bot_process.stdout:
@@ -222,9 +220,9 @@ async def run_monitor_mode(interval: int):
if output:
print(f"\nBot 输出:\n{output}")
return 1
print(f"✅ Bot 已启动 (PID: {bot_process.pid})\n")
# 启动输出读取线程
def read_bot_output():
if bot_process.stdout:
@@ -233,15 +231,15 @@ async def run_monitor_mode(interval: int):
print(f"[Bot] {line}", end="")
except Exception:
pass
output_thread = threading.Thread(target=read_bot_output, daemon=True)
output_thread.start()
try:
await monitor_bot_process(bot_process, interval)
except KeyboardInterrupt:
print("\n\n⚠️ 用户中断")
if bot_process.poll() is None:
print("\n正在停止 Bot...")
bot_process.terminate()
@@ -251,9 +249,9 @@ async def run_monitor_mode(interval: int):
print("⚠️ 强制终止 Bot...")
bot_process.kill()
bot_process.wait()
print("✅ Bot 已停止")
return 0
@@ -263,8 +261,8 @@ async def run_monitor_mode(interval: int):
class ObjectMemoryProfiler:
"""对象级内存分析器"""
def __init__(self, interval: int = 10, output_file: Optional[str] = None, object_limit: int = 20):
def __init__(self, interval: int = 10, output_file: str | None = None, object_limit: int = 20):
self.interval = interval
self.output_file = output_file
self.object_limit = object_limit
@@ -273,23 +271,23 @@ class ObjectMemoryProfiler:
if PYMPLER_AVAILABLE:
self.tracker = tracker.SummaryTracker()
self.iteration = 0
def get_object_stats(self) -> Dict:
def get_object_stats(self) -> dict:
"""获取当前进程的对象统计(所有线程)"""
if not PYMPLER_AVAILABLE:
return {}
try:
gc.collect()
all_objects = muppy.get_objects()
sum_data = summary.summarize(all_objects)
# 按总大小第3个元素降序排序
sorted_sum_data = sorted(sum_data, key=lambda x: x[2], reverse=True)
# 按模块统计内存
module_stats = self._get_module_stats(all_objects)
threads = threading.enumerate()
thread_info = [
{
@@ -299,13 +297,13 @@ class ObjectMemoryProfiler:
}
for t in threads
]
gc_stats = {
"collections": gc.get_count(),
"garbage": len(gc.garbage),
"tracked": len(gc.get_objects()),
}
return {
"summary": sorted_sum_data[:self.object_limit],
"module_stats": module_stats,
@@ -316,52 +314,52 @@ class ObjectMemoryProfiler:
except Exception as e:
print(f"❌ 获取对象统计失败: {e}")
return {}
def _get_module_stats(self, all_objects: list) -> Dict:
def _get_module_stats(self, all_objects: list) -> dict:
"""统计各模块的内存占用"""
module_mem = defaultdict(lambda: {"count": 0, "size": 0})
for obj in all_objects:
try:
# 获取对象所属模块
obj_type = type(obj)
module_name = obj_type.__module__
if module_name:
# 获取顶级模块名(例如 src.chat.xxx -> src
top_module = module_name.split('.')[0]
top_module = module_name.split(".")[0]
obj_size = sys.getsizeof(obj)
module_mem[top_module]["count"] += 1
module_mem[top_module]["size"] += obj_size
except Exception:
# 忽略无法获取大小的对象
continue
# 转换为列表并按大小排序
sorted_modules = sorted(
[(mod, stats["count"], stats["size"])
[(mod, stats["count"], stats["size"])
for mod, stats in module_mem.items()],
key=lambda x: x[2],
reverse=True
)
return {
"top_modules": sorted_modules[:20], # 前20个模块
"total_modules": len(module_mem)
}
def print_stats(self, stats: Dict, iteration: int):
def print_stats(self, stats: dict, iteration: int):
"""打印统计信息"""
print("\n" + "=" * 80)
print(f"🔍 对象级内存分析 #{iteration} - {time.strftime('%H:%M:%S')}")
print("=" * 80)
if "summary" in stats:
print(f"\n📦 对象统计 (前 {self.object_limit} 个类型):\n")
print(f"{'类型':<50} {'数量':>12} {'总大小':>15}")
print("-" * 80)
for obj_type, obj_count, obj_size in stats["summary"]:
if obj_size >= 1024 * 1024 * 1024:
size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB"
@@ -371,14 +369,14 @@ class ObjectMemoryProfiler:
size_str = f"{obj_size / 1024:.2f} KB"
else:
size_str = f"{obj_size} B"
print(f"{obj_type:<50} {obj_count:>12,} {size_str:>15}")
if "module_stats" in stats and stats["module_stats"]:
print(f"\n📚 模块内存占用 (前 20 个模块):\n")
if stats.get("module_stats"):
print("\n📚 模块内存占用 (前 20 个模块):\n")
print(f"{'模块名':<40} {'对象数':>12} {'总内存':>15}")
print("-" * 80)
for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]:
if obj_size >= 1024 * 1024 * 1024:
size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB"
@@ -388,46 +386,46 @@ class ObjectMemoryProfiler:
size_str = f"{obj_size / 1024:.2f} KB"
else:
size_str = f"{obj_size} B"
print(f"{module_name:<40} {obj_count:>12,} {size_str:>15}")
print(f"\n 总模块数: {stats['module_stats']['total_modules']}")
if "threads" in stats:
print(f"\n🧵 线程信息 ({len(stats['threads'])} 个):")
for idx, t in enumerate(stats["threads"], 1):
status = "" if t["alive"] else ""
daemon = "(守护)" if t["daemon"] else ""
print(f" [{idx}] {status} {t['name']} {daemon}")
if "gc_stats" in stats:
gc_stats = stats["gc_stats"]
print(f"\n🗑️ 垃圾回收:")
print("\n🗑️ 垃圾回收:")
print(f" 代 0: {gc_stats['collections'][0]:,}")
print(f" 代 1: {gc_stats['collections'][1]:,}")
print(f" 代 2: {gc_stats['collections'][2]:,}")
print(f" 追踪对象: {gc_stats['tracked']:,}")
if "total_objects" in stats:
print(f"\n📊 总对象数: {stats['total_objects']:,}")
print("=" * 80 + "\n")
def print_diff(self):
"""打印对象变化"""
if not PYMPLER_AVAILABLE or not self.tracker:
return
print("\n📈 对象变化分析:")
print("-" * 80)
self.tracker.print_diff()
print("-" * 80)
def save_to_file(self, stats: Dict):
def save_to_file(self, stats: dict):
"""保存统计信息到文件"""
if not self.output_file:
return
try:
# 保存文本
with open(self.output_file, "a", encoding="utf-8") as f:
@@ -435,91 +433,91 @@ class ObjectMemoryProfiler:
f.write(f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"迭代: #{self.iteration}\n")
f.write(f"{'=' * 80}\n\n")
if "summary" in stats:
f.write("对象统计:\n")
for obj_type, obj_count, obj_size in stats["summary"]:
f.write(f" {obj_type}: {obj_count:,} 个, {obj_size:,} 字节\n")
if "module_stats" in stats and stats["module_stats"]:
if stats.get("module_stats"):
f.write("\n模块统计 (前 20 个):\n")
for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]:
f.write(f" {module_name}: {obj_count:,} 个对象, {obj_size:,} 字节\n")
f.write(f"\n总对象数: {stats.get('total_objects', 0):,}\n")
f.write(f"线程数: {len(stats.get('threads', []))}\n")
# 保存 JSONL
jsonl_path = str(self.output_file) + ".jsonl"
record = {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'),
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"iteration": self.iteration,
"total_objects": stats.get("total_objects", 0),
"threads": stats.get("threads", []),
"gc_stats": stats.get("gc_stats", {}),
"summary": [
{"type": t, "count": c, "size": s}
{"type": t, "count": c, "size": s}
for (t, c, s) in stats.get("summary", [])
],
"module_stats": stats.get("module_stats", {}),
}
with open(jsonl_path, "a", encoding="utf-8") as jf:
jf.write(json.dumps(record, ensure_ascii=False) + "\n")
if self.iteration == 1:
print(f"💾 数据保存到: {self.output_file}")
print(f"💾 结构化数据: {jsonl_path}")
except Exception as e:
print(f"⚠️ 保存文件失败: {e}")
def start_monitoring(self):
"""启动监控线程"""
self.running = True
def monitor_loop():
print(f"🚀 对象分析器已启动")
print("🚀 对象分析器已启动")
print(f" 监控间隔: {self.interval}")
print(f" 对象类型限制: {self.object_limit}")
print(f" 输出文件: {self.output_file or ''}")
print()
while self.running:
try:
self.iteration += 1
stats = self.get_object_stats()
self.print_stats(stats, self.iteration)
if self.iteration % 3 == 0 and self.tracker:
self.print_diff()
if self.output_file:
self.save_to_file(stats)
time.sleep(self.interval)
except Exception as e:
print(f"❌ 监控出错: {e}")
import traceback
traceback.print_exc()
monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
monitor_thread.start()
print(f"✓ 监控线程已启动\n")
print("✓ 监控线程已启动\n")
def stop(self):
"""停止监控"""
self.running = False
def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
def run_objects_mode(interval: int, output: str | None, object_limit: int):
"""对象分析模式主函数"""
if not PYMPLER_AVAILABLE:
print("❌ pympler 未安装,无法使用对象分析模式")
print(" 安装: pip install pympler")
return 1
print("=" * 80)
print("🔬 对象分析模式")
print("=" * 80)
@@ -529,38 +527,38 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
print(" 3. 显示对象变化diff")
print(" 4. 保存 JSONL 数据用于可视化")
print("=" * 80 + "\n")
# 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent
if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root))
print(f"✓ 已添加项目根目录到 Python 路径: {project_root}\n")
profiler = ObjectMemoryProfiler(
interval=interval,
output_file=output,
object_limit=object_limit
)
profiler.start_monitoring()
print("🤖 正在启动 Bot...\n")
try:
import bot
if hasattr(bot, 'main_async'):
if hasattr(bot, "main_async"):
asyncio.run(bot.main_async())
elif hasattr(bot, 'main'):
elif hasattr(bot, "main"):
bot.main()
else:
print("⚠️ bot.py 未找到 main_async() 或 main() 函数")
print(" Bot 模块已导入,监控线程在后台运行")
print(" 按 Ctrl+C 停止\n")
while profiler.running:
time.sleep(1)
except KeyboardInterrupt:
print("\n\n⚠️ 用户中断")
except Exception as e:
@@ -569,7 +567,7 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
traceback.print_exc()
finally:
profiler.stop()
return 0
@@ -577,10 +575,10 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
# 可视化模式
# ============================================================================
def load_jsonl(path: Path) -> List[Dict]:
def load_jsonl(path: Path) -> list[dict]:
"""加载 JSONL 文件"""
snapshots = []
with open(path, "r", encoding="utf-8") as f:
with open(path, encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
@@ -592,7 +590,7 @@ def load_jsonl(path: Path) -> List[Dict]:
return snapshots
def aggregate_top_types(snapshots: List[Dict], top_n: int = 10):
def aggregate_top_types(snapshots: list[dict], top_n: int = 10):
"""聚合前 N 个对象类型的时间序列"""
type_max = defaultdict(int)
for snap in snapshots:
@@ -600,37 +598,37 @@ def aggregate_top_types(snapshots: List[Dict], top_n: int = 10):
t = item.get("type")
s = int(item.get("size", 0))
type_max[t] = max(type_max[t], s)
top_types = sorted(type_max.items(), key=lambda kv: kv[1], reverse=True)[:top_n]
top_names = [t for t, _ in top_types]
times = []
series = {t: [] for t in top_names}
for snap in snapshots:
ts = snap.get("timestamp")
try:
times.append(datetime.strptime(ts, "%Y-%m-%d %H:%M:%S"))
except Exception:
times.append(None)
summary = {item.get("type"): int(item.get("size", 0))
summary = {item.get("type"): int(item.get("size", 0))
for item in snap.get("summary", [])}
for t in top_names:
series[t].append(summary.get(t, 0) / 1024.0 / 1024.0)
return times, series
def plot_series(times: List, series: Dict, output: Path, top_n: int):
def plot_series(times: list, series: dict, output: Path, top_n: int):
"""绘制时间序列图"""
plt.figure(figsize=(14, 8))
for name, values in series.items():
if all(v == 0 for v in values):
continue
plt.plot(times, values, marker="o", label=name, linewidth=2)
plt.xlabel("时间", fontsize=12)
plt.ylabel("内存 (MB)", fontsize=12)
plt.title(f"对象类型随时间的内存占用 (前 {top_n} 类型)", fontsize=14)
@@ -647,31 +645,31 @@ def run_visualize_mode(input_file: str, output_file: str, top: int):
print("❌ matplotlib 未安装,无法使用可视化模式")
print(" 安装: pip install matplotlib")
return 1
print("=" * 80)
print("📊 可视化模式")
print("=" * 80)
path = Path(input_file)
if not path.exists():
print(f"❌ 找不到输入文件: {path}")
return 1
print(f"📂 读取数据: {path}")
snaps = load_jsonl(path)
if not snaps:
print("❌ 未读取到任何快照数据")
return 1
print(f"✓ 读取 {len(snaps)} 个快照")
times, series = aggregate_top_types(snaps, top_n=top)
print(f"✓ 提取前 {top} 个对象类型")
output_path = Path(output_file)
plot_series(times, series, output_path, top)
return 0
@@ -693,10 +691,10 @@ def main():
使用示例:
# 进程监控(启动 bot 并监控)
python scripts/memory_profiler.py --monitor --interval 10
# 对象分析(深度对象统计)
python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt
# 生成可视化图表
python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15 --output plot.png
@@ -705,26 +703,26 @@ def main():
- 可视化模式需要: pip install matplotlib
""",
)
# 模式选择
mode_group = parser.add_mutually_exclusive_group(required=True)
mode_group.add_argument("--monitor", "-m", action="store_true",
mode_group.add_argument("--monitor", "-m", action="store_true",
help="进程监控模式(外部监控 bot 进程)")
mode_group.add_argument("--objects", "-o", action="store_true",
mode_group.add_argument("--objects", "-o", action="store_true",
help="对象分析模式(内部统计所有对象)")
mode_group.add_argument("--visualize", "-v", action="store_true",
mode_group.add_argument("--visualize", "-v", action="store_true",
help="可视化模式(绘制 JSONL 数据)")
# 通用参数
parser.add_argument("--interval", "-i", type=int, default=10,
help="监控间隔(秒),默认 10")
# 对象分析参数
parser.add_argument("--output", type=str,
help="输出文件路径(对象分析模式)")
parser.add_argument("--object-limit", "-l", type=int, default=20,
help="对象类型显示数量,默认 20")
# 可视化参数
parser.add_argument("--input", type=str,
help="输入 JSONL 文件(可视化模式)")
@@ -732,24 +730,24 @@ def main():
help="展示前 N 个类型(可视化模式),默认 10")
parser.add_argument("--plot-output", type=str, default="memory_analysis_plot.png",
help="图表输出文件,默认 memory_analysis_plot.png")
args = parser.parse_args()
# 根据模式执行
if args.monitor:
return asyncio.run(run_monitor_mode(args.interval))
elif args.objects:
if not args.output:
print("⚠️ 建议使用 --output 指定输出文件以保存数据")
return run_objects_mode(args.interval, args.output, args.object_limit)
elif args.visualize:
if not args.input:
print("❌ 可视化模式需要 --input 参数指定 JSONL 文件")
return 1
return run_visualize_mode(args.input, args.plot_output, args.top)
return 0

View File

@@ -680,9 +680,9 @@ class EmojiManager:
try:
# 🔧 使用 QueryBuilder 以启用数据库缓存
from src.common.database.api.query import QueryBuilder
logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_instances = await QueryBuilder(Emoji).all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
@@ -802,7 +802,7 @@ class EmojiManager:
# 如果内存中没有,从数据库查找(使用 QueryBuilder 启用数据库缓存)
try:
from src.common.database.api.query import QueryBuilder
emoji_record = await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first()
if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
@@ -966,7 +966,7 @@ class EmojiManager:
existing_description = None
try:
from src.common.database.api.query import QueryBuilder
existing_image = await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first()
if existing_image and existing_image.description:
existing_description = existing_image.description

View File

@@ -1,5 +1,4 @@
import os
import random
import time
from datetime import datetime
from typing import Any
@@ -135,20 +134,20 @@ class ExpressionLearner:
async def cleanup_expired_expressions(self, expiration_days: int | None = None) -> int:
"""
清理过期的表达方式
Args:
expiration_days: 过期天数,超过此天数未激活的表达方式将被删除(不指定则从配置读取)
Returns:
int: 删除的表达方式数量
"""
# 从配置读取过期天数
if expiration_days is None:
expiration_days = global_config.expression.expiration_days
current_time = time.time()
expiration_threshold = current_time - (expiration_days * 24 * 3600)
try:
deleted_count = 0
async with get_db_session() as session:
@@ -160,15 +159,15 @@ class ExpressionLearner:
)
)
expired_expressions = list(query.scalars())
if expired_expressions:
for expr in expired_expressions:
await session.delete(expr)
deleted_count += 1
await session.commit()
logger.info(f"清理了 {deleted_count} 个过期表达方式(超过 {expiration_days} 天未使用)")
# 清除缓存
from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key
@@ -176,7 +175,7 @@ class ExpressionLearner:
await cache.delete(generate_cache_key("chat_expressions", self.chat_id))
else:
logger.debug(f"没有发现过期的表达方式(阈值:{expiration_days} 天)")
return deleted_count
except Exception as e:
logger.error(f"清理过期表达方式失败: {e}")
@@ -460,7 +459,7 @@ class ExpressionLearner:
)
)
same_situation_expr = query_same_situation.scalar()
# 情况2相同 chat_id + type + style相同表达不同情景
query_same_style = await session.execute(
select(Expression).where(
@@ -470,7 +469,7 @@ class ExpressionLearner:
)
)
same_style_expr = query_same_style.scalar()
# 情况3完全相同相同情景+相同表达)
query_exact_match = await session.execute(
select(Expression).where(
@@ -481,7 +480,7 @@ class ExpressionLearner:
)
)
exact_match_expr = query_exact_match.scalar()
# 优先处理完全匹配的情况
if exact_match_expr:
# 完全相同增加count更新时间

View File

@@ -72,21 +72,21 @@ class ExpressorModel:
是否删除成功
"""
removed = False
if cid in self._candidates:
del self._candidates[cid]
removed = True
if cid in self._situations:
del self._situations[cid]
# 从nb模型中删除
if cid in self.nb.cls_counts:
del self.nb.cls_counts[cid]
if cid in self.nb.token_counts:
del self.nb.token_counts[cid]
return removed
def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]:

View File

@@ -72,7 +72,7 @@ class StyleLearner:
# 检查是否需要清理
current_count = len(self.style_to_id)
cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
if current_count >= cleanup_trigger:
if current_count >= self.max_styles:
# 已经达到最大限制,必须清理
@@ -109,7 +109,7 @@ class StyleLearner:
def _cleanup_styles(self):
"""
清理低价值的风格,为新风格腾出空间
清理策略:
1. 综合考虑使用次数和最后使用时间
2. 删除得分最低的风格
@@ -118,34 +118,34 @@ class StyleLearner:
try:
current_time = time.time()
cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio))
# 计算每个风格的价值分数
style_scores = []
for style_id in self.style_to_id.values():
# 使用次数
usage_count = self.learning_stats["style_counts"].get(style_id, 0)
# 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float('inf')
time_since_used = current_time - last_used if last_used > 0 else float("inf")
# 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响
import math
usage_score = math.log1p(usage_count) # log(1 + count)
# 时间分数:转换为天数,使用指数衰减
days_unused = time_since_used / 86400 # 转换为天
time_score = math.exp(-days_unused / 30) # 30天衰减因子
# 综合分数80%使用频率 + 20%时间新鲜度
total_score = 0.8 * usage_score + 0.2 * time_score
style_scores.append((style_id, total_score, usage_count, days_unused))
# 按分数排序,分数低的先删除
style_scores.sort(key=lambda x: x[1])
# 删除分数最低的风格
deleted_styles = []
for style_id, score, usage, days in style_scores[:cleanup_count]:
@@ -156,27 +156,27 @@ class StyleLearner:
del self.id_to_style[style_id]
if style_id in self.id_to_situation:
del self.id_to_situation[style_id]
# 从统计中删除
if style_id in self.learning_stats["style_counts"]:
del self.learning_stats["style_counts"][style_id]
if style_id in self.learning_stats["style_last_used"]:
del self.learning_stats["style_last_used"][style_id]
# 从expressor模型中删除
self.expressor.remove_candidate(style_id)
deleted_styles.append((style_text[:30], usage, f"{days:.1f}"))
logger.info(
f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格,"
f"剩余 {len(self.style_to_id)} 个风格"
)
# 记录前5个被删除的风格用于调试
if deleted_styles:
logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}")
except Exception as e:
logger.error(f"清理风格失败: {e}", exc_info=True)
@@ -303,10 +303,10 @@ class StyleLearner:
def cleanup_old_styles(self, ratio: float | None = None) -> int:
"""
手动清理旧风格
Args:
ratio: 清理比例如果为None则使用默认的cleanup_ratio
Returns:
清理的风格数量
"""
@@ -318,7 +318,7 @@ class StyleLearner:
self.cleanup_ratio = old_cleanup_ratio
else:
self._cleanup_styles()
new_count = len(self.style_to_id)
cleaned = old_count - new_count
logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格")
@@ -357,11 +357,11 @@ class StyleLearner:
import pickle
meta_path = os.path.join(save_dir, "meta.pkl")
# 确保 learning_stats 包含所有必要字段
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
meta_data = {
"style_to_id": self.style_to_id,
"id_to_style": self.id_to_style,
@@ -416,7 +416,7 @@ class StyleLearner:
self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"]
# 确保旧数据兼容:如果没有 style_last_used 字段,添加它
if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {}
@@ -526,10 +526,10 @@ class StyleLearnerManager:
def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]:
"""
对所有学习器清理旧风格
Args:
ratio: 清理比例
Returns:
{chat_id: 清理数量}
"""
@@ -538,7 +538,7 @@ class StyleLearnerManager:
cleaned = learner.cleanup_old_styles(ratio)
if cleaned > 0:
cleanup_results[chat_id] = cleaned
total_cleaned = sum(cleanup_results.values())
logger.info(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格")
return cleanup_results

View File

@@ -8,7 +8,6 @@ from datetime import datetime
from typing import Any
import numpy as np
import orjson
from sqlalchemy import select
from src.common.config_helpers import resolve_embedding_dimension
@@ -124,7 +123,7 @@ class BotInterestManager:
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()]
tags_str = "\n".join(tags_info)
logger.info(f"当前兴趣标签:\n{tags_str}")
# 为加载的标签生成embedding数据库不存储embedding启动时动态生成
logger.info("🧠 为加载的标签生成embedding向量...")
await self._generate_embeddings_for_tags(loaded_interests)
@@ -326,13 +325,13 @@ class BotInterestManager:
raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding")
total_tags = len(interests.interest_tags)
# 尝试从文件加载缓存
file_cache = await self._load_embedding_cache_from_file(interests.personality_id)
if file_cache:
logger.info(f"📂 从文件加载 {len(file_cache)} 个embedding缓存")
self.embedding_cache.update(file_cache)
logger.info(f"🧠 开始为 {total_tags} 个兴趣标签生成embedding向量...")
memory_cached_count = 0
@@ -477,14 +476,14 @@ class BotInterestManager:
self, message_text: str, keywords: list[str] | None = None
) -> InterestMatchResult:
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
核心优化:将短标签扩展为完整的描述性句子,解决语义粒度不匹配问题
原问题:
- 消息: "今天天气不错" (完整句子)
- 标签: "蹭人治愈" (2-4字短语)
- 标签: "蹭人治愈" (2-4字短语)
- 结果: 误匹配,因为短标签的 embedding 过于抽象
解决方案:
- 标签扩展: "蹭人治愈" -> "表达亲近、寻求安慰、撒娇的内容"
- 现在是: 句子 vs 句子,匹配更准确
@@ -527,18 +526,18 @@ class BotInterestManager:
if tag.embedding:
# 🔧 优化:获取扩展标签的 embedding带缓存
expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name)
if expanded_embedding:
# 使用扩展标签的 embedding 进行匹配
similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding)
# 同时计算原始标签的相似度作为参考
original_similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
# 混合策略扩展标签权重更高70%原始标签作为补充30%
# 这样可以兼顾准确性(扩展)和灵活性(原始)
final_similarity = similarity * 0.7 + original_similarity * 0.3
logger.debug(f"标签'{tag.tag_name}': 原始={original_similarity:.3f}, 扩展={similarity:.3f}, 最终={final_similarity:.3f}")
else:
# 如果扩展 embedding 获取失败,使用原始 embedding
@@ -603,27 +602,27 @@ class BotInterestManager:
logger.debug(
f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
)
# 如果有新生成的扩展embedding保存到缓存文件
if hasattr(self, '_new_expanded_embeddings_generated') and self._new_expanded_embeddings_generated:
if hasattr(self, "_new_expanded_embeddings_generated") and self._new_expanded_embeddings_generated:
await self._save_embedding_cache_to_file(self.current_interests.personality_id)
self._new_expanded_embeddings_generated = False
logger.debug("💾 已保存新生成的扩展embedding到缓存文件")
return result
async def _get_expanded_tag_embedding(self, tag_name: str) -> list[float] | None:
"""获取扩展标签的 embedding带缓存
优先使用缓存,如果没有则生成并缓存
"""
# 检查缓存
if tag_name in self.expanded_embedding_cache:
return self.expanded_embedding_cache[tag_name]
# 扩展标签
expanded_tag = self._expand_tag_for_matching(tag_name)
# 生成 embedding
try:
embedding = await self._get_embedding(expanded_tag)
@@ -636,19 +635,19 @@ class BotInterestManager:
return embedding
except Exception as e:
logger.warning(f"为标签'{tag_name}'生成扩展embedding失败: {e}")
return None
def _expand_tag_for_matching(self, tag_name: str) -> str:
"""将短标签扩展为完整的描述性句子
这是解决"标签太短导致误匹配"的核心方法
策略:
1. 优先使用 LLM 生成的 expanded 字段(最准确)
2. 如果没有,使用基于规则的回退方案
3. 最后使用通用模板
示例:
- "Python" + expanded -> "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
- "蹭人治愈" + expanded -> "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
@@ -656,7 +655,7 @@ class BotInterestManager:
# 使用缓存
if tag_name in self.expanded_tag_cache:
return self.expanded_tag_cache[tag_name]
# 🎯 优先策略:使用 LLM 生成的 expanded 字段
if self.current_interests:
for tag in self.current_interests.interest_tags:
@@ -664,66 +663,66 @@ class BotInterestManager:
logger.debug(f"✅ 使用LLM生成的扩展描述: {tag_name} -> {tag.expanded[:50]}...")
self.expanded_tag_cache[tag_name] = tag.expanded
return tag.expanded
# 🔧 回退策略基于规则的扩展用于兼容旧数据或LLM未生成扩展的情况
logger.debug(f"⚠️ 标签'{tag_name}'没有LLM扩展描述使用规则回退方案")
tag_lower = tag_name.lower()
# 技术编程类标签(具体化描述)
if any(word in tag_lower for word in ['python', 'java', 'code', '代码', '编程', '脚本', '算法', '开发']):
if 'python' in tag_lower:
return f"讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
elif '算法' in tag_lower:
return f"讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化"
elif '代码' in tag_lower or '被窝' in tag_lower:
return f"讨论写代码、编程开发、代码实现、技术方案、编程技巧"
if any(word in tag_lower for word in ["python", "java", "code", "代码", "编程", "脚本", "算法", "开发"]):
if "python" in tag_lower:
return "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
elif "算法" in tag_lower:
return "讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化"
elif "代码" in tag_lower or "被窝" in tag_lower:
return "讨论写代码、编程开发、代码实现、技术方案、编程技巧"
else:
return f"讨论编程开发、软件技术、代码编写、技术实现"
return "讨论编程开发、软件技术、代码编写、技术实现"
# 情感表达类标签(具体化为真实对话场景)
elif any(word in tag_lower for word in ['治愈', '撒娇', '安慰', '呼噜', '', '卖萌']):
return f"想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
elif any(word in tag_lower for word in ["治愈", "撒娇", "安慰", "呼噜", "", "卖萌"]):
return "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
# 游戏娱乐类标签(具体游戏场景)
elif any(word in tag_lower for word in ['游戏', '网游', 'mmo', '', '']):
return f"讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得"
elif any(word in tag_lower for word in ["游戏", "网游", "mmo", "", ""]):
return "讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得"
# 动漫影视类标签(具体观看行为)
elif any(word in tag_lower for word in ['', '动漫', '视频', 'b站', '弹幕', '追番', '云新番']):
elif any(word in tag_lower for word in ["", "动漫", "视频", "b站", "弹幕", "追番", "云新番"]):
# 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西"
if '' in tag_lower or '新番' in tag_lower:
return f"讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色"
if "" in tag_lower or "新番" in tag_lower:
return "讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色"
else:
return f"讨论动漫番剧内容、B站视频、弹幕文化、追番体验"
return "讨论动漫番剧内容、B站视频、弹幕文化、追番体验"
# 社交平台类标签(具体平台行为)
elif any(word in tag_lower for word in ['小红书', '贴吧', '论坛', '社区', '吃瓜', '八卦']):
if '吃瓜' in tag_lower:
return f"聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题"
elif any(word in tag_lower for word in ["小红书", "贴吧", "论坛", "社区", "吃瓜", "八卦"]):
if "吃瓜" in tag_lower:
return "聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题"
else:
return f"讨论社交平台内容、网络社区话题、论坛讨论、分享生活"
return "讨论社交平台内容、网络社区话题、论坛讨论、分享生活"
# 生活日常类标签(具体萌宠场景)
elif any(word in tag_lower for word in ['', '宠物', '尾巴', '耳朵', '毛绒']):
return f"讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得"
elif any(word in tag_lower for word in ["", "宠物", "尾巴", "耳朵", "毛绒"]):
return "讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得"
# 状态心情类标签(具体情绪状态)
elif any(word in tag_lower for word in ['社恐', '隐身', '流浪', '深夜', '被窝']):
if '社恐' in tag_lower:
return f"表达社交焦虑、不想见人、想躲起来、害怕社交的心情"
elif '深夜' in tag_lower:
return f"深夜睡不着、熬夜、夜猫子、深夜思考人生的对话"
elif any(word in tag_lower for word in ["社恐", "隐身", "流浪", "深夜", "被窝"]):
if "社恐" in tag_lower:
return "表达社交焦虑、不想见人、想躲起来、害怕社交的心情"
elif "深夜" in tag_lower:
return "深夜睡不着、熬夜、夜猫子、深夜思考人生的对话"
else:
return f"表达当前心情状态、个人感受、生活状态"
return "表达当前心情状态、个人感受、生活状态"
# 物品装备类标签(具体使用场景)
elif any(word in tag_lower for word in ['键盘', '耳机', '装备', '设备']):
return f"讨论键盘耳机装备、数码产品、使用体验、装备推荐评测"
elif any(word in tag_lower for word in ["键盘", "耳机", "装备", "设备"]):
return "讨论键盘耳机装备、数码产品、使用体验、装备推荐评测"
# 互动关系类标签
elif any(word in tag_lower for word in ['拾风', '互怼', '互动']):
return f"聊天互动、开玩笑、友好互怼、日常对话交流"
elif any(word in tag_lower for word in ["拾风", "互怼", "互动"]):
return "聊天互动、开玩笑、友好互怼、日常对话交流"
# 默认:尽量具体化
else:
return f"明确讨论{tag_name}这个特定主题的具体内容和相关话题"
@@ -1011,56 +1010,58 @@ class BotInterestManager:
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None:
"""从文件加载embedding缓存"""
try:
import orjson
from pathlib import Path
import orjson
cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json"
if not cache_file.exists():
logger.debug(f"📂 Embedding缓存文件不存在: {cache_file}")
return None
# 读取缓存文件
with open(cache_file, "rb") as f:
cache_data = orjson.loads(f.read())
# 验证缓存版本和embedding模型
cache_version = cache_data.get("version", 1)
cache_embedding_model = cache_data.get("embedding_model", "")
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") else ""
if cache_embedding_model != current_embedding_model:
logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model}{current_embedding_model}),忽略旧缓存")
return None
embeddings = cache_data.get("embeddings", {})
# 同时加载扩展标签的embedding缓存
expanded_embeddings = cache_data.get("expanded_embeddings", {})
if expanded_embeddings:
self.expanded_embedding_cache.update(expanded_embeddings)
logger.info(f"📂 加载 {len(expanded_embeddings)} 个扩展标签embedding缓存")
logger.info(f"✅ 成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
return embeddings
except Exception as e:
logger.warning(f"⚠️ 加载embedding缓存文件失败: {e}")
return None
async def _save_embedding_cache_to_file(self, personality_id: str):
"""保存embedding缓存到文件包括扩展标签的embedding"""
try:
import orjson
from pathlib import Path
from datetime import datetime
from pathlib import Path
import orjson
cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json"
# 准备缓存数据
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list else ""
cache_data = {
@@ -1071,13 +1072,13 @@ class BotInterestManager:
"embeddings": self.embedding_cache,
"expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding
}
# 写入文件
with open(cache_file, "wb") as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2))
logger.debug(f"💾 已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}")
except Exception as e:
logger.warning(f"⚠️ 保存embedding缓存文件失败: {e}")

View File

@@ -9,8 +9,8 @@ from .scheduler_dispatcher import SchedulerDispatcher, scheduler_dispatcher
__all__ = [
"MessageManager",
"SingleStreamContextManager",
"SchedulerDispatcher",
"SingleStreamContextManager",
"message_manager",
"scheduler_dispatcher",
]

View File

@@ -73,7 +73,7 @@ class SingleStreamContextManager:
cache_enabled = global_config.chat.enable_message_cache
use_cache_system = message_manager.is_running and cache_enabled
if not cache_enabled:
logger.debug(f"消息缓存系统已在配置中禁用")
logger.debug("消息缓存系统已在配置中禁用")
except Exception as e:
logger.debug(f"MessageManager不可用使用直接添加: {e}")
use_cache_system = False
@@ -129,13 +129,13 @@ class SingleStreamContextManager:
await self._calculate_message_interest(message)
self.total_messages += 1
self.last_access_time = time.time()
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True
# 不应该到达这里,但为了类型检查添加返回值
return True
except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False

View File

@@ -4,13 +4,11 @@
"""
import asyncio
import random
import time
from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
@@ -77,7 +75,7 @@ class MessageManager:
# 启动基于 scheduler 的消息分发器
await scheduler_dispatcher.start()
scheduler_dispatcher.set_chatter_manager(self.chatter_manager)
# 保留旧的流循环管理器(暂时)以便平滑过渡
# TODO: 在确认新机制稳定后移除
# await stream_loop_manager.start()
@@ -108,7 +106,7 @@ class MessageManager:
# 停止基于 scheduler 的消息分发器
await scheduler_dispatcher.stop()
# 停止旧的流循环管理器(如果启用)
# await stream_loop_manager.stop()
@@ -116,7 +114,7 @@ class MessageManager:
async def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流
新的流程:
1. 检查 notice 消息
2. 将消息添加到上下文(缓存)
@@ -149,10 +147,10 @@ class MessageManager:
if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return
# 将消息添加到上下文
await chat_stream.context_manager.add_message(message)
# 通知 scheduler_dispatcher 处理消息接收事件
# dispatcher 会检查是否需要打断、创建或更新 schedule
await scheduler_dispatcher.on_message_received(stream_id)

View File

@@ -20,7 +20,7 @@ logger = get_logger("scheduler_dispatcher")
class SchedulerDispatcher:
"""基于 scheduler 的消息分发器
工作流程:
1. 接收消息时,将消息添加到聊天流上下文
2. 检查是否有活跃的 schedule如果没有则创建
@@ -32,13 +32,13 @@ class SchedulerDispatcher:
def __init__(self):
# 追踪每个流的 schedule_id
self.stream_schedules: dict[str, str] = {} # stream_id -> schedule_id
# 用于保护 schedule 创建/删除的锁,避免竞态条件
self.schedule_locks: dict[str, asyncio.Lock] = {} # stream_id -> Lock
# Chatter 管理器
self.chatter_manager: ChatterManager | None = None
# 统计信息
self.stats = {
"total_schedules_created": 0,
@@ -48,9 +48,9 @@ class SchedulerDispatcher:
"total_failures": 0,
"start_time": time.time(),
}
self.is_running = False
logger.info("基于 Scheduler 的消息分发器初始化完成")
async def start(self) -> None:
@@ -58,7 +58,7 @@ class SchedulerDispatcher:
if self.is_running:
logger.warning("分发器已在运行")
return
self.is_running = True
logger.info("基于 Scheduler 的消息分发器已启动")
@@ -66,9 +66,9 @@ class SchedulerDispatcher:
"""停止分发器"""
if not self.is_running:
return
self.is_running = False
# 取消所有活跃的 schedule
schedule_ids = list(self.stream_schedules.values())
for schedule_id in schedule_ids:
@@ -76,7 +76,7 @@ class SchedulerDispatcher:
await unified_scheduler.remove_schedule(schedule_id)
except Exception as e:
logger.error(f"移除 schedule {schedule_id} 失败: {e}")
self.stream_schedules.clear()
logger.info("基于 Scheduler 的消息分发器已停止")
@@ -84,7 +84,7 @@ class SchedulerDispatcher:
"""设置 Chatter 管理器"""
self.chatter_manager = chatter_manager
logger.debug(f"设置 Chatter 管理器: {chatter_manager.__class__.__name__}")
def _get_schedule_lock(self, stream_id: str) -> asyncio.Lock:
"""获取流的 schedule 锁"""
if stream_id not in self.schedule_locks:
@@ -93,40 +93,40 @@ class SchedulerDispatcher:
async def on_message_received(self, stream_id: str) -> None:
"""消息接收时的处理逻辑
Args:
stream_id: 聊天流ID
"""
if not self.is_running:
logger.warning("分发器未运行,忽略消息")
return
try:
# 1. 获取流上下文
context = await self._get_stream_context(stream_id)
if not context:
logger.warning(f"无法获取流上下文: {stream_id}")
return
# 2. 检查是否有活跃的 schedule
has_active_schedule = stream_id in self.stream_schedules
if not has_active_schedule:
# 4. 创建新的 schedule在锁内避免重复创建
await self._create_schedule(stream_id, context)
return
# 3. 检查打断判定
if has_active_schedule:
should_interrupt = await self._check_interruption(stream_id, context)
if should_interrupt:
# 移除旧 schedule 并创建新的(内部有锁保护)
await self._cancel_and_recreate_schedule(stream_id, context)
logger.debug(f"⚡ 打断成功: 流={stream_id[:8]}..., 已重新创建 schedule")
else:
logger.debug(f"打断判定失败,保持原有 schedule: 流={stream_id[:8]}...")
except Exception as e:
logger.error(f"处理消息接收事件失败 {stream_id}: {e}", exc_info=True)
@@ -144,18 +144,18 @@ class SchedulerDispatcher:
async def _check_interruption(self, stream_id: str, context: StreamContext) -> bool:
"""检查是否应该打断当前处理
Args:
stream_id: 流ID
context: 流上下文
Returns:
bool: 是否应该打断
"""
# 检查是否启用打断
if not global_config.chat.interruption_enabled:
return False
# 检查是否正在回复,以及是否允许在回复时打断
if context.is_replying:
if not global_config.chat.allow_reply_interruption:
@@ -163,49 +163,49 @@ class SchedulerDispatcher:
return False
else:
logger.debug(f"聊天流 {stream_id} 正在回复中,但配置允许回复时打断")
# 只有当 Chatter 真正在处理时才检查打断
if not context.is_chatter_processing:
logger.debug(f"聊天流 {stream_id} Chatter 未在处理,无需打断")
return False
# 检查最后一条消息
last_message = context.get_last_message()
if not last_message:
return False
# 检查是否为表情包消息
if last_message.is_picid or last_message.is_emoji:
logger.info(f"消息 {last_message.message_id} 是表情包或Emoji跳过打断检查")
return False
# 检查触发用户ID
triggering_user_id = context.triggering_user_id
if triggering_user_id and last_message.user_info.user_id != triggering_user_id:
logger.info(f"消息来自非触发用户 {last_message.user_info.user_id},实际触发用户为 {triggering_user_id},跳过打断检查")
return False
# 检查是否已达到最大打断次数
if context.interruption_count >= global_config.chat.interruption_max_limit:
logger.debug(
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit}"
)
return False
# 计算打断概率
interruption_probability = context.calculate_interruption_probability(
global_config.chat.interruption_max_limit
)
# 根据概率决定是否打断
import random
if random.random() < interruption_probability:
logger.debug(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
# 增加打断计数
await context.increment_interruption_count()
self.stats["total_interruptions"] += 1
# 检查是否已达到最大次数
if context.interruption_count >= global_config.chat.interruption_max_limit:
logger.warning(
@@ -215,7 +215,7 @@ class SchedulerDispatcher:
logger.info(
f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}"
)
return True
else:
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
@@ -223,7 +223,7 @@ class SchedulerDispatcher:
async def _cancel_and_recreate_schedule(self, stream_id: str, context: StreamContext) -> None:
"""取消旧的 schedule 并创建新的(打断模式,使用极短延迟)
Args:
stream_id: 流ID
context: 流上下文
@@ -244,13 +244,13 @@ class SchedulerDispatcher:
)
# 移除失败,不创建新 schedule避免重复
return
# 创建新的 schedule使用即时处理模式极短延迟
await self._create_schedule(stream_id, context, immediate_mode=True)
async def _create_schedule(self, stream_id: str, context: StreamContext, immediate_mode: bool = False) -> None:
"""为聊天流创建新的 schedule
Args:
stream_id: 流ID
context: 流上下文
@@ -266,7 +266,7 @@ class SchedulerDispatcher:
)
await unified_scheduler.remove_schedule(old_schedule_id)
del self.stream_schedules[stream_id]
# 如果是即时处理模式打断时使用固定的1秒延迟立即重新处理
if immediate_mode:
delay = 1.0 # 硬编码1秒延迟确保打断后能快速重新处理
@@ -277,10 +277,10 @@ class SchedulerDispatcher:
else:
# 常规模式:计算初始延迟
delay = await self._calculate_initial_delay(stream_id, context)
# 获取未读消息数量用于日志
unread_count = len(context.unread_messages) if context.unread_messages else 0
# 创建 schedule
schedule_id = await unified_scheduler.create_schedule(
callback=self._on_schedule_triggered,
@@ -290,41 +290,41 @@ class SchedulerDispatcher:
task_name=f"dispatch_{stream_id[:8]}",
callback_args=(stream_id,),
)
# 追踪 schedule
self.stream_schedules[stream_id] = schedule_id
self.stats["total_schedules_created"] += 1
mode_indicator = "⚡打断" if immediate_mode else "📅常规"
logger.info(
f"{mode_indicator} 创建 schedule: 流={stream_id[:8]}..., "
f"延迟={delay:.3f}s, 未读={unread_count}, "
f"ID={schedule_id[:8]}..."
)
except Exception as e:
logger.error(f"创建 schedule 失败 {stream_id}: {e}", exc_info=True)
async def _calculate_initial_delay(self, stream_id: str, context: StreamContext) -> float:
"""计算初始延迟时间
Args:
stream_id: 流ID
context: 流上下文
Returns:
float: 延迟时间(秒)
"""
# 基础间隔
base_interval = getattr(global_config.chat, "distribution_interval", 5.0)
# 检查是否有未读消息
unread_count = len(context.unread_messages) if context.unread_messages else 0
# 强制分发阈值
force_dispatch_threshold = getattr(global_config.chat, "force_dispatch_unread_threshold", 20)
# 如果未读消息过多,使用最小间隔
if force_dispatch_threshold and unread_count > force_dispatch_threshold:
min_interval = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
@@ -334,24 +334,24 @@ class SchedulerDispatcher:
f"使用最小间隔={min_interval}s"
)
return min_interval
# 尝试使用能量管理器计算间隔
try:
# 更新能量值
await self._update_stream_energy(stream_id, context)
# 获取当前 focus_energy
focus_energy = energy_manager.energy_cache.get(stream_id, (0.5, 0))[0]
# 使用能量管理器计算间隔
interval = energy_manager.get_distribution_interval(focus_energy)
logger.info(
f"📊 动态间隔计算: 流={stream_id[:8]}..., "
f"能量={focus_energy:.3f}, 间隔={interval:.2f}s"
)
return interval
except Exception as e:
logger.info(
f"📊 使用默认间隔: 流={stream_id[:8]}..., "
@@ -361,96 +361,96 @@ class SchedulerDispatcher:
async def _update_stream_energy(self, stream_id: str, context: StreamContext) -> None:
"""更新流的能量值
Args:
stream_id: 流ID
context: 流上下文
"""
try:
from src.chat.message_receive.chat_stream import get_chat_manager
# 获取聊天流
chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream:
logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新")
return
# 合并未读消息和历史消息
all_messages = []
# 添加历史消息
history_messages = context.get_history_messages(limit=global_config.chat.max_context_size)
all_messages.extend(history_messages)
# 添加未读消息
unread_messages = context.get_unread_messages()
all_messages.extend(unread_messages)
# 按时间排序并限制数量
all_messages.sort(key=lambda m: m.time)
messages = all_messages[-global_config.chat.max_context_size:]
# 获取用户ID
user_id = context.triggering_user_id
# 使用能量管理器计算并缓存能量值
energy = await energy_manager.calculate_focus_energy(
stream_id=stream_id,
messages=messages,
user_id=user_id
)
# 同步更新到 ChatStream
chat_stream._focus_energy = energy
logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}")
except Exception as e:
logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False)
async def _on_schedule_triggered(self, stream_id: str) -> None:
"""schedule 触发时的回调
Args:
stream_id: 流ID
"""
try:
old_schedule_id = self.stream_schedules.get(stream_id)
logger.info(
f"⏰ Schedule 触发: 流={stream_id[:8]}..., "
f"ID={old_schedule_id[:8] if old_schedule_id else 'None'}..., "
f"开始处理消息"
)
# 获取流上下文
context = await self._get_stream_context(stream_id)
if not context:
logger.warning(f"Schedule 触发时无法获取流上下文: {stream_id}")
return
# 检查是否有未读消息
if not context.unread_messages:
logger.debug(f"{stream_id} 没有未读消息,跳过处理")
return
# 激活 chatter 处理(不需要锁,允许并发处理)
success = await self._process_stream(stream_id, context)
# 更新统计
self.stats["total_process_cycles"] += 1
if not success:
self.stats["total_failures"] += 1
self.stream_schedules.pop(stream_id, None)
# 检查缓存中是否有待处理的消息
from src.chat.message_manager.message_manager import message_manager
has_cached = message_manager.has_cached_messages(stream_id)
if has_cached:
# 有缓存消息,立即创建新 schedule 继续处理
logger.info(
@@ -464,60 +464,60 @@ class SchedulerDispatcher:
f"✅ 处理完成且无缓存消息: 流={stream_id[:8]}..., "
f"等待新消息到达"
)
except Exception as e:
logger.error(f"Schedule 回调执行失败 {stream_id}: {e}", exc_info=True)
async def _process_stream(self, stream_id: str, context: StreamContext) -> bool:
"""处理流消息
Args:
stream_id: 流ID
context: 流上下文
Returns:
bool: 是否处理成功
"""
if not self.chatter_manager:
logger.warning(f"Chatter 管理器未设置: {stream_id}")
return False
# 设置处理状态
self._set_stream_processing_status(stream_id, True)
try:
start_time = time.time()
# 设置触发用户ID
last_message = context.get_last_message()
if last_message:
context.triggering_user_id = last_message.user_info.user_id
# 创建异步任务刷新能量(不阻塞主流程)
energy_task = asyncio.create_task(self._refresh_focus_energy(stream_id))
# 设置 Chatter 正在处理的标志
context.is_chatter_processing = True
logger.debug(f"设置 Chatter 处理标志: {stream_id}")
try:
# 调用 chatter_manager 处理流上下文
results = await self.chatter_manager.process_stream_context(stream_id, context)
success = results.get("success", False)
if success:
process_time = time.time() - start_time
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
else:
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
return success
finally:
# 清除 Chatter 处理标志
context.is_chatter_processing = False
logger.debug(f"清除 Chatter 处理标志: {stream_id}")
# 等待能量刷新任务完成
try:
await asyncio.wait_for(energy_task, timeout=5.0)
@@ -525,11 +525,11 @@ class SchedulerDispatcher:
logger.warning(f"等待能量刷新超时: {stream_id}")
except Exception as e:
logger.debug(f"能量刷新任务异常: {e}")
except Exception as e:
logger.error(f"流处理异常: {stream_id} - {e}", exc_info=True)
return False
finally:
# 设置处理状态为未处理
self._set_stream_processing_status(stream_id, False)
@@ -538,11 +538,11 @@ class SchedulerDispatcher:
"""设置流的处理状态"""
try:
from src.chat.message_manager.message_manager import message_manager
if message_manager.is_running:
message_manager.set_stream_processing_status(stream_id, is_processing)
logger.debug(f"设置流处理状态: stream={stream_id}, processing={is_processing}")
except ImportError:
logger.debug("MessageManager 不可用,跳过状态设置")
except Exception as e:
@@ -556,7 +556,7 @@ class SchedulerDispatcher:
if not chat_stream:
logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
return
await chat_stream.context_manager.refresh_focus_energy_from_history()
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
except Exception as e:

View File

@@ -367,7 +367,7 @@ class ChatBot:
message_segment = message_data.get("message_segment")
if message_segment and isinstance(message_segment, dict):
if message_segment.get("type") == "adapter_response":
logger.info(f"[DEBUG bot.py message_process] 检测到adapter_response立即处理")
logger.info("[DEBUG bot.py message_process] 检测到adapter_response立即处理")
await self._handle_adapter_response_from_dict(message_segment.get("data"))
return

View File

@@ -205,7 +205,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
return result
else:
logger.warning(f"[at处理] 无法解析格式: '{segment.data}'")
return f"@{segment.data}"
return f"@{segment.data}"
logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}")
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"

View File

@@ -542,7 +542,7 @@ class DefaultReplyer:
all_memories = []
try:
from src.memory_graph.manager_singleton import get_memory_manager, is_initialized
if is_initialized():
manager = get_memory_manager()
if manager:
@@ -552,12 +552,12 @@ class DefaultReplyer:
sender_name = ""
if user_info_obj:
sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "")
# 获取参与者信息
participants = []
try:
# 尝试从聊天流中获取参与者信息
if hasattr(stream, 'chat_history_manager'):
if hasattr(stream, "chat_history_manager"):
history_manager = stream.chat_history_manager
# 获取最近的参与者列表
recent_records = history_manager.get_memory_chat_history(
@@ -586,16 +586,16 @@ class DefaultReplyer:
formatted_history = ""
if chat_history:
# 移除过长的历史记录,只保留最近部分
lines = chat_history.strip().split('\n')
lines = chat_history.strip().split("\n")
recent_lines = lines[-10:] if len(lines) > 10 else lines
formatted_history = '\n'.join(recent_lines)
formatted_history = "\n".join(recent_lines)
query_context = {
"chat_history": formatted_history,
"sender": sender_name,
"participants": participants,
}
# 使用记忆管理器的智能检索(多查询策略)
memories = await manager.search_memories(
query=target,
@@ -605,23 +605,23 @@ class DefaultReplyer:
use_multi_query=True,
context=query_context,
)
if memories:
logger.info(f"[记忆图] 检索到 {len(memories)} 条相关记忆")
# 使用新的格式化工具构建完整的记忆描述
from src.memory_graph.utils.memory_formatter import (
format_memory_for_prompt,
get_memory_type_label,
)
for memory in memories:
# 使用格式化工具生成完整的主谓宾描述
content = format_memory_for_prompt(memory, include_metadata=False)
# 获取记忆类型
mem_type = memory.memory_type.value if memory.memory_type else "未知"
if content:
all_memories.append({
"content": content,
@@ -636,7 +636,7 @@ class DefaultReplyer:
except Exception as e:
logger.debug(f"[记忆图] 检索失败: {e}")
all_memories = []
# 构建记忆字符串,使用方括号格式
memory_str = ""
has_any_memory = False
@@ -725,7 +725,7 @@ class DefaultReplyer:
for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result")
tool_result.get("type", "tool_result")
# 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}")
@@ -744,7 +744,7 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}")
return ""
def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
"""解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt
@@ -1897,7 +1897,7 @@ class DefaultReplyer:
async def _store_chat_memory_async(self, reply_to: str, reply_message: DatabaseMessages | dict[str, Any] | None = None):
"""
[已废弃] 异步存储聊天记忆从build_memory_block迁移而来
此函数已被记忆图系统的工具调用方式替代。
记忆现在由LLM在对话过程中通过CreateMemoryTool主动创建。
@@ -1906,14 +1906,13 @@ class DefaultReplyer:
reply_message: 回复的原始消息
"""
return # 已禁用,保留函数签名以防其他地方有引用
# 以下代码已废弃,不再执行
try:
if not global_config.memory.enable_memory:
return
# 使用统一记忆系统存储记忆
from src.chat.memory_system import get_memory_system
stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None)
@@ -2036,7 +2035,7 @@ class DefaultReplyer:
timestamp=time.time(),
limit=int(global_config.chat.max_context_size),
)
chat_history = await build_readable_messages(
await build_readable_messages(
message_list_before_short,
replace_bot_name=True,
merge_messages=False,

View File

@@ -400,7 +400,7 @@ class Prompt:
# 初始化预构建参数字典
pre_built_params = {}
try:
# --- 步骤 1: 准备构建任务 ---
tasks = []

View File

@@ -87,20 +87,18 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
)
processed_text = message.processed_plain_text or ""
# 1. 判断是否为私聊(强提及)
group_info = getattr(message, "group_info", None)
if not group_info or not getattr(group_info, "group_id", None):
is_private = True
mention_type = 2
logger.debug("检测到私聊消息 - 强提及")
# 2. 判断是否被@(强提及)
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text):
is_at = True
mention_type = 2
logger.debug("检测到@提及 - 强提及")
# 3. 判断是否被回复(强提及)
if re.match(
rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\)(.+?)\],说:", processed_text
@@ -108,10 +106,9 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:",
processed_text,
):
is_replied = True
mention_type = 2
logger.debug("检测到回复消息 - 强提及")
# 4. 判断文本中是否提及bot名字或别名弱提及
if mention_type == 0: # 只有在没有强提及时才检查弱提及
# 移除@和回复标记后再检查
@@ -119,21 +116,19 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\)(.+?)\],说:", "", message_content)
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", message_content)
# 检查bot主名字
if global_config.bot.nickname in message_content:
is_text_mentioned = True
mention_type = 1
logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及")
# 如果主名字没匹配,再检查别名
elif nicknames:
for alias_name in nicknames:
if alias_name in message_content:
is_text_mentioned = True
mention_type = 1
logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及")
break
# 返回结果
is_mentioned = mention_type > 0
return is_mentioned, float(mention_type)

View File

@@ -368,13 +368,13 @@ class CacheManager:
if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
def get_health_stats(self) -> dict[str, Any]:
"""获取缓存健康统计信息"""
# 简化的健康统计,不包含内存监控(因为相关属性未定义)
return {
"l1_count": len(self.l1_kv_cache),
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0,
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0,
"tool_stats": {
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
@@ -397,7 +397,7 @@ class CacheManager:
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
# 检查向量索引大小
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0
if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")

View File

@@ -66,7 +66,7 @@ class BatchStats:
last_batch_duration: float = 0.0
last_batch_size: int = 0
congestion_score: float = 0.0 # 拥塞评分 (0-1)
# 🔧 新增:缓存统计
cache_size: int = 0 # 缓存条目数
cache_memory_mb: float = 0.0 # 缓存内存占用MB
@@ -539,8 +539,7 @@ class AdaptiveBatchScheduler:
def _set_cache(self, cache_key: str, result: Any) -> None:
"""设置缓存(改进版,带大小限制和内存统计)"""
import sys
# 🔧 检查缓存大小限制
if len(self._result_cache) >= self._cache_max_size:
# 首先清理过期条目
@@ -549,18 +548,18 @@ class AdaptiveBatchScheduler:
k for k, (_, ts) in self._result_cache.items()
if current_time - ts >= self.cache_ttl
]
for k in expired_keys:
# 更新内存统计
if k in self._cache_size_map:
self._cache_memory_estimate -= self._cache_size_map[k]
del self._cache_size_map[k]
del self._result_cache[k]
# 如果还是太大清理最老的条目LRU
if len(self._result_cache) >= self._cache_max_size:
oldest_key = min(
self._result_cache.keys(),
self._result_cache.keys(),
key=lambda k: self._result_cache[k][1]
)
# 更新内存统计
@@ -569,7 +568,7 @@ class AdaptiveBatchScheduler:
del self._cache_size_map[oldest_key]
del self._result_cache[oldest_key]
logger.debug(f"缓存已满,淘汰最老条目: {oldest_key}")
# 🔧 使用准确的内存估算方法
try:
total_size = estimate_size_smart(cache_key) + estimate_size_smart(result)
@@ -580,7 +579,7 @@ class AdaptiveBatchScheduler:
# 使用默认值
self._cache_size_map[cache_key] = 1024
self._cache_memory_estimate += 1024
self._result_cache[cache_key] = (result, time.time())
async def get_stats(self) -> BatchStats:

View File

@@ -171,7 +171,7 @@ class LRUCache(Generic[T]):
)
else:
adjusted_created_at = now
entry = CacheEntry(
value=value,
created_at=adjusted_created_at,
@@ -345,7 +345,7 @@ class MultiLevelCache:
# 估算数据大小(如果未提供)
if size is None:
size = estimate_size_smart(value)
# 检查单个条目大小是否超过限制
if size > self.max_item_size_bytes:
logger.warning(
@@ -354,7 +354,7 @@ class MultiLevelCache:
f"limit={self.max_item_size_bytes / (1024 * 1024):.2f}MB"
)
return
# 根据TTL决定写入哪个缓存层
if ttl is not None:
# 有自定义TTL根据TTL大小决定写入层级
@@ -394,37 +394,37 @@ class MultiLevelCache:
"""获取所有缓存层的统计信息(修正版,避免重复计数)"""
l1_stats = await self.l1_cache.get_stats()
l2_stats = await self.l2_cache.get_stats()
# 🔧 修复计算实际独占的内存避免L1和L2共享数据的重复计数
l1_keys = set(self.l1_cache._cache.keys())
l2_keys = set(self.l2_cache._cache.keys())
shared_keys = l1_keys & l2_keys
l1_only_keys = l1_keys - l2_keys
l2_only_keys = l2_keys - l1_keys
# 计算实际总内存(避免重复计数)
# L1独占内存
l1_only_size = sum(
self.l1_cache._cache[k].size
for k in l1_only_keys
self.l1_cache._cache[k].size
for k in l1_only_keys
if k in self.l1_cache._cache
)
# L2独占内存
l2_only_size = sum(
self.l2_cache._cache[k].size
for k in l2_only_keys
self.l2_cache._cache[k].size
for k in l2_only_keys
if k in self.l2_cache._cache
)
# 共享内存只计算一次使用L1的数据
shared_size = sum(
self.l1_cache._cache[k].size
for k in shared_keys
self.l1_cache._cache[k].size
for k in shared_keys
if k in self.l1_cache._cache
)
actual_total_size = l1_only_size + l2_only_size + shared_size
return {
"l1": l1_stats,
"l2": l2_stats,
@@ -442,7 +442,7 @@ class MultiLevelCache:
"""检查并强制清理超出内存限制的缓存"""
stats = await self.get_stats()
total_size = stats["l1"].total_size + stats["l2"].total_size
if total_size > self.max_memory_bytes:
memory_mb = total_size / (1024 * 1024)
max_mb = self.max_memory_bytes / (1024 * 1024)
@@ -452,14 +452,14 @@ class MultiLevelCache:
)
# 优先清理L2缓存温数据
await self.l2_cache.clear()
# 如果清理L2后仍超限清理L1
stats_after_l2 = await self.get_stats()
total_after_l2 = stats_after_l2["l1"].total_size + stats_after_l2["l2"].total_size
if total_after_l2 > self.max_memory_bytes:
logger.warning("清理L2后仍超限继续清理L1缓存")
await self.l1_cache.clear()
logger.info("缓存强制清理完成")
async def start_cleanup_task(self, interval: float = 60) -> None:
@@ -476,10 +476,10 @@ class MultiLevelCache:
while not self._is_closing:
try:
await asyncio.sleep(interval)
if self._is_closing:
break
stats = await self.get_stats()
l1_stats = stats["l1"]
l2_stats = stats["l2"]
@@ -493,13 +493,13 @@ class MultiLevelCache:
f"共享: {stats['shared_keys_count']}键/{stats['shared_mb']:.2f}MB "
f"(去重节省{stats['dedup_savings_mb']:.2f}MB)"
)
# 🔧 清理过期条目
await self._clean_expired_entries()
# 检查内存限制
await self.check_memory_limit()
except asyncio.CancelledError:
break
except Exception as e:
@@ -511,7 +511,7 @@ class MultiLevelCache:
async def stop_cleanup_task(self) -> None:
"""停止清理任务"""
self._is_closing = True
if self._cleanup_task is not None:
self._cleanup_task.cancel()
try:
@@ -520,43 +520,43 @@ class MultiLevelCache:
pass
self._cleanup_task = None
logger.info("缓存清理任务已停止")
async def _clean_expired_entries(self) -> None:
"""清理过期的缓存条目"""
try:
current_time = time.time()
# 清理 L1 过期条目
async with self.l1_cache._lock:
expired_keys = [
key for key, entry in self.l1_cache._cache.items()
if current_time - entry.created_at > self.l1_cache.ttl
]
for key in expired_keys:
entry = self.l1_cache._cache.pop(key, None)
if entry:
self.l1_cache._stats.evictions += 1
self.l1_cache._stats.item_count -= 1
self.l1_cache._stats.total_size -= entry.size
# 清理 L2 过期条目
async with self.l2_cache._lock:
expired_keys = [
key for key, entry in self.l2_cache._cache.items()
if current_time - entry.created_at > self.l2_cache.ttl
]
for key in expired_keys:
entry = self.l2_cache._cache.pop(key, None)
if entry:
self.l2_cache._stats.evictions += 1
self.l2_cache._stats.item_count -= 1
self.l2_cache._stats.total_size -= entry.size
if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期缓存条目")
except Exception as e:
logger.error(f"清理过期条目失败: {e}", exc_info=True)
@@ -568,7 +568,7 @@ _cache_lock = asyncio.Lock()
async def get_cache() -> MultiLevelCache:
"""获取全局缓存实例(单例)
从配置文件读取缓存参数,如果配置未加载则使用默认值
如果配置中禁用了缓存返回一个最小化的缓存实例容量为1
"""
@@ -580,9 +580,9 @@ async def get_cache() -> MultiLevelCache:
# 尝试从配置读取参数
try:
from src.config.config import global_config
db_config = global_config.database
# 检查是否启用缓存
if not db_config.enable_database_cache:
logger.info("数据库缓存已禁用,使用最小化缓存实例")
@@ -594,7 +594,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb=1,
)
return _global_cache
l1_max_size = db_config.cache_l1_max_size
l1_ttl = db_config.cache_l1_ttl
l2_max_size = db_config.cache_l2_max_size
@@ -602,7 +602,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb = db_config.cache_max_memory_mb
max_item_size_mb = db_config.cache_max_item_size_mb
cleanup_interval = db_config.cache_cleanup_interval
logger.info(
f"从配置加载缓存参数: L1({l1_max_size}/{l1_ttl}s), "
f"L2({l2_max_size}/{l2_ttl}s), 内存限制({max_memory_mb}MB), "
@@ -618,7 +618,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb = 100
max_item_size_mb = 1
cleanup_interval = 60
_global_cache = MultiLevelCache(
l1_max_size=l1_max_size,
l1_ttl=l1_ttl,

View File

@@ -4,73 +4,74 @@
提供比 sys.getsizeof() 更准确的内存占用估算方法
"""
import sys
import pickle
import sys
from typing import Any
import numpy as np
def get_accurate_size(obj: Any, seen: set | None = None) -> int:
"""
准确估算对象的内存大小(递归计算所有引用对象)
比 sys.getsizeof() 准确得多,特别是对于复杂嵌套对象。
Args:
obj: 要估算大小的对象
seen: 已访问对象的集合(用于避免循环引用)
Returns:
估算的字节数
"""
if seen is None:
seen = set()
obj_id = id(obj)
if obj_id in seen:
return 0
seen.add(obj_id)
size = sys.getsizeof(obj)
# NumPy 数组特殊处理
if isinstance(obj, np.ndarray):
size += obj.nbytes
return size
# 字典:递归计算所有键值对
if isinstance(obj, dict):
size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen)
size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen)
for k, v in obj.items())
# 列表、元组、集合:递归计算所有元素
elif isinstance(obj, (list, tuple, set, frozenset)):
elif isinstance(obj, list | tuple | set | frozenset):
size += sum(get_accurate_size(item, seen) for item in obj)
# 有 __dict__ 的对象:递归计算属性
elif hasattr(obj, '__dict__'):
elif hasattr(obj, "__dict__"):
size += get_accurate_size(obj.__dict__, seen)
# 其他可迭代对象
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)):
elif hasattr(obj, "__iter__") and not isinstance(obj, str | bytes | bytearray):
try:
size += sum(get_accurate_size(item, seen) for item in obj)
except:
pass
return size
def get_pickle_size(obj: Any) -> int:
"""
使用 pickle 序列化大小作为参考
通常比 sys.getsizeof() 更接近实际内存占用,
但可能略小于真实内存占用(不包括 Python 对象开销)
Args:
obj: 要估算大小的对象
Returns:
pickle 序列化后的字节数,失败返回 0
"""
@@ -83,17 +84,17 @@ def get_pickle_size(obj: Any) -> int:
def estimate_size_smart(obj: Any, max_depth: int = 5, sample_large: bool = True) -> int:
"""
智能估算对象大小(平衡准确性和性能)
使用深度受限的递归估算+采样策略,平衡准确性和性能:
- 深度5层足以覆盖99%的缓存数据结构
- 对大型容器(>100项进行采样估算
- 性能开销约60倍于sys.getsizeof但准确度提升1000+倍
Args:
obj: 要估算大小的对象
max_depth: 最大递归深度默认5层可覆盖大多数嵌套结构
sample_large: 对大型容器是否采样默认True提升性能
Returns:
估算的字节数
"""
@@ -105,24 +106,24 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
# 检查深度限制
if depth <= 0:
return sys.getsizeof(obj)
# 检查循环引用
obj_id = id(obj)
if obj_id in seen:
return 0
seen.add(obj_id)
# 基本大小
size = sys.getsizeof(obj)
# 简单类型直接返回
if isinstance(obj, (int, float, bool, type(None), str, bytes, bytearray)):
if isinstance(obj, int | float | bool | type(None) | str | bytes | bytearray):
return size
# NumPy 数组特殊处理
if isinstance(obj, np.ndarray):
return size + obj.nbytes
# 字典递归
if isinstance(obj, dict):
items = list(obj.items())
@@ -130,7 +131,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
# 大字典采样前50 + 中间50 + 最后50
sample_items = items[:50] + items[len(items)//2-25:len(items)//2+25] + items[-50:]
sampled_size = sum(
_estimate_recursive(k, depth - 1, seen, sample_large) +
_estimate_recursive(k, depth - 1, seen, sample_large) +
_estimate_recursive(v, depth - 1, seen, sample_large)
for k, v in sample_items
)
@@ -142,9 +143,9 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
size += _estimate_recursive(k, depth - 1, seen, sample_large)
size += _estimate_recursive(v, depth - 1, seen, sample_large)
return size
# 列表、元组、集合递归
if isinstance(obj, (list, tuple, set, frozenset)):
if isinstance(obj, list | tuple | set | frozenset):
items = list(obj)
if sample_large and len(items) > 100:
# 大容器采样前50 + 中间50 + 最后50
@@ -160,21 +161,21 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
for item in items:
size += _estimate_recursive(item, depth - 1, seen, sample_large)
return size
# 有 __dict__ 的对象
if hasattr(obj, '__dict__'):
if hasattr(obj, "__dict__"):
size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large)
return size
def format_size(size_bytes: int) -> str:
"""
格式化字节数为人类可读的格式
Args:
size_bytes: 字节数
Returns:
格式化后的字符串,如 "1.23 MB"
"""

View File

@@ -2,7 +2,6 @@ import os
import shutil
import sys
from datetime import datetime
from typing import Optional
import tomlkit
from pydantic import Field
@@ -381,7 +380,7 @@ class Config(ValidatedConfigBase):
notice: NoticeConfig = Field(..., description="Notice消息配置")
emoji: EmojiConfig = Field(..., description="表情配置")
expression: ExpressionConfig = Field(..., description="表达配置")
memory: Optional[MemoryConfig] = Field(default=None, description="记忆配置")
memory: MemoryConfig | None = Field(default=None, description="记忆配置")
mood: MoodConfig = Field(..., description="情绪配置")
reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置")
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")

View File

@@ -401,16 +401,16 @@ class MemoryConfig(ValidatedConfigBase):
memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡")
memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流")
memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列")
# === 记忆图系统配置 (Memory Graph System) ===
# 新一代记忆系统的配置项
enable: bool = Field(default=True, description="启用记忆图系统")
data_dir: str = Field(default="data/memory_graph", description="记忆数据存储目录")
# 向量存储配置
vector_collection_name: str = Field(default="memory_nodes", description="向量集合名称")
vector_db_path: str = Field(default="data/memory_graph/chroma_db", description="向量数据库路径")
# 检索配置
search_top_k: int = Field(default=10, description="默认检索返回数量")
search_min_importance: float = Field(default=0.3, description="最小重要性阈值")
@@ -418,13 +418,13 @@ class MemoryConfig(ValidatedConfigBase):
search_max_expand_depth: int = Field(default=2, description="检索时图扩展深度0-3")
search_expand_semantic_threshold: float = Field(default=0.3, description="图扩展时语义相似度阈值建议0.3-0.5,过低可能引入无关记忆,过高无法扩展)")
enable_query_optimization: bool = Field(default=True, description="启用查询优化")
# 检索权重配置 (记忆图系统)
search_vector_weight: float = Field(default=0.4, description="向量相似度权重")
search_graph_distance_weight: float = Field(default=0.2, description="图距离权重")
search_importance_weight: float = Field(default=0.2, description="重要性权重")
search_recency_weight: float = Field(default=0.2, description="时效性权重")
# 记忆整合配置
consolidation_enabled: bool = Field(default=False, description="是否启用记忆整合")
consolidation_interval_hours: float = Field(default=2.0, description="整合任务执行间隔(小时)")
@@ -442,21 +442,21 @@ class MemoryConfig(ValidatedConfigBase):
consolidation_linking_min_confidence: float = Field(default=0.7, description="LLM分析最低置信度阈值")
consolidation_linking_llm_temperature: float = Field(default=0.2, description="LLM分析温度参数")
consolidation_linking_llm_max_tokens: int = Field(default=1500, description="LLM分析最大输出长度")
# 遗忘配置 (记忆图系统)
forgetting_enabled: bool = Field(default=True, description="是否启用自动遗忘")
forgetting_activation_threshold: float = Field(default=0.1, description="激活度阈值")
forgetting_min_importance: float = Field(default=0.8, description="最小保护重要性")
# 激活配置
activation_decay_rate: float = Field(default=0.9, description="激活度衰减率")
activation_propagation_strength: float = Field(default=0.5, description="激活传播强度")
activation_propagation_depth: int = Field(default=2, description="激活传播深度")
# 性能配置
max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数")
max_related_memories: int = Field(default=5, description="相关记忆最大数量")
# 节点去重合并配置
node_merger_similarity_threshold: float = Field(default=0.85, description="节点去重相似度阈值")
node_merger_context_match_required: bool = Field(default=True, description="节点合并是否要求上下文匹配")

View File

@@ -534,7 +534,7 @@ class _RequestExecutor:
model_name = model_info.name
retry_interval = api_provider.retry_interval
if isinstance(e, (NetworkConnectionError, ReqAbortException)):
if isinstance(e, NetworkConnectionError | ReqAbortException):
return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
elif isinstance(e, RespNotOkException):
return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)

View File

@@ -100,10 +100,10 @@ class VectorStore:
# 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items():
if isinstance(value, (list, dict)):
if isinstance(value, list | dict):
import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None:
elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value
else:
metadata[key] = str(value)
@@ -149,9 +149,9 @@ class VectorStore:
"created_at": n.created_at.isoformat(),
}
for key, value in n.metadata.items():
if isinstance(value, (list, dict)):
if isinstance(value, list | dict):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None:
elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value # type: ignore
else:
metadata[key] = str(value)

View File

@@ -4,8 +4,6 @@
from __future__ import annotations
import asyncio
import numpy as np
from src.common.logger import get_logger
@@ -72,7 +70,7 @@ class EmbeddingGenerator:
logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
self._api_available = False
async def generate(self, text: str) -> np.ndarray | None:
"""
生成单个文本的嵌入向量
@@ -130,7 +128,7 @@ class EmbeddingGenerator:
logger.debug(f"API 嵌入生成失败: {e}")
return None
def _get_dimension(self) -> int:
"""获取嵌入维度"""
# 优先使用 API 维度

View File

@@ -7,11 +7,12 @@
"""
import atexit
import orjson
import os
import threading
from typing import Any, ClassVar
import orjson
from src.common.logger import get_logger
# 获取日志记录器
@@ -125,7 +126,7 @@ class PluginStorage:
try:
with open(self.file_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8'))
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8"))
self._dirty = False # 保存后重置标志
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
except Exception as e:

View File

@@ -5,12 +5,12 @@ MCP Client Manager
"""
import asyncio
import orjson
import shutil
from pathlib import Path
from typing import Any
import mcp.types
import orjson
from fastmcp.client import Client, StdioTransport, StreamableHttpTransport
from src.common.logger import get_logger

View File

@@ -4,11 +4,13 @@
"""
import time
from typing import Any, Optional
from dataclasses import dataclass, asdict, field
from dataclasses import dataclass, field
from typing import Any
import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
logger = get_logger("stream_tool_history")
@@ -18,10 +20,10 @@ class ToolCallRecord:
"""工具调用记录"""
tool_name: str
args: dict[str, Any]
result: Optional[dict[str, Any]] = None
result: dict[str, Any] | None = None
status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒)
execution_time: float | None = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息
@@ -32,9 +34,9 @@ class ToolCallRecord:
content = self.result.get("content", "")
if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)):
elif isinstance(content, list | dict):
try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..."
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..."
except Exception:
self.result_preview = str(content)[:500] + "..."
else:
@@ -105,7 +107,7 @@ class StreamToolHistoryManager:
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""从缓存或历史记录中获取结果
Args:
@@ -160,9 +162,9 @@ class StreamToolHistoryManager:
return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None,
tool_file_path: Optional[str] = None,
ttl: Optional[int] = None) -> None:
execution_time: float | None = None,
tool_file_path: str | None = None,
ttl: int | None = None) -> None:
"""缓存工具调用结果
Args:
@@ -207,7 +209,7 @@ class StreamToolHistoryManager:
except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]:
async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
"""获取最近的历史记录
Args:
@@ -295,7 +297,7 @@ class StreamToolHistoryManager:
self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]:
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""在内存历史记录中搜索缓存
Args:
@@ -333,7 +335,7 @@ class StreamToolHistoryManager:
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]:
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> str | None:
"""提取语义查询参数
Args:
@@ -370,7 +372,7 @@ class StreamToolHistoryManager:
return ""
try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8')
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
if len(args_str) > max_length:
args_str = args_str[:max_length] + "..."
return args_str
@@ -411,4 +413,4 @@ def cleanup_stream_manager(chat_id: str) -> None:
"""
if chat_id in _stream_managers:
del _stream_managers[chat_id]
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")

View File

@@ -1,5 +1,6 @@
import inspect
import time
from dataclasses import asdict
from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager
@@ -10,8 +11,7 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord
from dataclasses import asdict
from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_stream_tool_history_manager
logger = get_logger("tool_use")
@@ -140,7 +140,7 @@ class ToolExecutor:
# 构建工具调用历史文本
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
# 获取人设信息
personality_core = global_config.personality.personality_core
personality_side = global_config.personality.personality_side
@@ -197,7 +197,7 @@ class ToolExecutor:
return tool_definitions
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用
@@ -338,9 +338,8 @@ class ToolExecutor:
if tool_instance and result and tool_instance.enable_cache:
try:
tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key)
function_args.get(tool_instance.semantic_cache_query_key)
await self.history_manager.cache_result(
tool_name=tool_call.func_name,

View File

@@ -122,7 +122,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
+ relationship_score * self.score_weights["relationship"]
+ mentioned_score * self.score_weights["mentioned"]
)
# 限制总分上限为1.0,确保分数在合理范围内
total_score = min(raw_total_score, 1.0)
@@ -131,7 +131,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}"
)
if raw_total_score > 1.0:
logger.debug(f"[Affinity兴趣计算] 原始分数 {raw_total_score:.3f} 超过1.0,已限制为 {total_score:.3f}")
@@ -217,7 +217,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
return 0.0
except asyncio.TimeoutError:
logger.warning(f"⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数")
logger.warning("⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数")
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
except Exception as e:
logger.warning(f"智能兴趣匹配失败: {e}")
@@ -251,19 +251,19 @@ class AffinityInterestCalculator(BaseInterestCalculator):
def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float:
"""计算提及分 - 区分强提及和弱提及
强提及(被@、被回复、私聊): 使用 strong_mention_interest_score
弱提及(文本匹配名字/别名): 使用 weak_mention_interest_score
"""
from src.chat.utils.utils import is_mentioned_bot_in_message
# 使用统一的提及检测函数
is_mentioned, mention_type = is_mentioned_bot_in_message(message)
if not is_mentioned:
logger.debug("[提及分计算] 未提及机器人返回0.0")
return 0.0
# mention_type: 0=未提及, 1=弱提及, 2=强提及
if mention_type >= 2:
# 强提及:被@、被回复、私聊
@@ -281,22 +281,22 @@ class AffinityInterestCalculator(BaseInterestCalculator):
def _apply_no_reply_threshold_adjustment(self) -> tuple[float, float]:
"""应用阈值调整(包括连续不回复和回复后降低机制)
Returns:
tuple[float, float]: (调整后的回复阈值, 调整后的动作阈值)
"""
# 基础阈值
base_reply_threshold = self.reply_threshold
base_action_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
total_reduction = 0.0
# 1. 连续不回复的阈值降低
if self.no_reply_count > 0 and self.no_reply_count < self.max_no_reply_count:
no_reply_reduction = self.no_reply_count * self.probability_boost_per_no_reply
total_reduction += no_reply_reduction
logger.debug(f"[阈值调整] 连续不回复降低: {no_reply_reduction:.3f} (计数: {self.no_reply_count})")
# 2. 回复后的阈值降低使bot更容易连续对话
if self.enable_post_reply_boost and self.post_reply_boost_remaining > 0:
# 计算衰减后的降低值
@@ -309,16 +309,16 @@ class AffinityInterestCalculator(BaseInterestCalculator):
f"[阈值调整] 回复后降低: {post_reply_reduction:.3f} "
f"(剩余次数: {self.post_reply_boost_remaining}, 衰减: {decay_factor:.2f})"
)
# 应用总降低量
adjusted_reply_threshold = max(0.0, base_reply_threshold - total_reduction)
adjusted_action_threshold = max(0.0, base_action_threshold - total_reduction)
return adjusted_reply_threshold, adjusted_action_threshold
def _apply_no_reply_boost(self, base_score: float) -> float:
"""【已弃用】应用连续不回复的概率提升
注意:此方法已被 _apply_no_reply_threshold_adjustment 替代
保留用于向后兼容
"""
@@ -388,7 +388,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
self.no_reply_count = 0
else:
self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count)
def on_reply_sent(self):
"""当机器人发送回复后调用,激活回复后阈值降低机制"""
if self.enable_post_reply_boost:
@@ -399,16 +399,16 @@ class AffinityInterestCalculator(BaseInterestCalculator):
)
# 同时重置不回复计数
self.no_reply_count = 0
def on_message_processed(self, replied: bool):
"""消息处理完成后调用,更新各种计数器
Args:
replied: 是否回复了此消息
"""
# 更新不回复计数
self.update_no_reply_count(replied)
# 如果已回复,激活回复后降低机制
if replied:
self.on_reply_sent()

View File

@@ -4,10 +4,10 @@ AffinityFlow Chatter 规划器模块
包含计划生成、过滤、执行等规划相关功能
"""
from . import planner_prompts
from .plan_executor import ChatterPlanExecutor
from .plan_filter import ChatterPlanFilter
from .plan_generator import ChatterPlanGenerator
from .planner import ChatterActionPlanner
from . import planner_prompts
__all__ = ["ChatterActionPlanner", "planner_prompts", "ChatterPlanGenerator", "ChatterPlanFilter", "ChatterPlanExecutor"]
__all__ = ["ChatterActionPlanner", "ChatterPlanExecutor", "ChatterPlanFilter", "ChatterPlanGenerator", "planner_prompts"]

View File

@@ -14,9 +14,7 @@ from json_repair import repair_json
# 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.chat.utils.chat_message_builder import (
build_readable_actions,
build_readable_messages_with_id,
get_actions_by_timestamp_with_chat,
)
from src.chat.utils.prompt import global_prompt_manager
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
@@ -646,7 +644,7 @@ class ChatterPlanFilter:
memory_manager = get_memory_manager()
if not memory_manager:
return "记忆系统未初始化。"
# 将关键词转换为查询字符串
query = " ".join(keywords)
enhanced_memories = await memory_manager.search_memories(

View File

@@ -21,7 +21,6 @@ if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext
# 导入提示词模块以确保其被初始化
from src.plugins.built_in.affinity_flow_chatter.planner import planner_prompts
logger = get_logger("planner")
@@ -159,10 +158,10 @@ class ChatterActionPlanner:
action_data={},
action_message=None,
)
# 更新连续不回复计数
await self._update_interest_calculator_state(replied=False)
initial_plan = await self.generator.generate(chat_mode)
filtered_plan = initial_plan
filtered_plan.decided_actions = [no_action]
@@ -270,7 +269,7 @@ class ChatterActionPlanner:
try:
# Normal模式开始时刷新缓存消息到未读列表
await self._flush_cached_messages_to_unread(context)
unread_messages = context.get_unread_messages() if context else []
if not unread_messages:
@@ -347,7 +346,7 @@ class ChatterActionPlanner:
self._update_stats_from_execution_result(execution_result)
logger.info("Normal模式: 执行reply动作完成")
# 更新兴趣计算器状态(回复成功,重置不回复计数)
await self._update_interest_calculator_state(replied=True)
@@ -465,7 +464,7 @@ class ChatterActionPlanner:
async def _update_interest_calculator_state(self, replied: bool) -> None:
"""更新兴趣计算器状态(连续不回复计数和回复后降低机制)
Args:
replied: 是否回复了消息
"""
@@ -504,36 +503,36 @@ class ChatterActionPlanner:
async def _flush_cached_messages_to_unread(self, context: "StreamContext | None") -> list:
"""在planner开始时将缓存消息刷新到未读消息列表
此方法在动作修改器执行后、生成初始计划前调用,确保计划阶段能看到所有积累的消息。
Args:
context: 流上下文
Returns:
list: 刷新的消息列表
"""
if not context:
return []
try:
from src.chat.message_manager.message_manager import message_manager
stream_id = context.stream_id
if message_manager.is_running and message_manager.has_cached_messages(stream_id):
# 获取缓存消息
cached_messages = message_manager.flush_cached_messages(stream_id)
if cached_messages:
# 直接添加到上下文的未读消息列表
for message in cached_messages:
context.unread_messages.append(message)
logger.info(f"Planner开始前刷新缓存消息到未读列表: stream={stream_id}, 数量={len(cached_messages)}")
return cached_messages
return []
except ImportError:
logger.debug("MessageManager不可用跳过缓存刷新")
return []

View File

@@ -9,9 +9,9 @@ from .proactive_thinking_executor import execute_proactive_thinking
from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler
__all__ = [
"ProactiveThinkingReplyHandler",
"ProactiveThinkingMessageHandler",
"execute_proactive_thinking",
"ProactiveThinkingReplyHandler",
"ProactiveThinkingScheduler",
"execute_proactive_thinking",
"proactive_thinking_scheduler",
]

View File

@@ -3,7 +3,6 @@
当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复
"""
import orjson
from datetime import datetime
from typing import Any, Literal

View File

@@ -14,7 +14,6 @@ from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import config_api, generator_api, llm_api
@@ -320,7 +319,7 @@ class ContentService:
- 禁止在说说中直接、完整地提及当前的年月日,除非日期有特殊含义,但也尽量用节日名/节气名字代替。
2. **严禁重复**:下方会提供你最近发过的说说历史,你必须创作一条全新的、与历史记录内容和主题都不同的说说。
**其他的禁止的内容以及说明**
- 绝对禁止提及当下具体几点几分的时间戳。
- 绝对禁止攻击性内容和过度的负面情绪。

View File

@@ -136,10 +136,10 @@ class QZoneService:
logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}")
api_client = await self._get_api_client(qq_account, stream_id)
if not api_client:
logger.error(f"[DEBUG] API客户端获取失败返回错误")
logger.error("[DEBUG] API客户端获取失败返回错误")
return {"success": False, "message": "获取QZone API客户端失败"}
logger.info(f"[DEBUG] API客户端获取成功准备读取说说")
logger.info("[DEBUG] API客户端获取成功准备读取说说")
num_to_read = self.get_config("read.read_number", 5)
# 尝试执行如果Cookie失效则自动重试一次
@@ -186,7 +186,7 @@ class QZoneService:
# 检查是否是Cookie失效-3000错误
if "错误码: -3000" in error_msg and retry_count == 0:
logger.warning(f"检测到Cookie失效-3000错误准备删除缓存并重试...")
logger.warning("检测到Cookie失效-3000错误准备删除缓存并重试...")
# 删除Cookie缓存文件
cookie_file = self.cookie_service._get_cookie_file_path(qq_account)
@@ -623,7 +623,7 @@ class QZoneService:
logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
return None
logger.info(f"[DEBUG] p_skey获取成功")
logger.info("[DEBUG] p_skey获取成功")
gtk = self._generate_gtk(p_skey)
uin = cookies.get("uin", "").lstrip("o")
@@ -1230,7 +1230,7 @@ class QZoneService:
logger.error(f"监控好友动态失败: {e}", exc_info=True)
return []
logger.info(f"[DEBUG] API客户端构造完成返回包含6个方法的字典")
logger.info("[DEBUG] API客户端构造完成返回包含6个方法的字典")
return {
"publish": _publish,
"list_feeds": _list_feeds,

View File

@@ -3,11 +3,12 @@
负责记录和管理已回复过的评论ID避免重复回复
"""
import orjson
import time
from pathlib import Path
from typing import Any
import orjson
from src.common.logger import get_logger
logger = get_logger("MaiZone.ReplyTrackerService")
@@ -117,8 +118,8 @@ class ReplyTrackerService:
temp_file = self.reply_record_file.with_suffix(".tmp")
# 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f:
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')
with open(temp_file, "w", encoding="utf-8"):
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8")
# 如果写入成功,重命名为正式文件
if temp_file.stat().st_size > 0: # 确保写入成功

View File

@@ -1,7 +1,6 @@
import orjson
import random
import time
import random
import websockets as Server
import uuid
from maim_message import (
@@ -205,7 +204,7 @@ class SendHandler:
# 发送响应回MoFox-Bot
logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}")
await self.send_adapter_command_response(raw_message_base, response, request_id)
logger.debug(f"[DEBUG handle_adapter_command] send_adapter_command_response调用完成")
logger.debug("[DEBUG handle_adapter_command] send_adapter_command_response调用完成")
if response.get("status") == "ok":
logger.info(f"适配器命令 {action} 执行成功")

View File

@@ -1,10 +1,10 @@
"""
Metaso Search Engine (Chat Completions Mode)
"""
import orjson
from typing import Any
import httpx
import orjson
from src.common.logger import get_logger
from src.plugin_system.apis import config_api

View File

@@ -3,9 +3,10 @@ Serper search engine implementation
Google Search via Serper.dev API
"""
import aiohttp
from typing import Any
import aiohttp
from src.common.logger import get_logger
from src.plugin_system.apis import config_api

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
"""
from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool

View File

@@ -113,7 +113,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
search_tasks.append(engine.answer_search(custom_args))
else:
search_tasks.append(engine.search(custom_args))
@@ -162,7 +162,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索fallback策略")
results = await engine.answer_search(custom_args)
else:
@@ -195,7 +195,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'):
if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索")
results = await engine.answer_search(custom_args)
else:

View File

@@ -266,13 +266,13 @@ class UnifiedScheduler:
name=f"execute_{task.task_name}"
)
execution_tasks.append(execution_task)
# 追踪正在执行的任务,以便在 remove_schedule 时可以取消
self._executing_tasks[task.schedule_id] = execution_task
# 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务)
results = await asyncio.gather(*execution_tasks, return_exceptions=True)
# 清理执行追踪
for task in tasks_to_trigger:
self._executing_tasks.pop(task.schedule_id, None)
@@ -515,7 +515,7 @@ class UnifiedScheduler:
async def remove_schedule(self, schedule_id: str) -> bool:
"""移除调度任务
如果任务正在执行,会取消执行中的任务
"""
async with self._lock:
@@ -524,7 +524,7 @@ class UnifiedScheduler:
return False
task = self._tasks[schedule_id]
# 检查是否有正在执行的任务
executing_task = self._executing_tasks.get(schedule_id)
if executing_task and not executing_task.done():

View File

@@ -19,42 +19,42 @@ logger = get_logger(__name__)
def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str, Any] | list | None:
"""
从 LLM 响应中提取并解析 JSON
处理策略:
1. 清理 Markdown 代码块标记(```json 和 ```
2. 提取 JSON 对象或数组
3. 使用 json_repair 修复格式问题
4. 解析为 Python 对象
Args:
response: LLM 响应字符串
strict: 严格模式,如果为 True 则解析失败时返回 None否则尝试容错处理
Returns:
解析后的 dict 或 list失败时返回 None
Examples:
>>> extract_and_parse_json('```json\\n{"key": "value"}\\n```')
{'key': 'value'}
>>> extract_and_parse_json('Some text {"key": "value"} more text')
{'key': 'value'}
>>> extract_and_parse_json('[{"a": 1}, {"b": 2}]')
[{'a': 1}, {'b': 2}]
"""
if not response:
logger.debug("空响应,无法解析 JSON")
return None
try:
# 步骤 1: 清理响应
cleaned = _clean_llm_response(response)
if not cleaned:
logger.warning("清理后的响应为空")
return None
# 步骤 2: 尝试直接解析
try:
result = orjson.loads(cleaned)
@@ -62,11 +62,11 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
return result
except Exception as direct_error:
logger.debug(f"直接解析失败: {type(direct_error).__name__}: {direct_error}")
# 步骤 3: 使用 json_repair 修复并解析
try:
repaired = repair_json(cleaned)
# repair_json 可能返回字符串或已解析的对象
if isinstance(repaired, str):
result = orjson.loads(repaired)
@@ -74,16 +74,16 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
else:
result = repaired
logger.debug(f"✅ JSON 修复后解析成功(对象模式),类型: {type(result).__name__}")
return result
except Exception as repair_error:
logger.warning(f"JSON 修复失败: {type(repair_error).__name__}: {repair_error}")
if strict:
logger.error(f"严格模式下解析失败,响应片段: {cleaned[:200]}")
return None
# 最后的容错尝试:返回空字典或空列表
if cleaned.strip().startswith("["):
logger.warning("返回空列表作为容错")
@@ -91,7 +91,7 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
else:
logger.warning("返回空字典作为容错")
return {}
except Exception as e:
logger.error(f"❌ JSON 解析过程出现异常: {type(e).__name__}: {e}")
if strict:
@@ -102,37 +102,37 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
def _clean_llm_response(response: str) -> str:
"""
清理 LLM 响应,提取 JSON 部分
处理步骤:
1. 移除 Markdown 代码块标记(```json 和 ```
2. 提取第一个完整的 JSON 对象 {...} 或数组 [...]
3. 清理多余的空格和换行
Args:
response: 原始 LLM 响应
Returns:
清理后的 JSON 字符串
"""
if not response:
return ""
cleaned = response.strip()
# 移除 Markdown 代码块标记
# 匹配 ```json ... ``` 或 ``` ... ```
code_block_patterns = [
r"```json\s*(.*?)```", # ```json ... ```
r"```\s*(.*?)```", # ``` ... ```
]
for pattern in code_block_patterns:
match = re.search(pattern, cleaned, re.IGNORECASE | re.DOTALL)
if match:
cleaned = match.group(1).strip()
logger.debug(f"从 Markdown 代码块中提取内容,长度: {len(cleaned)}")
break
# 提取 JSON 对象或数组
# 优先查找对象 {...},其次查找数组 [...]
for start_char, end_char in [("{", "}"), ("[", "]")]:
@@ -143,7 +143,7 @@ def _clean_llm_response(response: str) -> str:
if extracted:
logger.debug(f"提取到 {start_char}...{end_char} 结构,长度: {len(extracted)}")
return extracted
# 如果没有找到明确的 JSON 结构,返回清理后的原始内容
logger.debug("未找到明确的 JSON 结构,返回清理后的原始内容")
return cleaned
@@ -152,39 +152,39 @@ def _clean_llm_response(response: str) -> str:
def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char: str) -> str | None:
"""
从指定位置提取平衡的 JSON 结构
使用栈匹配算法找到对应的结束符,处理嵌套和字符串中的特殊字符
Args:
text: 源文本
start_idx: 起始字符的索引
start_char: 起始字符({ 或 [
end_char: 结束字符(} 或 ]
Returns:
提取的 JSON 字符串,失败时返回 None
"""
depth = 0
in_string = False
escape_next = False
for i in range(start_idx, len(text)):
char = text[i]
# 处理转义字符
if escape_next:
escape_next = False
continue
if char == "\\":
escape_next = True
continue
# 处理字符串
if char == '"':
in_string = not in_string
continue
# 只在非字符串内处理括号
if not in_string:
if char == start_char:
@@ -194,7 +194,7 @@ def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char:
if depth == 0:
# 找到匹配的结束符
return text[start_idx : i + 1].strip()
# 没有找到匹配的结束符
logger.debug(f"未找到匹配的 {end_char},深度: {depth}")
return None
@@ -203,11 +203,11 @@ def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char:
def safe_parse_json(json_str: str, default: Any = None) -> Any:
"""
安全解析 JSON失败时返回默认值
Args:
json_str: JSON 字符串
default: 解析失败时返回的默认值
Returns:
解析结果或默认值
"""
@@ -222,19 +222,19 @@ def safe_parse_json(json_str: str, default: Any = None) -> Any:
def extract_json_field(response: str, field_name: str, default: Any = None) -> Any:
"""
从 LLM 响应中提取特定字段的值
Args:
response: LLM 响应
field_name: 字段名
default: 字段不存在时的默认值
Returns:
字段值或默认值
"""
parsed = extract_and_parse_json(response, strict=False)
if isinstance(parsed, dict):
return parsed.get(field_name, default)
logger.warning(f"解析结果不是字典,无法提取字段 '{field_name}'")
return default

View File

@@ -14,7 +14,7 @@ sys.path.insert(0, str(project_root))
from tools.memory_visualizer.visualizer_server import run_server
if __name__ == '__main__':
if __name__ == "__main__":
print("=" * 60)
print("🦊 MoFox Bot - 记忆图可视化工具")
print("=" * 60)
@@ -24,10 +24,10 @@ if __name__ == '__main__':
print("⏹️ 按 Ctrl+C 停止服务器")
print()
print("=" * 60)
try:
run_server(
host='127.0.0.1',
host="127.0.0.1",
port=5000,
debug=True
)

View File

@@ -15,7 +15,7 @@ from pathlib import Path
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
if __name__ == '__main__':
if __name__ == "__main__":
print("=" * 70)
print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)")
print("=" * 70)
@@ -26,10 +26,10 @@ if __name__ == '__main__':
print(" • 快速启动,无需完整初始化")
print()
print("=" * 70)
try:
from tools.memory_visualizer.visualizer_simple import run_server
run_server(host='127.0.0.1', port=5001, debug=True)
run_server(host="127.0.0.1", port=5001, debug=True)
except KeyboardInterrupt:
print("\n\n👋 服务器已停止")
except Exception as e:

View File

@@ -11,7 +11,6 @@ import logging
import sys
from datetime import datetime
from pathlib import Path
from typing import Optional
from flask import Flask, jsonify, render_template, request
from flask_cors import CORS
@@ -28,7 +27,7 @@ app = Flask(__name__)
CORS(app) # 允许跨域请求
# 全局记忆管理器
memory_manager: Optional[MemoryManager] = None
memory_manager: MemoryManager | None = None
def init_memory_manager():
@@ -189,7 +188,7 @@ def search_memories():
init_memory_manager()
query = request.args.get("q", "")
memory_type = request.args.get("type", None)
request.args.get("type", None)
limit = int(request.args.get("limit", 50))
loop = asyncio.new_event_loop()

View File

@@ -4,20 +4,18 @@
直接从存储的数据文件生成可视化,无需启动完整的记忆管理器
"""
import orjson
import sys
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List, Set
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
from typing import Any
import orjson
# 添加项目根目录
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from flask import Flask, jsonify, render_template_string, request, send_from_directory
from flask import Flask, jsonify, render_template_string, request
from flask_cors import CORS
app = Flask(__name__)
@@ -29,38 +27,38 @@ data_dir = project_root / "data" / "memory_graph"
current_data_file = None # 当前选择的数据文件
def find_available_data_files() -> List[Path]:
def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
return files
# 查找多种可能的文件名
possible_files = [
"graph_store.json",
"memory_graph.json",
"graph_data.json",
]
for filename in possible_files:
file_path = data_dir / filename
if file_path.exists():
files.append(file_path)
# 查找所有备份文件
for pattern in ["graph_store_*.json", "memory_graph_*.json", "graph_data_*.json"]:
for backup_file in data_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
# 查找backups子目录
backups_dir = data_dir / "backups"
if backups_dir.exists():
for backup_file in backups_dir.glob("**/*.json"):
if backup_file not in files:
files.append(backup_file)
# 查找data/backup目录
backup_dir = data_dir.parent / "backup"
if backup_dir.exists():
@@ -70,22 +68,22 @@ def find_available_data_files() -> List[Path]:
for backup_file in backup_dir.glob("**/memory_*.json"):
if backup_file not in files:
files.append(backup_file)
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
def load_graph_data(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据"""
global graph_data_cache, current_data_file
# 如果指定了新文件,清除缓存
if file_path is not None and file_path != current_data_file:
graph_data_cache = None
current_data_file = file_path
if graph_data_cache is not None:
return graph_data_cache
try:
# 确定要加载的文件
if current_data_file is not None:
@@ -94,115 +92,115 @@ def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
# 尝试查找可用的数据文件
available_files = find_available_data_files()
if not available_files:
print(f"⚠️ 未找到任何图数据文件")
print("⚠️ 未找到任何图数据文件")
print(f"📂 搜索目录: {data_dir}")
return {
"nodes": [],
"edges": [],
"nodes": [],
"edges": [],
"memories": [],
"stats": {"total_nodes": 0, "total_edges": 0, "total_memories": 0},
"error": "未找到数据文件",
"available_files": []
}
# 使用最新的文件
graph_file = available_files[0]
current_data_file = graph_file
print(f"📂 自动选择最新文件: {graph_file}")
if not graph_file.exists():
print(f"⚠️ 图数据文件不存在: {graph_file}")
return {
"nodes": [],
"edges": [],
"nodes": [],
"edges": [],
"memories": [],
"stats": {"total_nodes": 0, "total_edges": 0, "total_memories": 0},
"error": f"文件不存在: {graph_file}"
}
print(f"📂 加载图数据: {graph_file}")
with open(graph_file, 'r', encoding='utf-8') as f:
with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read())
# 解析数据
nodes_dict = {}
edges_list = []
memory_info = []
# 实际文件格式是 {nodes: [], edges: [], metadata: {}}
# 不是 {memories: [{nodes: [], edges: []}]}
nodes = data.get("nodes", [])
edges = data.get("edges", [])
metadata = data.get("metadata", {})
print(f"✅ 找到 {len(nodes)} 个节点, {len(edges)} 条边")
# 处理节点
for node in nodes:
node_id = node.get('id', '')
node_id = node.get("id", "")
if node_id and node_id not in nodes_dict:
memory_ids = node.get('metadata', {}).get('memory_ids', [])
memory_ids = node.get("metadata", {}).get("memory_ids", [])
nodes_dict[node_id] = {
'id': node_id,
'label': node.get('content', ''),
'type': node.get('node_type', ''),
'group': extract_group_from_type(node.get('node_type', '')),
'title': f"{node.get('node_type', '')}: {node.get('content', '')}",
'metadata': node.get('metadata', {}),
'created_at': node.get('created_at', ''),
'memory_ids': memory_ids,
"id": node_id,
"label": node.get("content", ""),
"type": node.get("node_type", ""),
"group": extract_group_from_type(node.get("node_type", "")),
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
"metadata": node.get("metadata", {}),
"created_at": node.get("created_at", ""),
"memory_ids": memory_ids,
}
# 处理边 - 使用集合去重避免重复的边ID
existing_edge_ids = set()
for edge in edges:
# 边的ID字段可能是 'id' 或 'edge_id'
edge_id = edge.get('edge_id') or edge.get('id', '')
edge_id = edge.get("edge_id") or edge.get("id", "")
# 如果ID为空或已存在跳过这条边
if not edge_id or edge_id in existing_edge_ids:
continue
existing_edge_ids.add(edge_id)
memory_id = edge.get('metadata', {}).get('memory_id', '')
memory_id = edge.get("metadata", {}).get("memory_id", "")
# 注意: GraphStore 保存的格式使用 'source'/'target', 不是 'source_id'/'target_id'
edges_list.append({
'id': edge_id,
'from': edge.get('source', edge.get('source_id', '')),
'to': edge.get('target', edge.get('target_id', '')),
'label': edge.get('relation', ''),
'type': edge.get('edge_type', ''),
'importance': edge.get('importance', 0.5),
'title': f"{edge.get('edge_type', '')}: {edge.get('relation', '')}",
'arrows': 'to',
'memory_id': memory_id,
"id": edge_id,
"from": edge.get("source", edge.get("source_id", "")),
"to": edge.get("target", edge.get("target_id", "")),
"label": edge.get("relation", ""),
"type": edge.get("edge_type", ""),
"importance": edge.get("importance", 0.5),
"title": f"{edge.get('edge_type', '')}: {edge.get('relation', '')}",
"arrows": "to",
"memory_id": memory_id,
})
# 从元数据中获取统计信息
stats = metadata.get('statistics', {})
total_memories = stats.get('total_memories', 0)
stats = metadata.get("statistics", {})
total_memories = stats.get("total_memories", 0)
# TODO: 如果需要记忆详细信息,需要从其他地方加载
# 目前只有节点和边的数据
graph_data_cache = {
'nodes': list(nodes_dict.values()),
'edges': edges_list,
'memories': memory_info, # 空列表,因为文件中没有记忆详情
'stats': {
'total_nodes': len(nodes_dict),
'total_edges': len(edges_list),
'total_memories': total_memories,
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": memory_info, # 空列表,因为文件中没有记忆详情
"stats": {
"total_nodes": len(nodes_dict),
"total_edges": len(edges_list),
"total_memories": total_memories,
},
'current_file': str(graph_file),
'file_size': graph_file.stat().st_size,
'file_modified': datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
"current_file": str(graph_file),
"file_size": graph_file.stat().st_size,
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
}
print(f"📊 统计: {len(nodes_dict)} 个节点, {len(edges_list)} 条边, {total_memories} 条记忆")
print(f"📄 数据文件: {graph_file} ({graph_file.stat().st_size / 1024:.2f} KB)")
return graph_data_cache
except Exception as e:
print(f"❌ 加载失败: {e}")
import traceback
@@ -214,246 +212,246 @@ def extract_group_from_type(node_type: str) -> str:
"""从节点类型提取分组名"""
# 假设类型格式为 "主体" 或 "SUBJECT"
type_mapping = {
'主体': 'SUBJECT',
'主题': 'TOPIC',
'客体': 'OBJECT',
'属性': 'ATTRIBUTE',
'': 'VALUE',
"主体": "SUBJECT",
"主题": "TOPIC",
"客体": "OBJECT",
"属性": "ATTRIBUTE",
"": "VALUE",
}
return type_mapping.get(node_type, node_type)
def generate_memory_text(memory: Dict[str, Any]) -> str:
def generate_memory_text(memory: dict[str, Any]) -> str:
"""生成记忆的文本描述"""
try:
nodes = {n['id']: n for n in memory.get('nodes', [])}
edges = memory.get('edges', [])
subject_id = memory.get('subject_id', '')
nodes = {n["id"]: n for n in memory.get("nodes", [])}
edges = memory.get("edges", [])
subject_id = memory.get("subject_id", "")
if not subject_id or subject_id not in nodes:
return f"[记忆 {memory.get('id', '')[:8]}]"
parts = [nodes[subject_id]['content']]
parts = [nodes[subject_id]["content"]]
# 找主题节点
for edge in edges:
if edge.get('edge_type') == '记忆类型' and edge.get('source_id') == subject_id:
topic_id = edge.get('target_id', '')
if edge.get("edge_type") == "记忆类型" and edge.get("source_id") == subject_id:
topic_id = edge.get("target_id", "")
if topic_id in nodes:
parts.append(nodes[topic_id]['content'])
parts.append(nodes[topic_id]["content"])
# 找客体
for e2 in edges:
if e2.get('edge_type') == '核心关系' and e2.get('source_id') == topic_id:
obj_id = e2.get('target_id', '')
if e2.get("edge_type") == "核心关系" and e2.get("source_id") == topic_id:
obj_id = e2.get("target_id", "")
if obj_id in nodes:
parts.append(f"{e2.get('relation', '')} {nodes[obj_id]['content']}")
break
break
return " ".join(parts)
except Exception:
return f"[记忆 {memory.get('id', '')[:8]}]"
# 使用内嵌的HTML模板(与之前相同)
HTML_TEMPLATE = open(project_root / "tools" / "memory_visualizer" / "templates" / "visualizer.html", 'r', encoding='utf-8').read()
HTML_TEMPLATE = open(project_root / "tools" / "memory_visualizer" / "templates" / "visualizer.html", encoding="utf-8").read()
@app.route('/')
@app.route("/")
def index():
"""主页面"""
return render_template_string(HTML_TEMPLATE)
@app.route('/api/graph/full')
@app.route("/api/graph/full")
def get_full_graph():
"""获取完整记忆图数据"""
try:
data = load_graph_data()
return jsonify({
'success': True,
'data': data
"success": True,
"data": data
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
"success": False,
"error": str(e)
}), 500
@app.route('/api/memory/<memory_id>')
@app.route("/api/memory/<memory_id>")
def get_memory_detail(memory_id: str):
"""获取记忆详情"""
try:
data = load_graph_data()
memory = next((m for m in data['memories'] if m['id'] == memory_id), None)
memory = next((m for m in data["memories"] if m["id"] == memory_id), None)
if memory is None:
return jsonify({
'success': False,
'error': '记忆不存在'
"success": False,
"error": "记忆不存在"
}), 404
return jsonify({
'success': True,
'data': memory
"success": True,
"data": memory
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
"success": False,
"error": str(e)
}), 500
@app.route('/api/search')
@app.route("/api/search")
def search_memories():
"""搜索记忆"""
try:
query = request.args.get('q', '').lower()
limit = int(request.args.get('limit', 50))
query = request.args.get("q", "").lower()
limit = int(request.args.get("limit", 50))
data = load_graph_data()
# 简单的文本匹配搜索
results = []
for memory in data['memories']:
text = memory.get('text', '').lower()
for memory in data["memories"]:
text = memory.get("text", "").lower()
if query in text:
results.append(memory)
return jsonify({
'success': True,
'data': {
'results': results[:limit],
'count': len(results),
"success": True,
"data": {
"results": results[:limit],
"count": len(results),
}
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
"success": False,
"error": str(e)
}), 500
@app.route('/api/stats')
@app.route("/api/stats")
def get_statistics():
"""获取统计信息"""
try:
data = load_graph_data()
# 扩展统计信息
node_types = {}
memory_types = {}
for node in data['nodes']:
node_type = node.get('type', 'Unknown')
for node in data["nodes"]:
node_type = node.get("type", "Unknown")
node_types[node_type] = node_types.get(node_type, 0) + 1
for memory in data['memories']:
mem_type = memory.get('type', 'Unknown')
for memory in data["memories"]:
mem_type = memory.get("type", "Unknown")
memory_types[mem_type] = memory_types.get(mem_type, 0) + 1
stats = data.get('stats', {})
stats['node_types'] = node_types
stats['memory_types'] = memory_types
stats = data.get("stats", {})
stats["node_types"] = node_types
stats["memory_types"] = memory_types
return jsonify({
'success': True,
'data': stats
"success": True,
"data": stats
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
"success": False,
"error": str(e)
}), 500
@app.route('/api/reload')
@app.route("/api/reload")
def reload_data():
"""重新加载数据"""
global graph_data_cache
graph_data_cache = None
data = load_graph_data()
return jsonify({
'success': True,
'message': '数据已重新加载',
'stats': data.get('stats', {})
"success": True,
"message": "数据已重新加载",
"stats": data.get("stats", {})
})
@app.route('/api/files')
@app.route("/api/files")
def list_files():
"""列出所有可用的数据文件"""
try:
files = find_available_data_files()
file_list = []
for f in files:
stat = f.stat()
file_list.append({
'path': str(f),
'name': f.name,
'size': stat.st_size,
'size_kb': round(stat.st_size / 1024, 2),
'modified': datetime.fromtimestamp(stat.st_mtime).isoformat(),
'modified_readable': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S'),
'is_current': str(f) == str(current_data_file) if current_data_file else False
"path": str(f),
"name": f.name,
"size": stat.st_size,
"size_kb": round(stat.st_size / 1024, 2),
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
"modified_readable": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
"is_current": str(f) == str(current_data_file) if current_data_file else False
})
return jsonify({
'success': True,
'files': file_list,
'count': len(file_list),
'current_file': str(current_data_file) if current_data_file else None
"success": True,
"files": file_list,
"count": len(file_list),
"current_file": str(current_data_file) if current_data_file else None
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
"success": False,
"error": str(e)
}), 500
@app.route('/api/select_file', methods=['POST'])
@app.route("/api/select_file", methods=["POST"])
def select_file():
"""选择要加载的数据文件"""
global graph_data_cache, current_data_file
try:
data = request.get_json()
file_path = data.get('file_path')
file_path = data.get("file_path")
if not file_path:
return jsonify({
'success': False,
'error': '未提供文件路径'
"success": False,
"error": "未提供文件路径"
}), 400
file_path = Path(file_path)
if not file_path.exists():
return jsonify({
'success': False,
'error': f'文件不存在: {file_path}'
"success": False,
"error": f"文件不存在: {file_path}"
}), 404
# 清除缓存并加载新文件
graph_data_cache = None
current_data_file = file_path
graph_data = load_graph_data(file_path)
return jsonify({
'success': True,
'message': f'已切换到文件: {file_path.name}',
'stats': graph_data.get('stats', {})
"success": True,
"message": f"已切换到文件: {file_path.name}",
"stats": graph_data.get("stats", {})
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
"success": False,
"error": str(e)
}), 500
def run_server(host: str = '127.0.0.1', port: int = 5001, debug: bool = False):
def run_server(host: str = "127.0.0.1", port: int = 5001, debug: bool = False):
"""启动服务器"""
print("=" * 60)
print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)")
@@ -463,14 +461,14 @@ def run_server(host: str = '127.0.0.1', port: int = 5001, debug: bool = False):
print("⏹️ 按 Ctrl+C 停止服务器")
print("=" * 60)
print()
# 预加载数据
load_graph_data()
app.run(host=host, port=port, debug=debug)
if __name__ == '__main__':
if __name__ == "__main__":
try:
run_server(debug=True)
except KeyboardInterrupt: