This commit is contained in:
Windpicker-owo
2025-11-07 21:16:45 +08:00
69 changed files with 1061 additions and 3246 deletions

View File

@@ -17,19 +17,19 @@
import argparse import argparse
import json import json
from pathlib import Path
from typing import Dict, Any, List, Tuple
import logging import logging
from pathlib import Path
from typing import Any
import orjson import orjson
# 配置日志 # 配置日志
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[ handlers=[
logging.StreamHandler(), logging.StreamHandler(),
logging.FileHandler('embedding_cleanup.log', encoding='utf-8') logging.FileHandler("embedding_cleanup.log", encoding="utf-8")
] ]
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,13 +49,13 @@ class EmbeddingCleaner:
self.cleaned_files = [] self.cleaned_files = []
self.errors = [] self.errors = []
self.stats = { self.stats = {
'files_processed': 0, "files_processed": 0,
'embedings_removed': 0, "embedings_removed": 0,
'bytes_saved': 0, "bytes_saved": 0,
'nodes_processed': 0 "nodes_processed": 0
} }
def find_json_files(self) -> List[Path]: def find_json_files(self) -> list[Path]:
"""查找可能包含向量数据的 JSON 文件""" """查找可能包含向量数据的 JSON 文件"""
json_files = [] json_files = []
@@ -65,7 +65,7 @@ class EmbeddingCleaner:
json_files.append(memory_graph_file) json_files.append(memory_graph_file)
# 测试数据文件 # 测试数据文件
test_dir = self.data_dir / "test_*" self.data_dir / "test_*"
for test_path in self.data_dir.glob("test_*/memory_graph.json"): for test_path in self.data_dir.glob("test_*/memory_graph.json"):
if test_path.exists(): if test_path.exists():
json_files.append(test_path) json_files.append(test_path)
@@ -82,7 +82,7 @@ class EmbeddingCleaner:
logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件") logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件")
return json_files return json_files
def analyze_embedding_in_data(self, data: Dict[str, Any]) -> int: def analyze_embedding_in_data(self, data: dict[str, Any]) -> int:
""" """
分析数据中的 embedding 字段数量 分析数据中的 embedding 字段数量
@@ -97,7 +97,7 @@ class EmbeddingCleaner:
def count_embeddings(obj): def count_embeddings(obj):
nonlocal embedding_count nonlocal embedding_count
if isinstance(obj, dict): if isinstance(obj, dict):
if 'embedding' in obj: if "embedding" in obj:
embedding_count += 1 embedding_count += 1
for value in obj.values(): for value in obj.values():
count_embeddings(value) count_embeddings(value)
@@ -108,7 +108,7 @@ class EmbeddingCleaner:
count_embeddings(data) count_embeddings(data)
return embedding_count return embedding_count
def clean_embedding_from_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: def clean_embedding_from_data(self, data: dict[str, Any]) -> tuple[dict[str, Any], int]:
""" """
从数据中移除 embedding 字段 从数据中移除 embedding 字段
@@ -123,8 +123,8 @@ class EmbeddingCleaner:
def remove_embeddings(obj): def remove_embeddings(obj):
nonlocal removed_count nonlocal removed_count
if isinstance(obj, dict): if isinstance(obj, dict):
if 'embedding' in obj: if "embedding" in obj:
del obj['embedding'] del obj["embedding"]
removed_count += 1 removed_count += 1
for value in obj.values(): for value in obj.values():
remove_embeddings(value) remove_embeddings(value)
@@ -162,14 +162,14 @@ class EmbeddingCleaner:
data = orjson.loads(original_content) data = orjson.loads(original_content)
except orjson.JSONDecodeError: except orjson.JSONDecodeError:
# 回退到标准 json # 回退到标准 json
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# 分析 embedding 数据 # 分析 embedding 数据
embedding_count = self.analyze_embedding_in_data(data) embedding_count = self.analyze_embedding_in_data(data)
if embedding_count == 0: if embedding_count == 0:
logger.info(f" ✓ 文件中没有 embedding 数据,跳过") logger.info(" ✓ 文件中没有 embedding 数据,跳过")
return True return True
logger.info(f" 发现 {embedding_count} 个 embedding 字段") logger.info(f" 发现 {embedding_count} 个 embedding 字段")
@@ -193,30 +193,30 @@ class EmbeddingCleaner:
cleaned_data, cleaned_data,
indent=2, indent=2,
ensure_ascii=False ensure_ascii=False
).encode('utf-8') ).encode("utf-8")
cleaned_size = len(cleaned_content) cleaned_size = len(cleaned_content)
bytes_saved = original_size - cleaned_size bytes_saved = original_size - cleaned_size
# 原子写入 # 原子写入
temp_file = file_path.with_suffix('.tmp') temp_file = file_path.with_suffix(".tmp")
temp_file.write_bytes(cleaned_content) temp_file.write_bytes(cleaned_content)
temp_file.replace(file_path) temp_file.replace(file_path)
logger.info(f" ✓ 清理完成:") logger.info(" ✓ 清理完成:")
logger.info(f" - 移除 embedding 字段: {removed_count}") logger.info(f" - 移除 embedding 字段: {removed_count}")
logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)") logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)")
logger.info(f" - 新文件大小: {cleaned_size:,} 字节") logger.info(f" - 新文件大小: {cleaned_size:,} 字节")
# 更新统计 # 更新统计
self.stats['embedings_removed'] += removed_count self.stats["embedings_removed"] += removed_count
self.stats['bytes_saved'] += bytes_saved self.stats["bytes_saved"] += bytes_saved
else: else:
logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段") logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段")
self.stats['embedings_removed'] += embedding_count self.stats["embedings_removed"] += embedding_count
self.stats['files_processed'] += 1 self.stats["files_processed"] += 1
self.cleaned_files.append(file_path) self.cleaned_files.append(file_path)
return True return True
@@ -236,12 +236,12 @@ class EmbeddingCleaner:
节点数量 节点数量
""" """
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
node_count = 0 node_count = 0
if 'nodes' in data and isinstance(data['nodes'], list): if "nodes" in data and isinstance(data["nodes"], list):
node_count = len(data['nodes']) node_count = len(data["nodes"])
return node_count return node_count
@@ -268,7 +268,7 @@ class EmbeddingCleaner:
# 统计总节点数 # 统计总节点数
total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files) total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files)
self.stats['nodes_processed'] = total_nodes self.stats["nodes_processed"] = total_nodes
logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点") logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点")
@@ -295,8 +295,8 @@ class EmbeddingCleaner:
if not dry_run: if not dry_run:
logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节") logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节")
if self.stats['bytes_saved'] > 0: if self.stats["bytes_saved"] > 0:
mb_saved = self.stats['bytes_saved'] / 1024 / 1024 mb_saved = self.stats["bytes_saved"] / 1024 / 1024
logger.info(f"节省空间: {mb_saved:.2f} MB") logger.info(f"节省空间: {mb_saved:.2f} MB")
if self.errors: if self.errors:
@@ -342,7 +342,7 @@ def main():
print(" 请确保向量数据库正在正常工作。") print(" 请确保向量数据库正在正常工作。")
print() print()
response = input("确认继续?(yes/no): ") response = input("确认继续?(yes/no): ")
if response.lower() not in ['yes', 'y', '']: if response.lower() not in ["yes", "y", ""]:
print("操作已取消") print("操作已取消")
return return
@@ -352,4 +352,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -71,9 +71,10 @@ lunar_python
fuzzywuzzy fuzzywuzzy
python-multipart python-multipart
aiofiles aiofiles
jinja2
inkfox inkfox
soundfile soundfile
pedalboard pedalboard
# For local speech-to-text functionality (stt_whisper_plugin) # For local speech-to-text functionality (stt_whisper_plugin)
openai-whisper openai-whisper

View File

@@ -10,10 +10,10 @@
示例: 示例:
# 进程监控(启动 bot 并监控) # 进程监控(启动 bot 并监控)
python scripts/memory_profiler.py --monitor --interval 10 python scripts/memory_profiler.py --monitor --interval 10
# 对象分析(深度对象统计) # 对象分析(深度对象统计)
python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt
# 生成可视化图表 # 生成可视化图表
python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15 python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15
""" """
@@ -22,7 +22,6 @@ import argparse
import asyncio import asyncio
import gc import gc
import json import json
import os
import subprocess import subprocess
import sys import sys
import threading import threading
@@ -30,7 +29,6 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional
import psutil import psutil
@@ -56,29 +54,29 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
if bot_process.pid is None: if bot_process.pid is None:
print("❌ Bot 进程 PID 为空") print("❌ Bot 进程 PID 为空")
return return
print(f"🔍 开始监控 Bot 内存PID: {bot_process.pid}") print(f"🔍 开始监控 Bot 内存PID: {bot_process.pid}")
print(f"监控间隔: {interval}") print(f"监控间隔: {interval}")
print("按 Ctrl+C 停止监控和 Bot\n") print("按 Ctrl+C 停止监控和 Bot\n")
try: try:
process = psutil.Process(bot_process.pid) process = psutil.Process(bot_process.pid)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
print("❌ 无法找到 Bot 进程") print("❌ 无法找到 Bot 进程")
return return
history = [] history = []
iteration = 0 iteration = 0
try: try:
while bot_process.poll() is None: while bot_process.poll() is None:
try: try:
mem_info = process.memory_info() mem_info = process.memory_info()
mem_percent = process.memory_percent() mem_percent = process.memory_percent()
children = process.children(recursive=True) children = process.children(recursive=True)
children_mem = sum(child.memory_info().rss for child in children) children_mem = sum(child.memory_info().rss for child in children)
info = { info = {
"timestamp": time.strftime("%H:%M:%S"), "timestamp": time.strftime("%H:%M:%S"),
"rss_mb": mem_info.rss / 1024 / 1024, "rss_mb": mem_info.rss / 1024 / 1024,
@@ -87,24 +85,24 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
"children_count": len(children), "children_count": len(children),
"children_mem_mb": children_mem / 1024 / 1024, "children_mem_mb": children_mem / 1024 / 1024,
} }
history.append(info) history.append(info)
iteration += 1 iteration += 1
print(f"{'=' * 80}") print(f"{'=' * 80}")
print(f"检查点 #{iteration} - {info['timestamp']}") print(f"检查点 #{iteration} - {info['timestamp']}")
print(f"Bot 进程 (PID: {bot_process.pid})") print(f"Bot 进程 (PID: {bot_process.pid})")
print(f" RSS: {info['rss_mb']:.2f} MB") print(f" RSS: {info['rss_mb']:.2f} MB")
print(f" VMS: {info['vms_mb']:.2f} MB") print(f" VMS: {info['vms_mb']:.2f} MB")
print(f" 占比: {info['percent']:.2f}%") print(f" 占比: {info['percent']:.2f}%")
if children: if children:
print(f" 子进程: {info['children_count']}") print(f" 子进程: {info['children_count']}")
print(f" 子进程内存: {info['children_mem_mb']:.2f} MB") print(f" 子进程内存: {info['children_mem_mb']:.2f} MB")
total_mem = info['rss_mb'] + info['children_mem_mb'] total_mem = info["rss_mb"] + info["children_mem_mb"]
print(f" 总内存: {total_mem:.2f} MB") print(f" 总内存: {total_mem:.2f} MB")
print(f"\n 📋 子进程详情:") print("\n 📋 子进程详情:")
for idx, child in enumerate(children, 1): for idx, child in enumerate(children, 1):
try: try:
child_mem = child.memory_info().rss / 1024 / 1024 child_mem = child.memory_info().rss / 1024 / 1024
@@ -116,30 +114,30 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
print(f" 命令: {child_cmdline}") print(f" 命令: {child_cmdline}")
except (psutil.NoSuchProcess, psutil.AccessDenied): except (psutil.NoSuchProcess, psutil.AccessDenied):
print(f" [{idx}] 无法访问进程信息") print(f" [{idx}] 无法访问进程信息")
if len(history) > 1: if len(history) > 1:
prev = history[-2] prev = history[-2]
rss_diff = info['rss_mb'] - prev['rss_mb'] rss_diff = info["rss_mb"] - prev["rss_mb"]
print(f"\n变化:") print("\n变化:")
print(f" RSS: {rss_diff:+.2f} MB") print(f" RSS: {rss_diff:+.2f} MB")
if rss_diff > 10: if rss_diff > 10:
print(f" ⚠️ 内存增长较快!") print(" ⚠️ 内存增长较快!")
if info['rss_mb'] > 1000: if info["rss_mb"] > 1000:
print(f" ⚠️ 内存使用超过 1GB") print(" ⚠️ 内存使用超过 1GB")
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
await asyncio.sleep(interval) await asyncio.sleep(interval)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
print("\n❌ Bot 进程已结束") print("\n❌ Bot 进程已结束")
break break
except Exception as e: except Exception as e:
print(f"\n❌ 监控出错: {e}") print(f"\n❌ 监控出错: {e}")
break break
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n⚠️ 用户中断监控") print("\n\n⚠️ 用户中断监控")
finally: finally:
if history and bot_process.pid: if history and bot_process.pid:
save_process_history(history, bot_process.pid) save_process_history(history, bot_process.pid)
@@ -149,25 +147,25 @@ def save_process_history(history: list, pid: int):
"""保存进程监控历史""" """保存进程监控历史"""
output_dir = Path("data/memory_diagnostics") output_dir = Path("data/memory_diagnostics")
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = output_dir / f"process_monitor_{timestamp}_pid{pid}.txt" output_file = output_dir / f"process_monitor_{timestamp}_pid{pid}.txt"
with open(output_file, "w", encoding="utf-8") as f: with open(output_file, "w", encoding="utf-8") as f:
f.write("Bot 进程内存监控历史记录\n") f.write("Bot 进程内存监控历史记录\n")
f.write("=" * 80 + "\n\n") f.write("=" * 80 + "\n\n")
f.write(f"Bot PID: {pid}\n\n") f.write(f"Bot PID: {pid}\n\n")
for info in history: for info in history:
f.write(f"时间: {info['timestamp']}\n") f.write(f"时间: {info['timestamp']}\n")
f.write(f"RSS: {info['rss_mb']:.2f} MB\n") f.write(f"RSS: {info['rss_mb']:.2f} MB\n")
f.write(f"VMS: {info['vms_mb']:.2f} MB\n") f.write(f"VMS: {info['vms_mb']:.2f} MB\n")
f.write(f"占比: {info['percent']:.2f}%\n") f.write(f"占比: {info['percent']:.2f}%\n")
if info['children_count'] > 0: if info["children_count"] > 0:
f.write(f"子进程: {info['children_count']}\n") f.write(f"子进程: {info['children_count']}\n")
f.write(f"子进程内存: {info['children_mem_mb']:.2f} MB\n") f.write(f"子进程内存: {info['children_mem_mb']:.2f} MB\n")
f.write("\n") f.write("\n")
print(f"\n✅ 监控历史已保存到: {output_file}") print(f"\n✅ 监控历史已保存到: {output_file}")
@@ -182,28 +180,28 @@ async def run_monitor_mode(interval: int):
print(" 3. 显示子进程详细信息") print(" 3. 显示子进程详细信息")
print(" 4. 自动保存监控历史") print(" 4. 自动保存监控历史")
print("=" * 80 + "\n") print("=" * 80 + "\n")
project_root = Path(__file__).parent.parent project_root = Path(__file__).parent.parent
bot_file = project_root / "bot.py" bot_file = project_root / "bot.py"
if not bot_file.exists(): if not bot_file.exists():
print(f"❌ 找不到 bot.py: {bot_file}") print(f"❌ 找不到 bot.py: {bot_file}")
return 1 return 1
# 检测虚拟环境 # 检测虚拟环境
venv_python = project_root / ".venv" / "Scripts" / "python.exe" venv_python = project_root / ".venv" / "Scripts" / "python.exe"
if not venv_python.exists(): if not venv_python.exists():
venv_python = project_root / ".venv" / "bin" / "python" venv_python = project_root / ".venv" / "bin" / "python"
if venv_python.exists(): if venv_python.exists():
python_exe = str(venv_python) python_exe = str(venv_python)
print(f"🐍 使用虚拟环境: {venv_python}") print(f"🐍 使用虚拟环境: {venv_python}")
else: else:
python_exe = sys.executable python_exe = sys.executable
print(f"⚠️ 未找到虚拟环境,使用当前 Python: {python_exe}") print(f"⚠️ 未找到虚拟环境,使用当前 Python: {python_exe}")
print(f"🤖 启动 Bot: {bot_file}") print(f"🤖 启动 Bot: {bot_file}")
bot_process = subprocess.Popen( bot_process = subprocess.Popen(
[python_exe, str(bot_file)], [python_exe, str(bot_file)],
cwd=str(project_root), cwd=str(project_root),
@@ -212,9 +210,9 @@ async def run_monitor_mode(interval: int):
text=True, text=True,
bufsize=1, bufsize=1,
) )
await asyncio.sleep(2) await asyncio.sleep(2)
if bot_process.poll() is not None: if bot_process.poll() is not None:
print("❌ Bot 启动失败") print("❌ Bot 启动失败")
if bot_process.stdout: if bot_process.stdout:
@@ -222,9 +220,9 @@ async def run_monitor_mode(interval: int):
if output: if output:
print(f"\nBot 输出:\n{output}") print(f"\nBot 输出:\n{output}")
return 1 return 1
print(f"✅ Bot 已启动 (PID: {bot_process.pid})\n") print(f"✅ Bot 已启动 (PID: {bot_process.pid})\n")
# 启动输出读取线程 # 启动输出读取线程
def read_bot_output(): def read_bot_output():
if bot_process.stdout: if bot_process.stdout:
@@ -233,15 +231,15 @@ async def run_monitor_mode(interval: int):
print(f"[Bot] {line}", end="") print(f"[Bot] {line}", end="")
except Exception: except Exception:
pass pass
output_thread = threading.Thread(target=read_bot_output, daemon=True) output_thread = threading.Thread(target=read_bot_output, daemon=True)
output_thread.start() output_thread.start()
try: try:
await monitor_bot_process(bot_process, interval) await monitor_bot_process(bot_process, interval)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n⚠️ 用户中断") print("\n\n⚠️ 用户中断")
if bot_process.poll() is None: if bot_process.poll() is None:
print("\n正在停止 Bot...") print("\n正在停止 Bot...")
bot_process.terminate() bot_process.terminate()
@@ -251,9 +249,9 @@ async def run_monitor_mode(interval: int):
print("⚠️ 强制终止 Bot...") print("⚠️ 强制终止 Bot...")
bot_process.kill() bot_process.kill()
bot_process.wait() bot_process.wait()
print("✅ Bot 已停止") print("✅ Bot 已停止")
return 0 return 0
@@ -263,8 +261,8 @@ async def run_monitor_mode(interval: int):
class ObjectMemoryProfiler: class ObjectMemoryProfiler:
"""对象级内存分析器""" """对象级内存分析器"""
def __init__(self, interval: int = 10, output_file: Optional[str] = None, object_limit: int = 20): def __init__(self, interval: int = 10, output_file: str | None = None, object_limit: int = 20):
self.interval = interval self.interval = interval
self.output_file = output_file self.output_file = output_file
self.object_limit = object_limit self.object_limit = object_limit
@@ -273,23 +271,23 @@ class ObjectMemoryProfiler:
if PYMPLER_AVAILABLE: if PYMPLER_AVAILABLE:
self.tracker = tracker.SummaryTracker() self.tracker = tracker.SummaryTracker()
self.iteration = 0 self.iteration = 0
def get_object_stats(self) -> Dict: def get_object_stats(self) -> dict:
"""获取当前进程的对象统计(所有线程)""" """获取当前进程的对象统计(所有线程)"""
if not PYMPLER_AVAILABLE: if not PYMPLER_AVAILABLE:
return {} return {}
try: try:
gc.collect() gc.collect()
all_objects = muppy.get_objects() all_objects = muppy.get_objects()
sum_data = summary.summarize(all_objects) sum_data = summary.summarize(all_objects)
# 按总大小第3个元素降序排序 # 按总大小第3个元素降序排序
sorted_sum_data = sorted(sum_data, key=lambda x: x[2], reverse=True) sorted_sum_data = sorted(sum_data, key=lambda x: x[2], reverse=True)
# 按模块统计内存 # 按模块统计内存
module_stats = self._get_module_stats(all_objects) module_stats = self._get_module_stats(all_objects)
threads = threading.enumerate() threads = threading.enumerate()
thread_info = [ thread_info = [
{ {
@@ -299,13 +297,13 @@ class ObjectMemoryProfiler:
} }
for t in threads for t in threads
] ]
gc_stats = { gc_stats = {
"collections": gc.get_count(), "collections": gc.get_count(),
"garbage": len(gc.garbage), "garbage": len(gc.garbage),
"tracked": len(gc.get_objects()), "tracked": len(gc.get_objects()),
} }
return { return {
"summary": sorted_sum_data[:self.object_limit], "summary": sorted_sum_data[:self.object_limit],
"module_stats": module_stats, "module_stats": module_stats,
@@ -316,52 +314,52 @@ class ObjectMemoryProfiler:
except Exception as e: except Exception as e:
print(f"❌ 获取对象统计失败: {e}") print(f"❌ 获取对象统计失败: {e}")
return {} return {}
def _get_module_stats(self, all_objects: list) -> Dict: def _get_module_stats(self, all_objects: list) -> dict:
"""统计各模块的内存占用""" """统计各模块的内存占用"""
module_mem = defaultdict(lambda: {"count": 0, "size": 0}) module_mem = defaultdict(lambda: {"count": 0, "size": 0})
for obj in all_objects: for obj in all_objects:
try: try:
# 获取对象所属模块 # 获取对象所属模块
obj_type = type(obj) obj_type = type(obj)
module_name = obj_type.__module__ module_name = obj_type.__module__
if module_name: if module_name:
# 获取顶级模块名(例如 src.chat.xxx -> src # 获取顶级模块名(例如 src.chat.xxx -> src
top_module = module_name.split('.')[0] top_module = module_name.split(".")[0]
obj_size = sys.getsizeof(obj) obj_size = sys.getsizeof(obj)
module_mem[top_module]["count"] += 1 module_mem[top_module]["count"] += 1
module_mem[top_module]["size"] += obj_size module_mem[top_module]["size"] += obj_size
except Exception: except Exception:
# 忽略无法获取大小的对象 # 忽略无法获取大小的对象
continue continue
# 转换为列表并按大小排序 # 转换为列表并按大小排序
sorted_modules = sorted( sorted_modules = sorted(
[(mod, stats["count"], stats["size"]) [(mod, stats["count"], stats["size"])
for mod, stats in module_mem.items()], for mod, stats in module_mem.items()],
key=lambda x: x[2], key=lambda x: x[2],
reverse=True reverse=True
) )
return { return {
"top_modules": sorted_modules[:20], # 前20个模块 "top_modules": sorted_modules[:20], # 前20个模块
"total_modules": len(module_mem) "total_modules": len(module_mem)
} }
def print_stats(self, stats: Dict, iteration: int): def print_stats(self, stats: dict, iteration: int):
"""打印统计信息""" """打印统计信息"""
print("\n" + "=" * 80) print("\n" + "=" * 80)
print(f"🔍 对象级内存分析 #{iteration} - {time.strftime('%H:%M:%S')}") print(f"🔍 对象级内存分析 #{iteration} - {time.strftime('%H:%M:%S')}")
print("=" * 80) print("=" * 80)
if "summary" in stats: if "summary" in stats:
print(f"\n📦 对象统计 (前 {self.object_limit} 个类型):\n") print(f"\n📦 对象统计 (前 {self.object_limit} 个类型):\n")
print(f"{'类型':<50} {'数量':>12} {'总大小':>15}") print(f"{'类型':<50} {'数量':>12} {'总大小':>15}")
print("-" * 80) print("-" * 80)
for obj_type, obj_count, obj_size in stats["summary"]: for obj_type, obj_count, obj_size in stats["summary"]:
if obj_size >= 1024 * 1024 * 1024: if obj_size >= 1024 * 1024 * 1024:
size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB" size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB"
@@ -371,14 +369,14 @@ class ObjectMemoryProfiler:
size_str = f"{obj_size / 1024:.2f} KB" size_str = f"{obj_size / 1024:.2f} KB"
else: else:
size_str = f"{obj_size} B" size_str = f"{obj_size} B"
print(f"{obj_type:<50} {obj_count:>12,} {size_str:>15}") print(f"{obj_type:<50} {obj_count:>12,} {size_str:>15}")
if "module_stats" in stats and stats["module_stats"]: if stats.get("module_stats"):
print(f"\n📚 模块内存占用 (前 20 个模块):\n") print("\n📚 模块内存占用 (前 20 个模块):\n")
print(f"{'模块名':<40} {'对象数':>12} {'总内存':>15}") print(f"{'模块名':<40} {'对象数':>12} {'总内存':>15}")
print("-" * 80) print("-" * 80)
for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]: for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]:
if obj_size >= 1024 * 1024 * 1024: if obj_size >= 1024 * 1024 * 1024:
size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB" size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB"
@@ -388,46 +386,46 @@ class ObjectMemoryProfiler:
size_str = f"{obj_size / 1024:.2f} KB" size_str = f"{obj_size / 1024:.2f} KB"
else: else:
size_str = f"{obj_size} B" size_str = f"{obj_size} B"
print(f"{module_name:<40} {obj_count:>12,} {size_str:>15}") print(f"{module_name:<40} {obj_count:>12,} {size_str:>15}")
print(f"\n 总模块数: {stats['module_stats']['total_modules']}") print(f"\n 总模块数: {stats['module_stats']['total_modules']}")
if "threads" in stats: if "threads" in stats:
print(f"\n🧵 线程信息 ({len(stats['threads'])} 个):") print(f"\n🧵 线程信息 ({len(stats['threads'])} 个):")
for idx, t in enumerate(stats["threads"], 1): for idx, t in enumerate(stats["threads"], 1):
status = "" if t["alive"] else "" status = "" if t["alive"] else ""
daemon = "(守护)" if t["daemon"] else "" daemon = "(守护)" if t["daemon"] else ""
print(f" [{idx}] {status} {t['name']} {daemon}") print(f" [{idx}] {status} {t['name']} {daemon}")
if "gc_stats" in stats: if "gc_stats" in stats:
gc_stats = stats["gc_stats"] gc_stats = stats["gc_stats"]
print(f"\n🗑️ 垃圾回收:") print("\n🗑️ 垃圾回收:")
print(f" 代 0: {gc_stats['collections'][0]:,}") print(f" 代 0: {gc_stats['collections'][0]:,}")
print(f" 代 1: {gc_stats['collections'][1]:,}") print(f" 代 1: {gc_stats['collections'][1]:,}")
print(f" 代 2: {gc_stats['collections'][2]:,}") print(f" 代 2: {gc_stats['collections'][2]:,}")
print(f" 追踪对象: {gc_stats['tracked']:,}") print(f" 追踪对象: {gc_stats['tracked']:,}")
if "total_objects" in stats: if "total_objects" in stats:
print(f"\n📊 总对象数: {stats['total_objects']:,}") print(f"\n📊 总对象数: {stats['total_objects']:,}")
print("=" * 80 + "\n") print("=" * 80 + "\n")
def print_diff(self): def print_diff(self):
"""打印对象变化""" """打印对象变化"""
if not PYMPLER_AVAILABLE or not self.tracker: if not PYMPLER_AVAILABLE or not self.tracker:
return return
print("\n📈 对象变化分析:") print("\n📈 对象变化分析:")
print("-" * 80) print("-" * 80)
self.tracker.print_diff() self.tracker.print_diff()
print("-" * 80) print("-" * 80)
def save_to_file(self, stats: Dict): def save_to_file(self, stats: dict):
"""保存统计信息到文件""" """保存统计信息到文件"""
if not self.output_file: if not self.output_file:
return return
try: try:
# 保存文本 # 保存文本
with open(self.output_file, "a", encoding="utf-8") as f: with open(self.output_file, "a", encoding="utf-8") as f:
@@ -435,91 +433,91 @@ class ObjectMemoryProfiler:
f.write(f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"迭代: #{self.iteration}\n") f.write(f"迭代: #{self.iteration}\n")
f.write(f"{'=' * 80}\n\n") f.write(f"{'=' * 80}\n\n")
if "summary" in stats: if "summary" in stats:
f.write("对象统计:\n") f.write("对象统计:\n")
for obj_type, obj_count, obj_size in stats["summary"]: for obj_type, obj_count, obj_size in stats["summary"]:
f.write(f" {obj_type}: {obj_count:,} 个, {obj_size:,} 字节\n") f.write(f" {obj_type}: {obj_count:,} 个, {obj_size:,} 字节\n")
if "module_stats" in stats and stats["module_stats"]: if stats.get("module_stats"):
f.write("\n模块统计 (前 20 个):\n") f.write("\n模块统计 (前 20 个):\n")
for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]: for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]:
f.write(f" {module_name}: {obj_count:,} 个对象, {obj_size:,} 字节\n") f.write(f" {module_name}: {obj_count:,} 个对象, {obj_size:,} 字节\n")
f.write(f"\n总对象数: {stats.get('total_objects', 0):,}\n") f.write(f"\n总对象数: {stats.get('total_objects', 0):,}\n")
f.write(f"线程数: {len(stats.get('threads', []))}\n") f.write(f"线程数: {len(stats.get('threads', []))}\n")
# 保存 JSONL # 保存 JSONL
jsonl_path = str(self.output_file) + ".jsonl" jsonl_path = str(self.output_file) + ".jsonl"
record = { record = {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"iteration": self.iteration, "iteration": self.iteration,
"total_objects": stats.get("total_objects", 0), "total_objects": stats.get("total_objects", 0),
"threads": stats.get("threads", []), "threads": stats.get("threads", []),
"gc_stats": stats.get("gc_stats", {}), "gc_stats": stats.get("gc_stats", {}),
"summary": [ "summary": [
{"type": t, "count": c, "size": s} {"type": t, "count": c, "size": s}
for (t, c, s) in stats.get("summary", []) for (t, c, s) in stats.get("summary", [])
], ],
"module_stats": stats.get("module_stats", {}), "module_stats": stats.get("module_stats", {}),
} }
with open(jsonl_path, "a", encoding="utf-8") as jf: with open(jsonl_path, "a", encoding="utf-8") as jf:
jf.write(json.dumps(record, ensure_ascii=False) + "\n") jf.write(json.dumps(record, ensure_ascii=False) + "\n")
if self.iteration == 1: if self.iteration == 1:
print(f"💾 数据保存到: {self.output_file}") print(f"💾 数据保存到: {self.output_file}")
print(f"💾 结构化数据: {jsonl_path}") print(f"💾 结构化数据: {jsonl_path}")
except Exception as e: except Exception as e:
print(f"⚠️ 保存文件失败: {e}") print(f"⚠️ 保存文件失败: {e}")
def start_monitoring(self): def start_monitoring(self):
"""启动监控线程""" """启动监控线程"""
self.running = True self.running = True
def monitor_loop(): def monitor_loop():
print(f"🚀 对象分析器已启动") print("🚀 对象分析器已启动")
print(f" 监控间隔: {self.interval}") print(f" 监控间隔: {self.interval}")
print(f" 对象类型限制: {self.object_limit}") print(f" 对象类型限制: {self.object_limit}")
print(f" 输出文件: {self.output_file or ''}") print(f" 输出文件: {self.output_file or ''}")
print() print()
while self.running: while self.running:
try: try:
self.iteration += 1 self.iteration += 1
stats = self.get_object_stats() stats = self.get_object_stats()
self.print_stats(stats, self.iteration) self.print_stats(stats, self.iteration)
if self.iteration % 3 == 0 and self.tracker: if self.iteration % 3 == 0 and self.tracker:
self.print_diff() self.print_diff()
if self.output_file: if self.output_file:
self.save_to_file(stats) self.save_to_file(stats)
time.sleep(self.interval) time.sleep(self.interval)
except Exception as e: except Exception as e:
print(f"❌ 监控出错: {e}") print(f"❌ 监控出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
monitor_thread = threading.Thread(target=monitor_loop, daemon=True) monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
monitor_thread.start() monitor_thread.start()
print(f"✓ 监控线程已启动\n") print("✓ 监控线程已启动\n")
def stop(self): def stop(self):
"""停止监控""" """停止监控"""
self.running = False self.running = False
def run_objects_mode(interval: int, output: Optional[str], object_limit: int): def run_objects_mode(interval: int, output: str | None, object_limit: int):
"""对象分析模式主函数""" """对象分析模式主函数"""
if not PYMPLER_AVAILABLE: if not PYMPLER_AVAILABLE:
print("❌ pympler 未安装,无法使用对象分析模式") print("❌ pympler 未安装,无法使用对象分析模式")
print(" 安装: pip install pympler") print(" 安装: pip install pympler")
return 1 return 1
print("=" * 80) print("=" * 80)
print("🔬 对象分析模式") print("🔬 对象分析模式")
print("=" * 80) print("=" * 80)
@@ -529,38 +527,38 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
print(" 3. 显示对象变化diff") print(" 3. 显示对象变化diff")
print(" 4. 保存 JSONL 数据用于可视化") print(" 4. 保存 JSONL 数据用于可视化")
print("=" * 80 + "\n") print("=" * 80 + "\n")
# 添加项目根目录到 Python 路径 # 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent project_root = Path(__file__).parent.parent
if str(project_root) not in sys.path: if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
print(f"✓ 已添加项目根目录到 Python 路径: {project_root}\n") print(f"✓ 已添加项目根目录到 Python 路径: {project_root}\n")
profiler = ObjectMemoryProfiler( profiler = ObjectMemoryProfiler(
interval=interval, interval=interval,
output_file=output, output_file=output,
object_limit=object_limit object_limit=object_limit
) )
profiler.start_monitoring() profiler.start_monitoring()
print("🤖 正在启动 Bot...\n") print("🤖 正在启动 Bot...\n")
try: try:
import bot import bot
if hasattr(bot, 'main_async'): if hasattr(bot, "main_async"):
asyncio.run(bot.main_async()) asyncio.run(bot.main_async())
elif hasattr(bot, 'main'): elif hasattr(bot, "main"):
bot.main() bot.main()
else: else:
print("⚠️ bot.py 未找到 main_async() 或 main() 函数") print("⚠️ bot.py 未找到 main_async() 或 main() 函数")
print(" Bot 模块已导入,监控线程在后台运行") print(" Bot 模块已导入,监控线程在后台运行")
print(" 按 Ctrl+C 停止\n") print(" 按 Ctrl+C 停止\n")
while profiler.running: while profiler.running:
time.sleep(1) time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n⚠️ 用户中断") print("\n\n⚠️ 用户中断")
except Exception as e: except Exception as e:
@@ -569,7 +567,7 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
traceback.print_exc() traceback.print_exc()
finally: finally:
profiler.stop() profiler.stop()
return 0 return 0
@@ -577,10 +575,10 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
# 可视化模式 # 可视化模式
# ============================================================================ # ============================================================================
def load_jsonl(path: Path) -> List[Dict]: def load_jsonl(path: Path) -> list[dict]:
"""加载 JSONL 文件""" """加载 JSONL 文件"""
snapshots = [] snapshots = []
with open(path, "r", encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line: if not line:
@@ -592,7 +590,7 @@ def load_jsonl(path: Path) -> List[Dict]:
return snapshots return snapshots
def aggregate_top_types(snapshots: List[Dict], top_n: int = 10): def aggregate_top_types(snapshots: list[dict], top_n: int = 10):
"""聚合前 N 个对象类型的时间序列""" """聚合前 N 个对象类型的时间序列"""
type_max = defaultdict(int) type_max = defaultdict(int)
for snap in snapshots: for snap in snapshots:
@@ -600,37 +598,37 @@ def aggregate_top_types(snapshots: List[Dict], top_n: int = 10):
t = item.get("type") t = item.get("type")
s = int(item.get("size", 0)) s = int(item.get("size", 0))
type_max[t] = max(type_max[t], s) type_max[t] = max(type_max[t], s)
top_types = sorted(type_max.items(), key=lambda kv: kv[1], reverse=True)[:top_n] top_types = sorted(type_max.items(), key=lambda kv: kv[1], reverse=True)[:top_n]
top_names = [t for t, _ in top_types] top_names = [t for t, _ in top_types]
times = [] times = []
series = {t: [] for t in top_names} series = {t: [] for t in top_names}
for snap in snapshots: for snap in snapshots:
ts = snap.get("timestamp") ts = snap.get("timestamp")
try: try:
times.append(datetime.strptime(ts, "%Y-%m-%d %H:%M:%S")) times.append(datetime.strptime(ts, "%Y-%m-%d %H:%M:%S"))
except Exception: except Exception:
times.append(None) times.append(None)
summary = {item.get("type"): int(item.get("size", 0)) summary = {item.get("type"): int(item.get("size", 0))
for item in snap.get("summary", [])} for item in snap.get("summary", [])}
for t in top_names: for t in top_names:
series[t].append(summary.get(t, 0) / 1024.0 / 1024.0) series[t].append(summary.get(t, 0) / 1024.0 / 1024.0)
return times, series return times, series
def plot_series(times: List, series: Dict, output: Path, top_n: int): def plot_series(times: list, series: dict, output: Path, top_n: int):
"""绘制时间序列图""" """绘制时间序列图"""
plt.figure(figsize=(14, 8)) plt.figure(figsize=(14, 8))
for name, values in series.items(): for name, values in series.items():
if all(v == 0 for v in values): if all(v == 0 for v in values):
continue continue
plt.plot(times, values, marker="o", label=name, linewidth=2) plt.plot(times, values, marker="o", label=name, linewidth=2)
plt.xlabel("时间", fontsize=12) plt.xlabel("时间", fontsize=12)
plt.ylabel("内存 (MB)", fontsize=12) plt.ylabel("内存 (MB)", fontsize=12)
plt.title(f"对象类型随时间的内存占用 (前 {top_n} 类型)", fontsize=14) plt.title(f"对象类型随时间的内存占用 (前 {top_n} 类型)", fontsize=14)
@@ -647,31 +645,31 @@ def run_visualize_mode(input_file: str, output_file: str, top: int):
print("❌ matplotlib 未安装,无法使用可视化模式") print("❌ matplotlib 未安装,无法使用可视化模式")
print(" 安装: pip install matplotlib") print(" 安装: pip install matplotlib")
return 1 return 1
print("=" * 80) print("=" * 80)
print("📊 可视化模式") print("📊 可视化模式")
print("=" * 80) print("=" * 80)
path = Path(input_file) path = Path(input_file)
if not path.exists(): if not path.exists():
print(f"❌ 找不到输入文件: {path}") print(f"❌ 找不到输入文件: {path}")
return 1 return 1
print(f"📂 读取数据: {path}") print(f"📂 读取数据: {path}")
snaps = load_jsonl(path) snaps = load_jsonl(path)
if not snaps: if not snaps:
print("❌ 未读取到任何快照数据") print("❌ 未读取到任何快照数据")
return 1 return 1
print(f"✓ 读取 {len(snaps)} 个快照") print(f"✓ 读取 {len(snaps)} 个快照")
times, series = aggregate_top_types(snaps, top_n=top) times, series = aggregate_top_types(snaps, top_n=top)
print(f"✓ 提取前 {top} 个对象类型") print(f"✓ 提取前 {top} 个对象类型")
output_path = Path(output_file) output_path = Path(output_file)
plot_series(times, series, output_path, top) plot_series(times, series, output_path, top)
return 0 return 0
@@ -693,10 +691,10 @@ def main():
使用示例: 使用示例:
# 进程监控(启动 bot 并监控) # 进程监控(启动 bot 并监控)
python scripts/memory_profiler.py --monitor --interval 10 python scripts/memory_profiler.py --monitor --interval 10
# 对象分析(深度对象统计) # 对象分析(深度对象统计)
python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt
# 生成可视化图表 # 生成可视化图表
python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15 --output plot.png python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15 --output plot.png
@@ -705,26 +703,26 @@ def main():
- 可视化模式需要: pip install matplotlib - 可视化模式需要: pip install matplotlib
""", """,
) )
# 模式选择 # 模式选择
mode_group = parser.add_mutually_exclusive_group(required=True) mode_group = parser.add_mutually_exclusive_group(required=True)
mode_group.add_argument("--monitor", "-m", action="store_true", mode_group.add_argument("--monitor", "-m", action="store_true",
help="进程监控模式(外部监控 bot 进程)") help="进程监控模式(外部监控 bot 进程)")
mode_group.add_argument("--objects", "-o", action="store_true", mode_group.add_argument("--objects", "-o", action="store_true",
help="对象分析模式(内部统计所有对象)") help="对象分析模式(内部统计所有对象)")
mode_group.add_argument("--visualize", "-v", action="store_true", mode_group.add_argument("--visualize", "-v", action="store_true",
help="可视化模式(绘制 JSONL 数据)") help="可视化模式(绘制 JSONL 数据)")
# 通用参数 # 通用参数
parser.add_argument("--interval", "-i", type=int, default=10, parser.add_argument("--interval", "-i", type=int, default=10,
help="监控间隔(秒),默认 10") help="监控间隔(秒),默认 10")
# 对象分析参数 # 对象分析参数
parser.add_argument("--output", type=str, parser.add_argument("--output", type=str,
help="输出文件路径(对象分析模式)") help="输出文件路径(对象分析模式)")
parser.add_argument("--object-limit", "-l", type=int, default=20, parser.add_argument("--object-limit", "-l", type=int, default=20,
help="对象类型显示数量,默认 20") help="对象类型显示数量,默认 20")
# 可视化参数 # 可视化参数
parser.add_argument("--input", type=str, parser.add_argument("--input", type=str,
help="输入 JSONL 文件(可视化模式)") help="输入 JSONL 文件(可视化模式)")
@@ -732,24 +730,24 @@ def main():
help="展示前 N 个类型(可视化模式),默认 10") help="展示前 N 个类型(可视化模式),默认 10")
parser.add_argument("--plot-output", type=str, default="memory_analysis_plot.png", parser.add_argument("--plot-output", type=str, default="memory_analysis_plot.png",
help="图表输出文件,默认 memory_analysis_plot.png") help="图表输出文件,默认 memory_analysis_plot.png")
args = parser.parse_args() args = parser.parse_args()
# 根据模式执行 # 根据模式执行
if args.monitor: if args.monitor:
return asyncio.run(run_monitor_mode(args.interval)) return asyncio.run(run_monitor_mode(args.interval))
elif args.objects: elif args.objects:
if not args.output: if not args.output:
print("⚠️ 建议使用 --output 指定输出文件以保存数据") print("⚠️ 建议使用 --output 指定输出文件以保存数据")
return run_objects_mode(args.interval, args.output, args.object_limit) return run_objects_mode(args.interval, args.output, args.object_limit)
elif args.visualize: elif args.visualize:
if not args.input: if not args.input:
print("❌ 可视化模式需要 --input 参数指定 JSONL 文件") print("❌ 可视化模式需要 --input 参数指定 JSONL 文件")
return 1 return 1
return run_visualize_mode(args.input, args.plot_output, args.top) return run_visualize_mode(args.input, args.plot_output, args.top)
return 0 return 0

View File

@@ -0,0 +1,361 @@
"""
记忆图可视化 - API 路由模块
提供 Web API 用于可视化记忆图数据
"""
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
import orjson
from fastapi import APIRouter, HTTPException, Request
from fastapi.responses import HTMLResponse, JSONResponse
from fastapi.templating import Jinja2Templates
# 调整项目根目录的计算方式
project_root = Path(__file__).parent.parent.parent
data_dir = project_root / "data" / "memory_graph"
# 缓存
graph_data_cache = None
current_data_file = None
# FastAPI 路由
router = APIRouter()
# Jinja2 模板引擎
templates = Jinja2Templates(directory=str(Path(__file__).parent / "templates"))
def find_available_data_files() -> List[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
return files
possible_files = ["graph_store.json", "memory_graph.json", "graph_data.json"]
for filename in possible_files:
file_path = data_dir / filename
if file_path.exists():
files.append(file_path)
for pattern in ["graph_store_*.json", "memory_graph_*.json", "graph_data_*.json"]:
for backup_file in data_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
backups_dir = data_dir / "backups"
if backups_dir.exists():
for backup_file in backups_dir.glob("**/*.json"):
if backup_file not in files:
files.append(backup_file)
backup_dir = data_dir.parent / "backup"
if backup_dir.exists():
for pattern in ["**/graph_*.json", "**/memory_*.json"]:
for backup_file in backup_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data_from_file(file_path: Optional[Path] = None) -> Dict[str, Any]:
"""从磁盘加载图数据"""
global graph_data_cache, current_data_file
if file_path and file_path != current_data_file:
graph_data_cache = None
current_data_file = file_path
if graph_data_cache:
return graph_data_cache
try:
graph_file = current_data_file
if not graph_file:
available_files = find_available_data_files()
if not available_files:
return {"error": "未找到数据文件", "nodes": [], "edges": [], "stats": {}}
graph_file = available_files[0]
current_data_file = graph_file
if not graph_file.exists():
return {"error": f"文件不存在: {graph_file}", "nodes": [], "edges": [], "stats": {}}
with open(graph_file, "r", encoding="utf-8") as f:
data = orjson.loads(f.read())
nodes = data.get("nodes", [])
edges = data.get("edges", [])
metadata = data.get("metadata", {})
nodes_dict = {
node["id"]: {
**node,
"label": node.get("content", ""),
"group": node.get("node_type", ""),
"title": f"{node.get('node_type', '')}: {node.get('content', '')}",
}
for node in nodes
if node.get("id")
}
edges_list = [
{
**edge,
"from": edge.get("source", edge.get("source_id")),
"to": edge.get("target", edge.get("target_id")),
"label": edge.get("relation", ""),
"arrows": "to",
}
for edge in edges
]
stats = metadata.get("statistics", {})
total_memories = stats.get("total_memories", 0)
graph_data_cache = {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": [],
"stats": {
"total_nodes": len(nodes_dict),
"total_edges": len(edges_list),
"total_memories": total_memories,
},
"current_file": str(graph_file),
"file_size": graph_file.stat().st_size,
"file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
}
return graph_data_cache
except Exception as e:
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"加载图数据失败: {e}")
@router.get("/", response_class=HTMLResponse)
async def index(request: Request):
"""主页面"""
return templates.TemplateResponse("visualizer.html", {"request": request})
def _format_graph_data_from_manager(memory_manager) -> Dict[str, Any]:
"""从 MemoryManager 提取并格式化图数据"""
if not memory_manager.graph_store:
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
all_memories = memory_manager.graph_store.get_all_memories()
nodes_dict = {}
edges_list = []
memory_info = []
for memory in all_memories:
memory_info.append(
{
"id": memory.id,
"type": memory.memory_type.value,
"importance": memory.importance,
"text": memory.to_text(),
}
)
for node in memory.nodes:
if node.id not in nodes_dict:
nodes_dict[node.id] = {
"id": node.id,
"label": node.content,
"type": node.node_type.value,
"group": node.node_type.name,
"title": f"{node.node_type.value}: {node.content}",
}
for edge in memory.edges:
edges_list.append( # noqa: PERF401
{
"id": edge.id,
"from": edge.source_id,
"to": edge.target_id,
"label": edge.relation,
"arrows": "to",
"memory_id": memory.id,
}
)
stats = memory_manager.get_statistics()
return {
"nodes": list(nodes_dict.values()),
"edges": edges_list,
"memories": memory_info,
"stats": {
"total_nodes": stats.get("total_nodes", 0),
"total_edges": stats.get("total_edges", 0),
"total_memories": stats.get("total_memories", 0),
},
"current_file": "memory_manager (实时数据)",
}
@router.get("/api/graph/full")
async def get_full_graph():
"""获取完整记忆图数据"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
data = {}
if memory_manager and memory_manager._initialized:
data = _format_graph_data_from_manager(memory_manager)
else:
# 如果内存管理器不可用,则从文件加载
data = load_graph_data_from_file()
return JSONResponse(content={"success": True, "data": data})
except Exception as e:
import traceback
traceback.print_exc()
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/files")
async def list_files_api():
"""列出所有可用的数据文件"""
try:
files = find_available_data_files()
file_list = []
for f in files:
stat = f.stat()
file_list.append(
{
"path": str(f),
"name": f.name,
"size": stat.st_size,
"size_kb": round(stat.st_size / 1024, 2),
"modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
"modified_readable": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
"is_current": str(f) == str(current_data_file) if current_data_file else False,
}
)
return JSONResponse(
content={
"success": True,
"files": file_list,
"count": len(file_list),
"current_file": str(current_data_file) if current_data_file else None,
}
)
except Exception as e:
# 增加日志记录
# logger.error(f"列出数据文件失败: {e}", exc_info=True)
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.post("/select_file")
async def select_file(request: Request):
"""选择要加载的数据文件"""
global graph_data_cache, current_data_file
try:
data = await request.json()
file_path = data.get("file_path")
if not file_path:
raise HTTPException(status_code=400, detail="未提供文件路径")
file_to_load = Path(file_path)
if not file_to_load.exists():
raise HTTPException(status_code=404, detail=f"文件不存在: {file_path}")
graph_data_cache = None
current_data_file = file_to_load
graph_data = load_graph_data_from_file(file_to_load)
return JSONResponse(
content={
"success": True,
"message": f"已切换到文件: {file_to_load.name}",
"stats": graph_data.get("stats", {}),
}
)
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/reload")
async def reload_data():
"""重新加载数据"""
global graph_data_cache
graph_data_cache = None
data = load_graph_data_from_file()
return JSONResponse(content={"success": True, "message": "数据已重新加载", "stats": data.get("stats", {})})
@router.get("/api/search")
async def search_memories(q: str, limit: int = 50):
"""搜索记忆"""
try:
from src.memory_graph.manager_singleton import get_memory_manager
memory_manager = get_memory_manager()
results = []
if memory_manager and memory_manager._initialized and memory_manager.graph_store:
# 从 memory_manager 搜索
all_memories = memory_manager.graph_store.get_all_memories()
for memory in all_memories:
if q.lower() in memory.to_text().lower():
results.append(
{
"id": memory.id,
"type": memory.memory_type.value,
"importance": memory.importance,
"text": memory.to_text(),
}
)
else:
# 从文件加载的数据中搜索 (降级方案)
data = load_graph_data_from_file()
for memory in data.get("memories", []):
if q.lower() in memory.get("text", "").lower():
results.append(memory)
return JSONResponse(
content={
"success": True,
"data": {
"results": results[:limit],
"count": len(results),
},
}
)
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)
@router.get("/api/stats")
async def get_statistics():
"""获取统计信息"""
try:
data = load_graph_data_from_file()
node_types = {}
memory_types = {}
for node in data["nodes"]:
node_type = node.get("type", "Unknown")
node_types[node_type] = node_types.get(node_type, 0) + 1
for memory in data.get("memories", []):
mem_type = memory.get("type", "Unknown")
memory_types[mem_type] = memory_types.get(mem_type, 0) + 1
stats = data.get("stats", {})
stats["node_types"] = node_types
stats["memory_types"] = memory_types
return JSONResponse(content={"success": True, "data": stats})
except Exception as e:
return JSONResponse(content={"success": False, "error": str(e)}, status_code=500)

View File

@@ -4,10 +4,10 @@ from typing import Any, Literal
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from src.common.database.compatibility import db_get from src.chat.utils.statistic import (
from src.common.database.core.models import LLMUsage StatisticOutputTask,
)
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import model_config
logger = get_logger("LLM统计API") logger = get_logger("LLM统计API")
@@ -37,108 +37,6 @@ COST_BY_USER = "costs_by_user"
COST_BY_MODEL = "costs_by_model" COST_BY_MODEL = "costs_by_model"
COST_BY_MODULE = "costs_by_module" COST_BY_MODULE = "costs_by_module"
async def _collect_stats_in_period(start_time: datetime, end_time: datetime) -> dict[str, Any]:
"""在指定时间段内收集LLM使用统计信息"""
records = await db_get(
model_class=LLMUsage,
filters={"timestamp": {"$gte": start_time, "$lt": end_time}},
)
if not records:
return {}
# 创建一个从 model_identifier 到 name 的映射
model_identifier_to_name_map = {model.model_identifier: model.name for model in model_config.models}
stats: dict[str, Any] = {
TOTAL_REQ_CNT: 0,
TOTAL_COST: 0.0,
REQ_CNT_BY_TYPE: defaultdict(int),
REQ_CNT_BY_USER: defaultdict(int),
REQ_CNT_BY_MODEL: defaultdict(int),
REQ_CNT_BY_MODULE: defaultdict(int),
IN_TOK_BY_TYPE: defaultdict(int),
IN_TOK_BY_USER: defaultdict(int),
IN_TOK_BY_MODEL: defaultdict(int),
IN_TOK_BY_MODULE: defaultdict(int),
OUT_TOK_BY_TYPE: defaultdict(int),
OUT_TOK_BY_USER: defaultdict(int),
OUT_TOK_BY_MODEL: defaultdict(int),
OUT_TOK_BY_MODULE: defaultdict(int),
TOTAL_TOK_BY_TYPE: defaultdict(int),
TOTAL_TOK_BY_USER: defaultdict(int),
TOTAL_TOK_BY_MODEL: defaultdict(int),
TOTAL_TOK_BY_MODULE: defaultdict(int),
COST_BY_TYPE: defaultdict(float),
COST_BY_USER: defaultdict(float),
COST_BY_MODEL: defaultdict(float),
COST_BY_MODULE: defaultdict(float),
}
for record in records:
if not isinstance(record, dict):
continue
stats[TOTAL_REQ_CNT] += 1
request_type = record.get("request_type") or "unknown"
user_id = record.get("user_id") or "unknown"
# 从数据库获取的是真实模型名 (model_identifier)
real_model_name = record.get("model_name") or "unknown"
module_name = request_type.split(".")[0] if "." in request_type else request_type
# 尝试通过真实模型名找到配置文件中的模型名
config_model_name = model_identifier_to_name_map.get(real_model_name, real_model_name)
prompt_tokens = record.get("prompt_tokens") or 0
completion_tokens = record.get("completion_tokens") or 0
total_tokens = prompt_tokens + completion_tokens
cost = 0.0
try:
# 使用配置文件中的模型名来获取模型信息
model_info = model_config.get_model_info(config_model_name)
if model_info:
input_cost = (prompt_tokens / 1000000) * model_info.price_in
output_cost = (completion_tokens / 1000000) * model_info.price_out
cost = round(input_cost + output_cost, 6)
except KeyError as e:
logger.info(str(e))
logger.warning(f"模型 '{config_model_name}' (真实名称: '{real_model_name}') 在配置中未找到,成本计算将使用默认值 0.0")
stats[TOTAL_COST] += cost
# 按类型统计
stats[REQ_CNT_BY_TYPE][request_type] += 1
stats[IN_TOK_BY_TYPE][request_type] += prompt_tokens
stats[OUT_TOK_BY_TYPE][request_type] += completion_tokens
stats[TOTAL_TOK_BY_TYPE][request_type] += total_tokens
stats[COST_BY_TYPE][request_type] += cost
# 按用户统计
stats[REQ_CNT_BY_USER][user_id] += 1
stats[IN_TOK_BY_USER][user_id] += prompt_tokens
stats[OUT_TOK_BY_USER][user_id] += completion_tokens
stats[TOTAL_TOK_BY_USER][user_id] += total_tokens
stats[COST_BY_USER][user_id] += cost
# 按模型统计 (使用配置文件中的名称)
stats[REQ_CNT_BY_MODEL][config_model_name] += 1
stats[IN_TOK_BY_MODEL][config_model_name] += prompt_tokens
stats[OUT_TOK_BY_MODEL][config_model_name] += completion_tokens
stats[TOTAL_TOK_BY_MODEL][config_model_name] += total_tokens
stats[COST_BY_MODEL][config_model_name] += cost
# 按模块统计
stats[REQ_CNT_BY_MODULE][module_name] += 1
stats[IN_TOK_BY_MODULE][module_name] += prompt_tokens
stats[OUT_TOK_BY_MODULE][module_name] += completion_tokens
stats[TOTAL_TOK_BY_MODULE][module_name] += total_tokens
stats[COST_BY_MODULE][module_name] += cost
return stats
@router.get("/llm/stats") @router.get("/llm/stats")
async def get_llm_stats( async def get_llm_stats(
period_type: Literal[ period_type: Literal[
@@ -179,7 +77,8 @@ async def get_llm_stats(
if start_time is None: if start_time is None:
raise HTTPException(status_code=400, detail="无法确定查询的起始时间") raise HTTPException(status_code=400, detail="无法确定查询的起始时间")
period_stats = await _collect_stats_in_period(start_time, end_time) stats_data = await StatisticOutputTask._collect_model_request_for_period([("custom", start_time)])
period_stats = stats_data.get("custom", {})
if not period_stats: if not period_stats:
return {"period": {"start": start_time.isoformat(), "end": end_time.isoformat()}, "data": {}} return {"period": {"start": start_time.isoformat(), "end": end_time.isoformat()}, "data": {}}

View File

@@ -658,7 +658,7 @@
try { try {
document.getElementById('loading').style.display = 'block'; document.getElementById('loading').style.display = 'block';
const response = await fetch('/api/graph/full'); const response = await fetch('/visualizer/api/graph/full');
const result = await response.json(); const result = await response.json();
if (result.success) { if (result.success) {
@@ -748,7 +748,7 @@
} }
try { try {
const response = await fetch(`/api/search?q=${encodeURIComponent(query)}&limit=50`); const response = await fetch(`/visualizer/api/search?q=${encodeURIComponent(query)}&limit=50`);
const result = await response.json(); const result = await response.json();
if (result.success) { if (result.success) {
@@ -1041,7 +1041,7 @@
// 文件选择功能 // 文件选择功能
async function loadFileList() { async function loadFileList() {
try { try {
const response = await fetch('/api/files'); const response = await fetch('/visualizer/api/files');
const result = await response.json(); const result = await response.json();
if (result.success) { if (result.success) {
@@ -1130,7 +1130,7 @@
document.getElementById('loading').style.display = 'block'; document.getElementById('loading').style.display = 'block';
closeFileSelector(); closeFileSelector();
const response = await fetch('/api/select_file', { const response = await fetch('/visualizer/api/select_file', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'

View File

@@ -680,9 +680,9 @@ class EmojiManager:
try: try:
# 🔧 使用 QueryBuilder 以启用数据库缓存 # 🔧 使用 QueryBuilder 以启用数据库缓存
from src.common.database.api.query import QueryBuilder from src.common.database.api.query import QueryBuilder
logger.debug("[数据库] 开始加载所有表情包记录 ...") logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_instances = await QueryBuilder(Emoji).all() emoji_instances = await QueryBuilder(Emoji).all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances) emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
@@ -802,7 +802,7 @@ class EmojiManager:
# 如果内存中没有,从数据库查找(使用 QueryBuilder 启用数据库缓存) # 如果内存中没有,从数据库查找(使用 QueryBuilder 启用数据库缓存)
try: try:
from src.common.database.api.query import QueryBuilder from src.common.database.api.query import QueryBuilder
emoji_record = await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first() emoji_record = await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first()
if emoji_record and emoji_record.description: if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
@@ -966,7 +966,7 @@ class EmojiManager:
existing_description = None existing_description = None
try: try:
from src.common.database.api.query import QueryBuilder from src.common.database.api.query import QueryBuilder
existing_image = await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first() existing_image = await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first()
if existing_image and existing_image.description: if existing_image and existing_image.description:
existing_description = existing_image.description existing_description = existing_image.description

View File

@@ -1,5 +1,4 @@
import os import os
import random
import time import time
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
@@ -135,20 +134,20 @@ class ExpressionLearner:
async def cleanup_expired_expressions(self, expiration_days: int | None = None) -> int: async def cleanup_expired_expressions(self, expiration_days: int | None = None) -> int:
""" """
清理过期的表达方式 清理过期的表达方式
Args: Args:
expiration_days: 过期天数,超过此天数未激活的表达方式将被删除(不指定则从配置读取) expiration_days: 过期天数,超过此天数未激活的表达方式将被删除(不指定则从配置读取)
Returns: Returns:
int: 删除的表达方式数量 int: 删除的表达方式数量
""" """
# 从配置读取过期天数 # 从配置读取过期天数
if expiration_days is None: if expiration_days is None:
expiration_days = global_config.expression.expiration_days expiration_days = global_config.expression.expiration_days
current_time = time.time() current_time = time.time()
expiration_threshold = current_time - (expiration_days * 24 * 3600) expiration_threshold = current_time - (expiration_days * 24 * 3600)
try: try:
deleted_count = 0 deleted_count = 0
async with get_db_session() as session: async with get_db_session() as session:
@@ -160,15 +159,15 @@ class ExpressionLearner:
) )
) )
expired_expressions = list(query.scalars()) expired_expressions = list(query.scalars())
if expired_expressions: if expired_expressions:
for expr in expired_expressions: for expr in expired_expressions:
await session.delete(expr) await session.delete(expr)
deleted_count += 1 deleted_count += 1
await session.commit() await session.commit()
logger.info(f"清理了 {deleted_count} 个过期表达方式(超过 {expiration_days} 天未使用)") logger.info(f"清理了 {deleted_count} 个过期表达方式(超过 {expiration_days} 天未使用)")
# 清除缓存 # 清除缓存
from src.common.database.optimization.cache_manager import get_cache from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key from src.common.database.utils.decorators import generate_cache_key
@@ -176,7 +175,7 @@ class ExpressionLearner:
await cache.delete(generate_cache_key("chat_expressions", self.chat_id)) await cache.delete(generate_cache_key("chat_expressions", self.chat_id))
else: else:
logger.debug(f"没有发现过期的表达方式(阈值:{expiration_days} 天)") logger.debug(f"没有发现过期的表达方式(阈值:{expiration_days} 天)")
return deleted_count return deleted_count
except Exception as e: except Exception as e:
logger.error(f"清理过期表达方式失败: {e}") logger.error(f"清理过期表达方式失败: {e}")
@@ -460,7 +459,7 @@ class ExpressionLearner:
) )
) )
same_situation_expr = query_same_situation.scalar() same_situation_expr = query_same_situation.scalar()
# 情况2相同 chat_id + type + style相同表达不同情景 # 情况2相同 chat_id + type + style相同表达不同情景
query_same_style = await session.execute( query_same_style = await session.execute(
select(Expression).where( select(Expression).where(
@@ -470,7 +469,7 @@ class ExpressionLearner:
) )
) )
same_style_expr = query_same_style.scalar() same_style_expr = query_same_style.scalar()
# 情况3完全相同相同情景+相同表达) # 情况3完全相同相同情景+相同表达)
query_exact_match = await session.execute( query_exact_match = await session.execute(
select(Expression).where( select(Expression).where(
@@ -481,7 +480,7 @@ class ExpressionLearner:
) )
) )
exact_match_expr = query_exact_match.scalar() exact_match_expr = query_exact_match.scalar()
# 优先处理完全匹配的情况 # 优先处理完全匹配的情况
if exact_match_expr: if exact_match_expr:
# 完全相同增加count更新时间 # 完全相同增加count更新时间

View File

@@ -72,21 +72,21 @@ class ExpressorModel:
是否删除成功 是否删除成功
""" """
removed = False removed = False
if cid in self._candidates: if cid in self._candidates:
del self._candidates[cid] del self._candidates[cid]
removed = True removed = True
if cid in self._situations: if cid in self._situations:
del self._situations[cid] del self._situations[cid]
# 从nb模型中删除 # 从nb模型中删除
if cid in self.nb.cls_counts: if cid in self.nb.cls_counts:
del self.nb.cls_counts[cid] del self.nb.cls_counts[cid]
if cid in self.nb.token_counts: if cid in self.nb.token_counts:
del self.nb.token_counts[cid] del self.nb.token_counts[cid]
return removed return removed
def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]: def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]:

View File

@@ -72,7 +72,7 @@ class StyleLearner:
# 检查是否需要清理 # 检查是否需要清理
current_count = len(self.style_to_id) current_count = len(self.style_to_id)
cleanup_trigger = int(self.max_styles * self.cleanup_threshold) cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
if current_count >= cleanup_trigger: if current_count >= cleanup_trigger:
if current_count >= self.max_styles: if current_count >= self.max_styles:
# 已经达到最大限制,必须清理 # 已经达到最大限制,必须清理
@@ -109,7 +109,7 @@ class StyleLearner:
def _cleanup_styles(self): def _cleanup_styles(self):
""" """
清理低价值的风格,为新风格腾出空间 清理低价值的风格,为新风格腾出空间
清理策略: 清理策略:
1. 综合考虑使用次数和最后使用时间 1. 综合考虑使用次数和最后使用时间
2. 删除得分最低的风格 2. 删除得分最低的风格
@@ -118,34 +118,34 @@ class StyleLearner:
try: try:
current_time = time.time() current_time = time.time()
cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio)) cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio))
# 计算每个风格的价值分数 # 计算每个风格的价值分数
style_scores = [] style_scores = []
for style_id in self.style_to_id.values(): for style_id in self.style_to_id.values():
# 使用次数 # 使用次数
usage_count = self.learning_stats["style_counts"].get(style_id, 0) usage_count = self.learning_stats["style_counts"].get(style_id, 0)
# 最后使用时间(越近越好) # 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0) last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float('inf') time_since_used = current_time - last_used if last_used > 0 else float("inf")
# 综合分数:使用次数越多越好,距离上次使用时间越短越好 # 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响 # 使用对数来平滑使用次数的影响
import math import math
usage_score = math.log1p(usage_count) # log(1 + count) usage_score = math.log1p(usage_count) # log(1 + count)
# 时间分数:转换为天数,使用指数衰减 # 时间分数:转换为天数,使用指数衰减
days_unused = time_since_used / 86400 # 转换为天 days_unused = time_since_used / 86400 # 转换为天
time_score = math.exp(-days_unused / 30) # 30天衰减因子 time_score = math.exp(-days_unused / 30) # 30天衰减因子
# 综合分数80%使用频率 + 20%时间新鲜度 # 综合分数80%使用频率 + 20%时间新鲜度
total_score = 0.8 * usage_score + 0.2 * time_score total_score = 0.8 * usage_score + 0.2 * time_score
style_scores.append((style_id, total_score, usage_count, days_unused)) style_scores.append((style_id, total_score, usage_count, days_unused))
# 按分数排序,分数低的先删除 # 按分数排序,分数低的先删除
style_scores.sort(key=lambda x: x[1]) style_scores.sort(key=lambda x: x[1])
# 删除分数最低的风格 # 删除分数最低的风格
deleted_styles = [] deleted_styles = []
for style_id, score, usage, days in style_scores[:cleanup_count]: for style_id, score, usage, days in style_scores[:cleanup_count]:
@@ -156,27 +156,27 @@ class StyleLearner:
del self.id_to_style[style_id] del self.id_to_style[style_id]
if style_id in self.id_to_situation: if style_id in self.id_to_situation:
del self.id_to_situation[style_id] del self.id_to_situation[style_id]
# 从统计中删除 # 从统计中删除
if style_id in self.learning_stats["style_counts"]: if style_id in self.learning_stats["style_counts"]:
del self.learning_stats["style_counts"][style_id] del self.learning_stats["style_counts"][style_id]
if style_id in self.learning_stats["style_last_used"]: if style_id in self.learning_stats["style_last_used"]:
del self.learning_stats["style_last_used"][style_id] del self.learning_stats["style_last_used"][style_id]
# 从expressor模型中删除 # 从expressor模型中删除
self.expressor.remove_candidate(style_id) self.expressor.remove_candidate(style_id)
deleted_styles.append((style_text[:30], usage, f"{days:.1f}")) deleted_styles.append((style_text[:30], usage, f"{days:.1f}"))
logger.info( logger.info(
f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格," f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格,"
f"剩余 {len(self.style_to_id)} 个风格" f"剩余 {len(self.style_to_id)} 个风格"
) )
# 记录前5个被删除的风格用于调试 # 记录前5个被删除的风格用于调试
if deleted_styles: if deleted_styles:
logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}") logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}")
except Exception as e: except Exception as e:
logger.error(f"清理风格失败: {e}", exc_info=True) logger.error(f"清理风格失败: {e}", exc_info=True)
@@ -303,10 +303,10 @@ class StyleLearner:
def cleanup_old_styles(self, ratio: float | None = None) -> int: def cleanup_old_styles(self, ratio: float | None = None) -> int:
""" """
手动清理旧风格 手动清理旧风格
Args: Args:
ratio: 清理比例如果为None则使用默认的cleanup_ratio ratio: 清理比例如果为None则使用默认的cleanup_ratio
Returns: Returns:
清理的风格数量 清理的风格数量
""" """
@@ -318,7 +318,7 @@ class StyleLearner:
self.cleanup_ratio = old_cleanup_ratio self.cleanup_ratio = old_cleanup_ratio
else: else:
self._cleanup_styles() self._cleanup_styles()
new_count = len(self.style_to_id) new_count = len(self.style_to_id)
cleaned = old_count - new_count cleaned = old_count - new_count
logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格") logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格")
@@ -357,11 +357,11 @@ class StyleLearner:
import pickle import pickle
meta_path = os.path.join(save_dir, "meta.pkl") meta_path = os.path.join(save_dir, "meta.pkl")
# 确保 learning_stats 包含所有必要字段 # 确保 learning_stats 包含所有必要字段
if "style_last_used" not in self.learning_stats: if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {} self.learning_stats["style_last_used"] = {}
meta_data = { meta_data = {
"style_to_id": self.style_to_id, "style_to_id": self.style_to_id,
"id_to_style": self.id_to_style, "id_to_style": self.id_to_style,
@@ -416,7 +416,7 @@ class StyleLearner:
self.id_to_situation = meta_data["id_to_situation"] self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"] self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"] self.learning_stats = meta_data["learning_stats"]
# 确保旧数据兼容:如果没有 style_last_used 字段,添加它 # 确保旧数据兼容:如果没有 style_last_used 字段,添加它
if "style_last_used" not in self.learning_stats: if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {} self.learning_stats["style_last_used"] = {}
@@ -526,10 +526,10 @@ class StyleLearnerManager:
def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]: def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]:
""" """
对所有学习器清理旧风格 对所有学习器清理旧风格
Args: Args:
ratio: 清理比例 ratio: 清理比例
Returns: Returns:
{chat_id: 清理数量} {chat_id: 清理数量}
""" """
@@ -538,7 +538,7 @@ class StyleLearnerManager:
cleaned = learner.cleanup_old_styles(ratio) cleaned = learner.cleanup_old_styles(ratio)
if cleaned > 0: if cleaned > 0:
cleanup_results[chat_id] = cleaned cleanup_results[chat_id] = cleaned
total_cleaned = sum(cleanup_results.values()) total_cleaned = sum(cleanup_results.values())
logger.info(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格") logger.info(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格")
return cleanup_results return cleanup_results

View File

@@ -8,7 +8,6 @@ from datetime import datetime
from typing import Any from typing import Any
import numpy as np import numpy as np
import orjson
from sqlalchemy import select from sqlalchemy import select
from src.common.config_helpers import resolve_embedding_dimension from src.common.config_helpers import resolve_embedding_dimension
@@ -124,7 +123,7 @@ class BotInterestManager:
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()] tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()]
tags_str = "\n".join(tags_info) tags_str = "\n".join(tags_info)
logger.info(f"当前兴趣标签:\n{tags_str}") logger.info(f"当前兴趣标签:\n{tags_str}")
# 为加载的标签生成embedding数据库不存储embedding启动时动态生成 # 为加载的标签生成embedding数据库不存储embedding启动时动态生成
logger.info("🧠 为加载的标签生成embedding向量...") logger.info("🧠 为加载的标签生成embedding向量...")
await self._generate_embeddings_for_tags(loaded_interests) await self._generate_embeddings_for_tags(loaded_interests)
@@ -326,13 +325,13 @@ class BotInterestManager:
raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding") raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding")
total_tags = len(interests.interest_tags) total_tags = len(interests.interest_tags)
# 尝试从文件加载缓存 # 尝试从文件加载缓存
file_cache = await self._load_embedding_cache_from_file(interests.personality_id) file_cache = await self._load_embedding_cache_from_file(interests.personality_id)
if file_cache: if file_cache:
logger.info(f"📂 从文件加载 {len(file_cache)} 个embedding缓存") logger.info(f"📂 从文件加载 {len(file_cache)} 个embedding缓存")
self.embedding_cache.update(file_cache) self.embedding_cache.update(file_cache)
logger.info(f"🧠 开始为 {total_tags} 个兴趣标签生成embedding向量...") logger.info(f"🧠 开始为 {total_tags} 个兴趣标签生成embedding向量...")
memory_cached_count = 0 memory_cached_count = 0
@@ -477,14 +476,14 @@ class BotInterestManager:
self, message_text: str, keywords: list[str] | None = None self, message_text: str, keywords: list[str] | None = None
) -> InterestMatchResult: ) -> InterestMatchResult:
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略) """计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
核心优化:将短标签扩展为完整的描述性句子,解决语义粒度不匹配问题 核心优化:将短标签扩展为完整的描述性句子,解决语义粒度不匹配问题
原问题: 原问题:
- 消息: "今天天气不错" (完整句子) - 消息: "今天天气不错" (完整句子)
- 标签: "蹭人治愈" (2-4字短语) - 标签: "蹭人治愈" (2-4字短语)
- 结果: 误匹配,因为短标签的 embedding 过于抽象 - 结果: 误匹配,因为短标签的 embedding 过于抽象
解决方案: 解决方案:
- 标签扩展: "蹭人治愈" -> "表达亲近、寻求安慰、撒娇的内容" - 标签扩展: "蹭人治愈" -> "表达亲近、寻求安慰、撒娇的内容"
- 现在是: 句子 vs 句子,匹配更准确 - 现在是: 句子 vs 句子,匹配更准确
@@ -527,18 +526,18 @@ class BotInterestManager:
if tag.embedding: if tag.embedding:
# 🔧 优化:获取扩展标签的 embedding带缓存 # 🔧 优化:获取扩展标签的 embedding带缓存
expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name) expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name)
if expanded_embedding: if expanded_embedding:
# 使用扩展标签的 embedding 进行匹配 # 使用扩展标签的 embedding 进行匹配
similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding) similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding)
# 同时计算原始标签的相似度作为参考 # 同时计算原始标签的相似度作为参考
original_similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding) original_similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
# 混合策略扩展标签权重更高70%原始标签作为补充30% # 混合策略扩展标签权重更高70%原始标签作为补充30%
# 这样可以兼顾准确性(扩展)和灵活性(原始) # 这样可以兼顾准确性(扩展)和灵活性(原始)
final_similarity = similarity * 0.7 + original_similarity * 0.3 final_similarity = similarity * 0.7 + original_similarity * 0.3
logger.debug(f"标签'{tag.tag_name}': 原始={original_similarity:.3f}, 扩展={similarity:.3f}, 最终={final_similarity:.3f}") logger.debug(f"标签'{tag.tag_name}': 原始={original_similarity:.3f}, 扩展={similarity:.3f}, 最终={final_similarity:.3f}")
else: else:
# 如果扩展 embedding 获取失败,使用原始 embedding # 如果扩展 embedding 获取失败,使用原始 embedding
@@ -603,27 +602,27 @@ class BotInterestManager:
logger.debug( logger.debug(
f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}" f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
) )
# 如果有新生成的扩展embedding保存到缓存文件 # 如果有新生成的扩展embedding保存到缓存文件
if hasattr(self, '_new_expanded_embeddings_generated') and self._new_expanded_embeddings_generated: if hasattr(self, "_new_expanded_embeddings_generated") and self._new_expanded_embeddings_generated:
await self._save_embedding_cache_to_file(self.current_interests.personality_id) await self._save_embedding_cache_to_file(self.current_interests.personality_id)
self._new_expanded_embeddings_generated = False self._new_expanded_embeddings_generated = False
logger.debug("💾 已保存新生成的扩展embedding到缓存文件") logger.debug("💾 已保存新生成的扩展embedding到缓存文件")
return result return result
async def _get_expanded_tag_embedding(self, tag_name: str) -> list[float] | None: async def _get_expanded_tag_embedding(self, tag_name: str) -> list[float] | None:
"""获取扩展标签的 embedding带缓存 """获取扩展标签的 embedding带缓存
优先使用缓存,如果没有则生成并缓存 优先使用缓存,如果没有则生成并缓存
""" """
# 检查缓存 # 检查缓存
if tag_name in self.expanded_embedding_cache: if tag_name in self.expanded_embedding_cache:
return self.expanded_embedding_cache[tag_name] return self.expanded_embedding_cache[tag_name]
# 扩展标签 # 扩展标签
expanded_tag = self._expand_tag_for_matching(tag_name) expanded_tag = self._expand_tag_for_matching(tag_name)
# 生成 embedding # 生成 embedding
try: try:
embedding = await self._get_embedding(expanded_tag) embedding = await self._get_embedding(expanded_tag)
@@ -636,19 +635,19 @@ class BotInterestManager:
return embedding return embedding
except Exception as e: except Exception as e:
logger.warning(f"为标签'{tag_name}'生成扩展embedding失败: {e}") logger.warning(f"为标签'{tag_name}'生成扩展embedding失败: {e}")
return None return None
def _expand_tag_for_matching(self, tag_name: str) -> str: def _expand_tag_for_matching(self, tag_name: str) -> str:
"""将短标签扩展为完整的描述性句子 """将短标签扩展为完整的描述性句子
这是解决"标签太短导致误匹配"的核心方法 这是解决"标签太短导致误匹配"的核心方法
策略: 策略:
1. 优先使用 LLM 生成的 expanded 字段(最准确) 1. 优先使用 LLM 生成的 expanded 字段(最准确)
2. 如果没有,使用基于规则的回退方案 2. 如果没有,使用基于规则的回退方案
3. 最后使用通用模板 3. 最后使用通用模板
示例: 示例:
- "Python" + expanded -> "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题" - "Python" + expanded -> "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
- "蹭人治愈" + expanded -> "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话" - "蹭人治愈" + expanded -> "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
@@ -656,7 +655,7 @@ class BotInterestManager:
# 使用缓存 # 使用缓存
if tag_name in self.expanded_tag_cache: if tag_name in self.expanded_tag_cache:
return self.expanded_tag_cache[tag_name] return self.expanded_tag_cache[tag_name]
# 🎯 优先策略:使用 LLM 生成的 expanded 字段 # 🎯 优先策略:使用 LLM 生成的 expanded 字段
if self.current_interests: if self.current_interests:
for tag in self.current_interests.interest_tags: for tag in self.current_interests.interest_tags:
@@ -664,66 +663,66 @@ class BotInterestManager:
logger.debug(f"✅ 使用LLM生成的扩展描述: {tag_name} -> {tag.expanded[:50]}...") logger.debug(f"✅ 使用LLM生成的扩展描述: {tag_name} -> {tag.expanded[:50]}...")
self.expanded_tag_cache[tag_name] = tag.expanded self.expanded_tag_cache[tag_name] = tag.expanded
return tag.expanded return tag.expanded
# 🔧 回退策略基于规则的扩展用于兼容旧数据或LLM未生成扩展的情况 # 🔧 回退策略基于规则的扩展用于兼容旧数据或LLM未生成扩展的情况
logger.debug(f"⚠️ 标签'{tag_name}'没有LLM扩展描述使用规则回退方案") logger.debug(f"⚠️ 标签'{tag_name}'没有LLM扩展描述使用规则回退方案")
tag_lower = tag_name.lower() tag_lower = tag_name.lower()
# 技术编程类标签(具体化描述) # 技术编程类标签(具体化描述)
if any(word in tag_lower for word in ['python', 'java', 'code', '代码', '编程', '脚本', '算法', '开发']): if any(word in tag_lower for word in ["python", "java", "code", "代码", "编程", "脚本", "算法", "开发"]):
if 'python' in tag_lower: if "python" in tag_lower:
return f"讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题" return "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
elif '算法' in tag_lower: elif "算法" in tag_lower:
return f"讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化" return "讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化"
elif '代码' in tag_lower or '被窝' in tag_lower: elif "代码" in tag_lower or "被窝" in tag_lower:
return f"讨论写代码、编程开发、代码实现、技术方案、编程技巧" return "讨论写代码、编程开发、代码实现、技术方案、编程技巧"
else: else:
return f"讨论编程开发、软件技术、代码编写、技术实现" return "讨论编程开发、软件技术、代码编写、技术实现"
# 情感表达类标签(具体化为真实对话场景) # 情感表达类标签(具体化为真实对话场景)
elif any(word in tag_lower for word in ['治愈', '撒娇', '安慰', '呼噜', '', '卖萌']): elif any(word in tag_lower for word in ["治愈", "撒娇", "安慰", "呼噜", "", "卖萌"]):
return f"想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话" return "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
# 游戏娱乐类标签(具体游戏场景) # 游戏娱乐类标签(具体游戏场景)
elif any(word in tag_lower for word in ['游戏', '网游', 'mmo', '', '']): elif any(word in tag_lower for word in ["游戏", "网游", "mmo", "", ""]):
return f"讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得" return "讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得"
# 动漫影视类标签(具体观看行为) # 动漫影视类标签(具体观看行为)
elif any(word in tag_lower for word in ['', '动漫', '视频', 'b站', '弹幕', '追番', '云新番']): elif any(word in tag_lower for word in ["", "动漫", "视频", "b站", "弹幕", "追番", "云新番"]):
# 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西" # 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西"
if '' in tag_lower or '新番' in tag_lower: if "" in tag_lower or "新番" in tag_lower:
return f"讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色" return "讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色"
else: else:
return f"讨论动漫番剧内容、B站视频、弹幕文化、追番体验" return "讨论动漫番剧内容、B站视频、弹幕文化、追番体验"
# 社交平台类标签(具体平台行为) # 社交平台类标签(具体平台行为)
elif any(word in tag_lower for word in ['小红书', '贴吧', '论坛', '社区', '吃瓜', '八卦']): elif any(word in tag_lower for word in ["小红书", "贴吧", "论坛", "社区", "吃瓜", "八卦"]):
if '吃瓜' in tag_lower: if "吃瓜" in tag_lower:
return f"聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题" return "聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题"
else: else:
return f"讨论社交平台内容、网络社区话题、论坛讨论、分享生活" return "讨论社交平台内容、网络社区话题、论坛讨论、分享生活"
# 生活日常类标签(具体萌宠场景) # 生活日常类标签(具体萌宠场景)
elif any(word in tag_lower for word in ['', '宠物', '尾巴', '耳朵', '毛绒']): elif any(word in tag_lower for word in ["", "宠物", "尾巴", "耳朵", "毛绒"]):
return f"讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得" return "讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得"
# 状态心情类标签(具体情绪状态) # 状态心情类标签(具体情绪状态)
elif any(word in tag_lower for word in ['社恐', '隐身', '流浪', '深夜', '被窝']): elif any(word in tag_lower for word in ["社恐", "隐身", "流浪", "深夜", "被窝"]):
if '社恐' in tag_lower: if "社恐" in tag_lower:
return f"表达社交焦虑、不想见人、想躲起来、害怕社交的心情" return "表达社交焦虑、不想见人、想躲起来、害怕社交的心情"
elif '深夜' in tag_lower: elif "深夜" in tag_lower:
return f"深夜睡不着、熬夜、夜猫子、深夜思考人生的对话" return "深夜睡不着、熬夜、夜猫子、深夜思考人生的对话"
else: else:
return f"表达当前心情状态、个人感受、生活状态" return "表达当前心情状态、个人感受、生活状态"
# 物品装备类标签(具体使用场景) # 物品装备类标签(具体使用场景)
elif any(word in tag_lower for word in ['键盘', '耳机', '装备', '设备']): elif any(word in tag_lower for word in ["键盘", "耳机", "装备", "设备"]):
return f"讨论键盘耳机装备、数码产品、使用体验、装备推荐评测" return "讨论键盘耳机装备、数码产品、使用体验、装备推荐评测"
# 互动关系类标签 # 互动关系类标签
elif any(word in tag_lower for word in ['拾风', '互怼', '互动']): elif any(word in tag_lower for word in ["拾风", "互怼", "互动"]):
return f"聊天互动、开玩笑、友好互怼、日常对话交流" return "聊天互动、开玩笑、友好互怼、日常对话交流"
# 默认:尽量具体化 # 默认:尽量具体化
else: else:
return f"明确讨论{tag_name}这个特定主题的具体内容和相关话题" return f"明确讨论{tag_name}这个特定主题的具体内容和相关话题"
@@ -1011,56 +1010,58 @@ class BotInterestManager:
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None: async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None:
"""从文件加载embedding缓存""" """从文件加载embedding缓存"""
try: try:
import orjson
from pathlib import Path from pathlib import Path
import orjson
cache_dir = Path("data/embedding") cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json" cache_file = cache_dir / f"{personality_id}_embeddings.json"
if not cache_file.exists(): if not cache_file.exists():
logger.debug(f"📂 Embedding缓存文件不存在: {cache_file}") logger.debug(f"📂 Embedding缓存文件不存在: {cache_file}")
return None return None
# 读取缓存文件 # 读取缓存文件
with open(cache_file, "rb") as f: with open(cache_file, "rb") as f:
cache_data = orjson.loads(f.read()) cache_data = orjson.loads(f.read())
# 验证缓存版本和embedding模型 # 验证缓存版本和embedding模型
cache_version = cache_data.get("version", 1) cache_version = cache_data.get("version", 1)
cache_embedding_model = cache_data.get("embedding_model", "") cache_embedding_model = cache_data.get("embedding_model", "")
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") else "" current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") else ""
if cache_embedding_model != current_embedding_model: if cache_embedding_model != current_embedding_model:
logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model}{current_embedding_model}),忽略旧缓存") logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model}{current_embedding_model}),忽略旧缓存")
return None return None
embeddings = cache_data.get("embeddings", {}) embeddings = cache_data.get("embeddings", {})
# 同时加载扩展标签的embedding缓存 # 同时加载扩展标签的embedding缓存
expanded_embeddings = cache_data.get("expanded_embeddings", {}) expanded_embeddings = cache_data.get("expanded_embeddings", {})
if expanded_embeddings: if expanded_embeddings:
self.expanded_embedding_cache.update(expanded_embeddings) self.expanded_embedding_cache.update(expanded_embeddings)
logger.info(f"📂 加载 {len(expanded_embeddings)} 个扩展标签embedding缓存") logger.info(f"📂 加载 {len(expanded_embeddings)} 个扩展标签embedding缓存")
logger.info(f"✅ 成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})") logger.info(f"✅ 成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
return embeddings return embeddings
except Exception as e: except Exception as e:
logger.warning(f"⚠️ 加载embedding缓存文件失败: {e}") logger.warning(f"⚠️ 加载embedding缓存文件失败: {e}")
return None return None
async def _save_embedding_cache_to_file(self, personality_id: str): async def _save_embedding_cache_to_file(self, personality_id: str):
"""保存embedding缓存到文件包括扩展标签的embedding""" """保存embedding缓存到文件包括扩展标签的embedding"""
try: try:
import orjson
from pathlib import Path
from datetime import datetime from datetime import datetime
from pathlib import Path
import orjson
cache_dir = Path("data/embedding") cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json" cache_file = cache_dir / f"{personality_id}_embeddings.json"
# 准备缓存数据 # 准备缓存数据
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list else "" current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list else ""
cache_data = { cache_data = {
@@ -1071,13 +1072,13 @@ class BotInterestManager:
"embeddings": self.embedding_cache, "embeddings": self.embedding_cache,
"expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding "expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding
} }
# 写入文件 # 写入文件
with open(cache_file, "wb") as f: with open(cache_file, "wb") as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2)) f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2))
logger.debug(f"💾 已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}") logger.debug(f"💾 已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}")
except Exception as e: except Exception as e:
logger.warning(f"⚠️ 保存embedding缓存文件失败: {e}") logger.warning(f"⚠️ 保存embedding缓存文件失败: {e}")

View File

@@ -9,8 +9,8 @@ from .scheduler_dispatcher import SchedulerDispatcher, scheduler_dispatcher
__all__ = [ __all__ = [
"MessageManager", "MessageManager",
"SingleStreamContextManager",
"SchedulerDispatcher", "SchedulerDispatcher",
"SingleStreamContextManager",
"message_manager", "message_manager",
"scheduler_dispatcher", "scheduler_dispatcher",
] ]

View File

@@ -73,7 +73,7 @@ class SingleStreamContextManager:
cache_enabled = global_config.chat.enable_message_cache cache_enabled = global_config.chat.enable_message_cache
use_cache_system = message_manager.is_running and cache_enabled use_cache_system = message_manager.is_running and cache_enabled
if not cache_enabled: if not cache_enabled:
logger.debug(f"消息缓存系统已在配置中禁用") logger.debug("消息缓存系统已在配置中禁用")
except Exception as e: except Exception as e:
logger.debug(f"MessageManager不可用使用直接添加: {e}") logger.debug(f"MessageManager不可用使用直接添加: {e}")
use_cache_system = False use_cache_system = False
@@ -129,13 +129,13 @@ class SingleStreamContextManager:
await self._calculate_message_interest(message) await self._calculate_message_interest(message)
self.total_messages += 1 self.total_messages += 1
self.last_access_time = time.time() self.last_access_time = time.time()
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}") logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True return True
# 不应该到达这里,但为了类型检查添加返回值 # 不应该到达这里,但为了类型检查添加返回值
return True return True
except Exception as e: except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False return False

View File

@@ -4,13 +4,11 @@
""" """
import asyncio import asyncio
import random
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
@@ -77,7 +75,7 @@ class MessageManager:
# 启动基于 scheduler 的消息分发器 # 启动基于 scheduler 的消息分发器
await scheduler_dispatcher.start() await scheduler_dispatcher.start()
scheduler_dispatcher.set_chatter_manager(self.chatter_manager) scheduler_dispatcher.set_chatter_manager(self.chatter_manager)
# 保留旧的流循环管理器(暂时)以便平滑过渡 # 保留旧的流循环管理器(暂时)以便平滑过渡
# TODO: 在确认新机制稳定后移除 # TODO: 在确认新机制稳定后移除
# await stream_loop_manager.start() # await stream_loop_manager.start()
@@ -108,7 +106,7 @@ class MessageManager:
# 停止基于 scheduler 的消息分发器 # 停止基于 scheduler 的消息分发器
await scheduler_dispatcher.stop() await scheduler_dispatcher.stop()
# 停止旧的流循环管理器(如果启用) # 停止旧的流循环管理器(如果启用)
# await stream_loop_manager.stop() # await stream_loop_manager.stop()
@@ -116,7 +114,7 @@ class MessageManager:
async def add_message(self, stream_id: str, message: DatabaseMessages): async def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流 """添加消息到指定聊天流
新的流程: 新的流程:
1. 检查 notice 消息 1. 检查 notice 消息
2. 将消息添加到上下文(缓存) 2. 将消息添加到上下文(缓存)
@@ -149,10 +147,10 @@ class MessageManager:
if not chat_stream: if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在") logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return return
# 将消息添加到上下文 # 将消息添加到上下文
await chat_stream.context_manager.add_message(message) await chat_stream.context_manager.add_message(message)
# 通知 scheduler_dispatcher 处理消息接收事件 # 通知 scheduler_dispatcher 处理消息接收事件
# dispatcher 会检查是否需要打断、创建或更新 schedule # dispatcher 会检查是否需要打断、创建或更新 schedule
await scheduler_dispatcher.on_message_received(stream_id) await scheduler_dispatcher.on_message_received(stream_id)

View File

@@ -20,7 +20,7 @@ logger = get_logger("scheduler_dispatcher")
class SchedulerDispatcher: class SchedulerDispatcher:
"""基于 scheduler 的消息分发器 """基于 scheduler 的消息分发器
工作流程: 工作流程:
1. 接收消息时,将消息添加到聊天流上下文 1. 接收消息时,将消息添加到聊天流上下文
2. 检查是否有活跃的 schedule如果没有则创建 2. 检查是否有活跃的 schedule如果没有则创建
@@ -32,10 +32,17 @@ class SchedulerDispatcher:
def __init__(self): def __init__(self):
# 追踪每个流的 schedule_id # 追踪每个流的 schedule_id
self.stream_schedules: dict[str, str] = {} # stream_id -> schedule_id self.stream_schedules: dict[str, str] = {} # stream_id -> schedule_id
<<<<<<< HEAD
=======
# 用于保护 schedule 创建/删除的锁,避免竞态条件
self.schedule_locks: dict[str, asyncio.Lock] = {} # stream_id -> Lock
>>>>>>> b0ee26652efd8d03c2e407575b8c5eea78f6afec
# Chatter 管理器 # Chatter 管理器
self.chatter_manager: ChatterManager | None = None self.chatter_manager: ChatterManager | None = None
# 统计信息 # 统计信息
self.stats = { self.stats = {
"total_schedules_created": 0, "total_schedules_created": 0,
@@ -45,9 +52,9 @@ class SchedulerDispatcher:
"total_failures": 0, "total_failures": 0,
"start_time": time.time(), "start_time": time.time(),
} }
self.is_running = False self.is_running = False
logger.info("基于 Scheduler 的消息分发器初始化完成") logger.info("基于 Scheduler 的消息分发器初始化完成")
async def start(self) -> None: async def start(self) -> None:
@@ -55,7 +62,7 @@ class SchedulerDispatcher:
if self.is_running: if self.is_running:
logger.warning("分发器已在运行") logger.warning("分发器已在运行")
return return
self.is_running = True self.is_running = True
logger.info("基于 Scheduler 的消息分发器已启动") logger.info("基于 Scheduler 的消息分发器已启动")
@@ -63,9 +70,9 @@ class SchedulerDispatcher:
"""停止分发器""" """停止分发器"""
if not self.is_running: if not self.is_running:
return return
self.is_running = False self.is_running = False
# 取消所有活跃的 schedule # 取消所有活跃的 schedule
schedule_ids = list(self.stream_schedules.values()) schedule_ids = list(self.stream_schedules.values())
for schedule_id in schedule_ids: for schedule_id in schedule_ids:
@@ -73,7 +80,7 @@ class SchedulerDispatcher:
await unified_scheduler.remove_schedule(schedule_id) await unified_scheduler.remove_schedule(schedule_id)
except Exception as e: except Exception as e:
logger.error(f"移除 schedule {schedule_id} 失败: {e}") logger.error(f"移除 schedule {schedule_id} 失败: {e}")
self.stream_schedules.clear() self.stream_schedules.clear()
logger.info("基于 Scheduler 的消息分发器已停止") logger.info("基于 Scheduler 的消息分发器已停止")
@@ -81,43 +88,52 @@ class SchedulerDispatcher:
"""设置 Chatter 管理器""" """设置 Chatter 管理器"""
self.chatter_manager = chatter_manager self.chatter_manager = chatter_manager
logger.debug(f"设置 Chatter 管理器: {chatter_manager.__class__.__name__}") logger.debug(f"设置 Chatter 管理器: {chatter_manager.__class__.__name__}")
<<<<<<< HEAD
=======
def _get_schedule_lock(self, stream_id: str) -> asyncio.Lock:
"""获取流的 schedule 锁"""
if stream_id not in self.schedule_locks:
self.schedule_locks[stream_id] = asyncio.Lock()
return self.schedule_locks[stream_id]
>>>>>>> b0ee26652efd8d03c2e407575b8c5eea78f6afec
async def on_message_received(self, stream_id: str) -> None: async def on_message_received(self, stream_id: str) -> None:
"""消息接收时的处理逻辑 """消息接收时的处理逻辑
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
""" """
if not self.is_running: if not self.is_running:
logger.warning("分发器未运行,忽略消息") logger.warning("分发器未运行,忽略消息")
return return
try: try:
# 1. 获取流上下文 # 1. 获取流上下文
context = await self._get_stream_context(stream_id) context = await self._get_stream_context(stream_id)
if not context: if not context:
logger.warning(f"无法获取流上下文: {stream_id}") logger.warning(f"无法获取流上下文: {stream_id}")
return return
# 2. 检查是否有活跃的 schedule # 2. 检查是否有活跃的 schedule
has_active_schedule = stream_id in self.stream_schedules has_active_schedule = stream_id in self.stream_schedules
if not has_active_schedule: if not has_active_schedule:
# 4. 创建新的 schedule在锁内避免重复创建 # 4. 创建新的 schedule在锁内避免重复创建
await self._create_schedule(stream_id, context) await self._create_schedule(stream_id, context)
return return
# 3. 检查打断判定 # 3. 检查打断判定
if has_active_schedule: if has_active_schedule:
should_interrupt = await self._check_interruption(stream_id, context) should_interrupt = await self._check_interruption(stream_id, context)
if should_interrupt: if should_interrupt:
# 移除旧 schedule 并创建新的 # 移除旧 schedule 并创建新的
await self._cancel_and_recreate_schedule(stream_id, context) await self._cancel_and_recreate_schedule(stream_id, context)
logger.debug(f"⚡ 打断成功: 流={stream_id[:8]}..., 已重新创建 schedule") logger.debug(f"⚡ 打断成功: 流={stream_id[:8]}..., 已重新创建 schedule")
else: else:
logger.debug(f"打断判定失败,保持原有 schedule: 流={stream_id[:8]}...") logger.debug(f"打断判定失败,保持原有 schedule: 流={stream_id[:8]}...")
except Exception as e: except Exception as e:
logger.error(f"处理消息接收事件失败 {stream_id}: {e}", exc_info=True) logger.error(f"处理消息接收事件失败 {stream_id}: {e}", exc_info=True)
@@ -135,18 +151,18 @@ class SchedulerDispatcher:
async def _check_interruption(self, stream_id: str, context: StreamContext) -> bool: async def _check_interruption(self, stream_id: str, context: StreamContext) -> bool:
"""检查是否应该打断当前处理 """检查是否应该打断当前处理
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
Returns: Returns:
bool: 是否应该打断 bool: 是否应该打断
""" """
# 检查是否启用打断 # 检查是否启用打断
if not global_config.chat.interruption_enabled: if not global_config.chat.interruption_enabled:
return False return False
# 检查是否正在回复,以及是否允许在回复时打断 # 检查是否正在回复,以及是否允许在回复时打断
if context.is_replying: if context.is_replying:
if not global_config.chat.allow_reply_interruption: if not global_config.chat.allow_reply_interruption:
@@ -154,49 +170,49 @@ class SchedulerDispatcher:
return False return False
else: else:
logger.debug(f"聊天流 {stream_id} 正在回复中,但配置允许回复时打断") logger.debug(f"聊天流 {stream_id} 正在回复中,但配置允许回复时打断")
# 只有当 Chatter 真正在处理时才检查打断 # 只有当 Chatter 真正在处理时才检查打断
if not context.is_chatter_processing: if not context.is_chatter_processing:
logger.debug(f"聊天流 {stream_id} Chatter 未在处理,无需打断") logger.debug(f"聊天流 {stream_id} Chatter 未在处理,无需打断")
return False return False
# 检查最后一条消息 # 检查最后一条消息
last_message = context.get_last_message() last_message = context.get_last_message()
if not last_message: if not last_message:
return False return False
# 检查是否为表情包消息 # 检查是否为表情包消息
if last_message.is_picid or last_message.is_emoji: if last_message.is_picid or last_message.is_emoji:
logger.info(f"消息 {last_message.message_id} 是表情包或Emoji跳过打断检查") logger.info(f"消息 {last_message.message_id} 是表情包或Emoji跳过打断检查")
return False return False
# 检查触发用户ID # 检查触发用户ID
triggering_user_id = context.triggering_user_id triggering_user_id = context.triggering_user_id
if triggering_user_id and last_message.user_info.user_id != triggering_user_id: if triggering_user_id and last_message.user_info.user_id != triggering_user_id:
logger.info(f"消息来自非触发用户 {last_message.user_info.user_id},实际触发用户为 {triggering_user_id},跳过打断检查") logger.info(f"消息来自非触发用户 {last_message.user_info.user_id},实际触发用户为 {triggering_user_id},跳过打断检查")
return False return False
# 检查是否已达到最大打断次数 # 检查是否已达到最大打断次数
if context.interruption_count >= global_config.chat.interruption_max_limit: if context.interruption_count >= global_config.chat.interruption_max_limit:
logger.debug( logger.debug(
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit}" f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit}"
) )
return False return False
# 计算打断概率 # 计算打断概率
interruption_probability = context.calculate_interruption_probability( interruption_probability = context.calculate_interruption_probability(
global_config.chat.interruption_max_limit global_config.chat.interruption_max_limit
) )
# 根据概率决定是否打断 # 根据概率决定是否打断
import random import random
if random.random() < interruption_probability: if random.random() < interruption_probability:
logger.debug(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}") logger.debug(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
# 增加打断计数 # 增加打断计数
await context.increment_interruption_count() await context.increment_interruption_count()
self.stats["total_interruptions"] += 1 self.stats["total_interruptions"] += 1
# 检查是否已达到最大次数 # 检查是否已达到最大次数
if context.interruption_count >= global_config.chat.interruption_max_limit: if context.interruption_count >= global_config.chat.interruption_max_limit:
logger.warning( logger.warning(
@@ -206,7 +222,7 @@ class SchedulerDispatcher:
logger.info( logger.info(
f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}" f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}"
) )
return True return True
else: else:
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}") logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
@@ -214,7 +230,7 @@ class SchedulerDispatcher:
async def _cancel_and_recreate_schedule(self, stream_id: str, context: StreamContext) -> None: async def _cancel_and_recreate_schedule(self, stream_id: str, context: StreamContext) -> None:
"""取消旧的 schedule 并创建新的(打断模式,使用极短延迟) """取消旧的 schedule 并创建新的(打断模式,使用极短延迟)
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
@@ -235,13 +251,13 @@ class SchedulerDispatcher:
) )
# 移除失败,不创建新 schedule避免重复 # 移除失败,不创建新 schedule避免重复
return return
# 创建新的 schedule使用即时处理模式极短延迟 # 创建新的 schedule使用即时处理模式极短延迟
await self._create_schedule(stream_id, context, immediate_mode=True) await self._create_schedule(stream_id, context, immediate_mode=True)
async def _create_schedule(self, stream_id: str, context: StreamContext, immediate_mode: bool = False) -> None: async def _create_schedule(self, stream_id: str, context: StreamContext, immediate_mode: bool = False) -> None:
"""为聊天流创建新的 schedule """为聊天流创建新的 schedule
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
@@ -257,7 +273,7 @@ class SchedulerDispatcher:
) )
await unified_scheduler.remove_schedule(old_schedule_id) await unified_scheduler.remove_schedule(old_schedule_id)
del self.stream_schedules[stream_id] del self.stream_schedules[stream_id]
# 如果是即时处理模式打断时使用固定的1秒延迟立即重新处理 # 如果是即时处理模式打断时使用固定的1秒延迟立即重新处理
if immediate_mode: if immediate_mode:
delay = 1.0 # 硬编码1秒延迟确保打断后能快速重新处理 delay = 1.0 # 硬编码1秒延迟确保打断后能快速重新处理
@@ -268,10 +284,10 @@ class SchedulerDispatcher:
else: else:
# 常规模式:计算初始延迟 # 常规模式:计算初始延迟
delay = await self._calculate_initial_delay(stream_id, context) delay = await self._calculate_initial_delay(stream_id, context)
# 获取未读消息数量用于日志 # 获取未读消息数量用于日志
unread_count = len(context.unread_messages) if context.unread_messages else 0 unread_count = len(context.unread_messages) if context.unread_messages else 0
# 创建 schedule # 创建 schedule
schedule_id = await unified_scheduler.create_schedule( schedule_id = await unified_scheduler.create_schedule(
callback=self._on_schedule_triggered, callback=self._on_schedule_triggered,
@@ -281,41 +297,41 @@ class SchedulerDispatcher:
task_name=f"dispatch_{stream_id[:8]}", task_name=f"dispatch_{stream_id[:8]}",
callback_args=(stream_id,), callback_args=(stream_id,),
) )
# 追踪 schedule # 追踪 schedule
self.stream_schedules[stream_id] = schedule_id self.stream_schedules[stream_id] = schedule_id
self.stats["total_schedules_created"] += 1 self.stats["total_schedules_created"] += 1
mode_indicator = "⚡打断" if immediate_mode else "📅常规" mode_indicator = "⚡打断" if immediate_mode else "📅常规"
logger.info( logger.info(
f"{mode_indicator} 创建 schedule: 流={stream_id[:8]}..., " f"{mode_indicator} 创建 schedule: 流={stream_id[:8]}..., "
f"延迟={delay:.3f}s, 未读={unread_count}, " f"延迟={delay:.3f}s, 未读={unread_count}, "
f"ID={schedule_id[:8]}..." f"ID={schedule_id[:8]}..."
) )
except Exception as e: except Exception as e:
logger.error(f"创建 schedule 失败 {stream_id}: {e}", exc_info=True) logger.error(f"创建 schedule 失败 {stream_id}: {e}", exc_info=True)
async def _calculate_initial_delay(self, stream_id: str, context: StreamContext) -> float: async def _calculate_initial_delay(self, stream_id: str, context: StreamContext) -> float:
"""计算初始延迟时间 """计算初始延迟时间
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
Returns: Returns:
float: 延迟时间(秒) float: 延迟时间(秒)
""" """
# 基础间隔 # 基础间隔
base_interval = getattr(global_config.chat, "distribution_interval", 5.0) base_interval = getattr(global_config.chat, "distribution_interval", 5.0)
# 检查是否有未读消息 # 检查是否有未读消息
unread_count = len(context.unread_messages) if context.unread_messages else 0 unread_count = len(context.unread_messages) if context.unread_messages else 0
# 强制分发阈值 # 强制分发阈值
force_dispatch_threshold = getattr(global_config.chat, "force_dispatch_unread_threshold", 20) force_dispatch_threshold = getattr(global_config.chat, "force_dispatch_unread_threshold", 20)
# 如果未读消息过多,使用最小间隔 # 如果未读消息过多,使用最小间隔
if force_dispatch_threshold and unread_count > force_dispatch_threshold: if force_dispatch_threshold and unread_count > force_dispatch_threshold:
min_interval = getattr(global_config.chat, "force_dispatch_min_interval", 0.1) min_interval = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
@@ -325,24 +341,24 @@ class SchedulerDispatcher:
f"使用最小间隔={min_interval}s" f"使用最小间隔={min_interval}s"
) )
return min_interval return min_interval
# 尝试使用能量管理器计算间隔 # 尝试使用能量管理器计算间隔
try: try:
# 更新能量值 # 更新能量值
await self._update_stream_energy(stream_id, context) await self._update_stream_energy(stream_id, context)
# 获取当前 focus_energy # 获取当前 focus_energy
focus_energy = energy_manager.energy_cache.get(stream_id, (0.5, 0))[0] focus_energy = energy_manager.energy_cache.get(stream_id, (0.5, 0))[0]
# 使用能量管理器计算间隔 # 使用能量管理器计算间隔
interval = energy_manager.get_distribution_interval(focus_energy) interval = energy_manager.get_distribution_interval(focus_energy)
logger.info( logger.info(
f"📊 动态间隔计算: 流={stream_id[:8]}..., " f"📊 动态间隔计算: 流={stream_id[:8]}..., "
f"能量={focus_energy:.3f}, 间隔={interval:.2f}s" f"能量={focus_energy:.3f}, 间隔={interval:.2f}s"
) )
return interval return interval
except Exception as e: except Exception as e:
logger.info( logger.info(
f"📊 使用默认间隔: 流={stream_id[:8]}..., " f"📊 使用默认间隔: 流={stream_id[:8]}..., "
@@ -352,96 +368,96 @@ class SchedulerDispatcher:
async def _update_stream_energy(self, stream_id: str, context: StreamContext) -> None: async def _update_stream_energy(self, stream_id: str, context: StreamContext) -> None:
"""更新流的能量值 """更新流的能量值
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
""" """
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
# 获取聊天流 # 获取聊天流
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id) chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream: if not chat_stream:
logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新") logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新")
return return
# 合并未读消息和历史消息 # 合并未读消息和历史消息
all_messages = [] all_messages = []
# 添加历史消息 # 添加历史消息
history_messages = context.get_history_messages(limit=global_config.chat.max_context_size) history_messages = context.get_history_messages(limit=global_config.chat.max_context_size)
all_messages.extend(history_messages) all_messages.extend(history_messages)
# 添加未读消息 # 添加未读消息
unread_messages = context.get_unread_messages() unread_messages = context.get_unread_messages()
all_messages.extend(unread_messages) all_messages.extend(unread_messages)
# 按时间排序并限制数量 # 按时间排序并限制数量
all_messages.sort(key=lambda m: m.time) all_messages.sort(key=lambda m: m.time)
messages = all_messages[-global_config.chat.max_context_size:] messages = all_messages[-global_config.chat.max_context_size:]
# 获取用户ID # 获取用户ID
user_id = context.triggering_user_id user_id = context.triggering_user_id
# 使用能量管理器计算并缓存能量值 # 使用能量管理器计算并缓存能量值
energy = await energy_manager.calculate_focus_energy( energy = await energy_manager.calculate_focus_energy(
stream_id=stream_id, stream_id=stream_id,
messages=messages, messages=messages,
user_id=user_id user_id=user_id
) )
# 同步更新到 ChatStream # 同步更新到 ChatStream
chat_stream._focus_energy = energy chat_stream._focus_energy = energy
logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}") logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}")
except Exception as e: except Exception as e:
logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False) logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False)
async def _on_schedule_triggered(self, stream_id: str) -> None: async def _on_schedule_triggered(self, stream_id: str) -> None:
"""schedule 触发时的回调 """schedule 触发时的回调
Args: Args:
stream_id: 流ID stream_id: 流ID
""" """
try: try:
old_schedule_id = self.stream_schedules.get(stream_id) old_schedule_id = self.stream_schedules.get(stream_id)
logger.info( logger.info(
f"⏰ Schedule 触发: 流={stream_id[:8]}..., " f"⏰ Schedule 触发: 流={stream_id[:8]}..., "
f"ID={old_schedule_id[:8] if old_schedule_id else 'None'}..., " f"ID={old_schedule_id[:8] if old_schedule_id else 'None'}..., "
f"开始处理消息" f"开始处理消息"
) )
# 获取流上下文 # 获取流上下文
context = await self._get_stream_context(stream_id) context = await self._get_stream_context(stream_id)
if not context: if not context:
logger.warning(f"Schedule 触发时无法获取流上下文: {stream_id}") logger.warning(f"Schedule 触发时无法获取流上下文: {stream_id}")
return return
# 检查是否有未读消息 # 检查是否有未读消息
if not context.unread_messages: if not context.unread_messages:
logger.debug(f"{stream_id} 没有未读消息,跳过处理") logger.debug(f"{stream_id} 没有未读消息,跳过处理")
return return
# 激活 chatter 处理(不需要锁,允许并发处理) # 激活 chatter 处理(不需要锁,允许并发处理)
success = await self._process_stream(stream_id, context) success = await self._process_stream(stream_id, context)
# 更新统计 # 更新统计
self.stats["total_process_cycles"] += 1 self.stats["total_process_cycles"] += 1
if not success: if not success:
self.stats["total_failures"] += 1 self.stats["total_failures"] += 1
self.stream_schedules.pop(stream_id, None) self.stream_schedules.pop(stream_id, None)
# 检查缓存中是否有待处理的消息 # 检查缓存中是否有待处理的消息
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
has_cached = message_manager.has_cached_messages(stream_id) has_cached = message_manager.has_cached_messages(stream_id)
if has_cached: if has_cached:
# 有缓存消息,立即创建新 schedule 继续处理 # 有缓存消息,立即创建新 schedule 继续处理
logger.info( logger.info(
@@ -455,60 +471,60 @@ class SchedulerDispatcher:
f"✅ 处理完成且无缓存消息: 流={stream_id[:8]}..., " f"✅ 处理完成且无缓存消息: 流={stream_id[:8]}..., "
f"等待新消息到达" f"等待新消息到达"
) )
except Exception as e: except Exception as e:
logger.error(f"Schedule 回调执行失败 {stream_id}: {e}", exc_info=True) logger.error(f"Schedule 回调执行失败 {stream_id}: {e}", exc_info=True)
async def _process_stream(self, stream_id: str, context: StreamContext) -> bool: async def _process_stream(self, stream_id: str, context: StreamContext) -> bool:
"""处理流消息 """处理流消息
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
Returns: Returns:
bool: 是否处理成功 bool: 是否处理成功
""" """
if not self.chatter_manager: if not self.chatter_manager:
logger.warning(f"Chatter 管理器未设置: {stream_id}") logger.warning(f"Chatter 管理器未设置: {stream_id}")
return False return False
# 设置处理状态 # 设置处理状态
self._set_stream_processing_status(stream_id, True) self._set_stream_processing_status(stream_id, True)
try: try:
start_time = time.time() start_time = time.time()
# 设置触发用户ID # 设置触发用户ID
last_message = context.get_last_message() last_message = context.get_last_message()
if last_message: if last_message:
context.triggering_user_id = last_message.user_info.user_id context.triggering_user_id = last_message.user_info.user_id
# 创建异步任务刷新能量(不阻塞主流程) # 创建异步任务刷新能量(不阻塞主流程)
energy_task = asyncio.create_task(self._refresh_focus_energy(stream_id)) energy_task = asyncio.create_task(self._refresh_focus_energy(stream_id))
# 设置 Chatter 正在处理的标志 # 设置 Chatter 正在处理的标志
context.is_chatter_processing = True context.is_chatter_processing = True
logger.debug(f"设置 Chatter 处理标志: {stream_id}") logger.debug(f"设置 Chatter 处理标志: {stream_id}")
try: try:
# 调用 chatter_manager 处理流上下文 # 调用 chatter_manager 处理流上下文
results = await self.chatter_manager.process_stream_context(stream_id, context) results = await self.chatter_manager.process_stream_context(stream_id, context)
success = results.get("success", False) success = results.get("success", False)
if success: if success:
process_time = time.time() - start_time process_time = time.time() - start_time
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)") logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
else: else:
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}") logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
return success return success
finally: finally:
# 清除 Chatter 处理标志 # 清除 Chatter 处理标志
context.is_chatter_processing = False context.is_chatter_processing = False
logger.debug(f"清除 Chatter 处理标志: {stream_id}") logger.debug(f"清除 Chatter 处理标志: {stream_id}")
# 等待能量刷新任务完成 # 等待能量刷新任务完成
try: try:
await asyncio.wait_for(energy_task, timeout=5.0) await asyncio.wait_for(energy_task, timeout=5.0)
@@ -516,11 +532,11 @@ class SchedulerDispatcher:
logger.warning(f"等待能量刷新超时: {stream_id}") logger.warning(f"等待能量刷新超时: {stream_id}")
except Exception as e: except Exception as e:
logger.debug(f"能量刷新任务异常: {e}") logger.debug(f"能量刷新任务异常: {e}")
except Exception as e: except Exception as e:
logger.error(f"流处理异常: {stream_id} - {e}", exc_info=True) logger.error(f"流处理异常: {stream_id} - {e}", exc_info=True)
return False return False
finally: finally:
# 设置处理状态为未处理 # 设置处理状态为未处理
self._set_stream_processing_status(stream_id, False) self._set_stream_processing_status(stream_id, False)
@@ -529,11 +545,11 @@ class SchedulerDispatcher:
"""设置流的处理状态""" """设置流的处理状态"""
try: try:
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
if message_manager.is_running: if message_manager.is_running:
message_manager.set_stream_processing_status(stream_id, is_processing) message_manager.set_stream_processing_status(stream_id, is_processing)
logger.debug(f"设置流处理状态: stream={stream_id}, processing={is_processing}") logger.debug(f"设置流处理状态: stream={stream_id}, processing={is_processing}")
except ImportError: except ImportError:
logger.debug("MessageManager 不可用,跳过状态设置") logger.debug("MessageManager 不可用,跳过状态设置")
except Exception as e: except Exception as e:
@@ -547,7 +563,7 @@ class SchedulerDispatcher:
if not chat_stream: if not chat_stream:
logger.debug(f"刷新能量时未找到聊天流: {stream_id}") logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
return return
await chat_stream.context_manager.refresh_focus_energy_from_history() await chat_stream.context_manager.refresh_focus_energy_from_history()
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量") logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
except Exception as e: except Exception as e:

View File

@@ -367,7 +367,7 @@ class ChatBot:
message_segment = message_data.get("message_segment") message_segment = message_data.get("message_segment")
if message_segment and isinstance(message_segment, dict): if message_segment and isinstance(message_segment, dict):
if message_segment.get("type") == "adapter_response": if message_segment.get("type") == "adapter_response":
logger.info(f"[DEBUG bot.py message_process] 检测到adapter_response立即处理") logger.info("[DEBUG bot.py message_process] 检测到adapter_response立即处理")
await self._handle_adapter_response_from_dict(message_segment.get("data")) await self._handle_adapter_response_from_dict(message_segment.get("data"))
return return

View File

@@ -205,7 +205,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
return result return result
else: else:
logger.warning(f"[at处理] 无法解析格式: '{segment.data}'") logger.warning(f"[at处理] 无法解析格式: '{segment.data}'")
return f"@{segment.data}" return f"@{segment.data}"
logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}") logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}")
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"

View File

@@ -542,7 +542,7 @@ class DefaultReplyer:
all_memories = [] all_memories = []
try: try:
from src.memory_graph.manager_singleton import get_memory_manager, is_initialized from src.memory_graph.manager_singleton import get_memory_manager, is_initialized
if is_initialized(): if is_initialized():
manager = get_memory_manager() manager = get_memory_manager()
if manager: if manager:
@@ -552,12 +552,12 @@ class DefaultReplyer:
sender_name = "" sender_name = ""
if user_info_obj: if user_info_obj:
sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "") sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "")
# 获取参与者信息 # 获取参与者信息
participants = [] participants = []
try: try:
# 尝试从聊天流中获取参与者信息 # 尝试从聊天流中获取参与者信息
if hasattr(stream, 'chat_history_manager'): if hasattr(stream, "chat_history_manager"):
history_manager = stream.chat_history_manager history_manager = stream.chat_history_manager
# 获取最近的参与者列表 # 获取最近的参与者列表
recent_records = history_manager.get_memory_chat_history( recent_records = history_manager.get_memory_chat_history(
@@ -586,16 +586,16 @@ class DefaultReplyer:
formatted_history = "" formatted_history = ""
if chat_history: if chat_history:
# 移除过长的历史记录,只保留最近部分 # 移除过长的历史记录,只保留最近部分
lines = chat_history.strip().split('\n') lines = chat_history.strip().split("\n")
recent_lines = lines[-10:] if len(lines) > 10 else lines recent_lines = lines[-10:] if len(lines) > 10 else lines
formatted_history = '\n'.join(recent_lines) formatted_history = "\n".join(recent_lines)
query_context = { query_context = {
"chat_history": formatted_history, "chat_history": formatted_history,
"sender": sender_name, "sender": sender_name,
"participants": participants, "participants": participants,
} }
# 使用记忆管理器的智能检索(多查询策略) # 使用记忆管理器的智能检索(多查询策略)
memories = await manager.search_memories( memories = await manager.search_memories(
query=target, query=target,
@@ -605,23 +605,23 @@ class DefaultReplyer:
use_multi_query=True, use_multi_query=True,
context=query_context, context=query_context,
) )
if memories: if memories:
logger.info(f"[记忆图] 检索到 {len(memories)} 条相关记忆") logger.info(f"[记忆图] 检索到 {len(memories)} 条相关记忆")
# 使用新的格式化工具构建完整的记忆描述 # 使用新的格式化工具构建完整的记忆描述
from src.memory_graph.utils.memory_formatter import ( from src.memory_graph.utils.memory_formatter import (
format_memory_for_prompt, format_memory_for_prompt,
get_memory_type_label, get_memory_type_label,
) )
for memory in memories: for memory in memories:
# 使用格式化工具生成完整的主谓宾描述 # 使用格式化工具生成完整的主谓宾描述
content = format_memory_for_prompt(memory, include_metadata=False) content = format_memory_for_prompt(memory, include_metadata=False)
# 获取记忆类型 # 获取记忆类型
mem_type = memory.memory_type.value if memory.memory_type else "未知" mem_type = memory.memory_type.value if memory.memory_type else "未知"
if content: if content:
all_memories.append({ all_memories.append({
"content": content, "content": content,
@@ -636,7 +636,7 @@ class DefaultReplyer:
except Exception as e: except Exception as e:
logger.debug(f"[记忆图] 检索失败: {e}") logger.debug(f"[记忆图] 检索失败: {e}")
all_memories = [] all_memories = []
# 构建记忆字符串,使用方括号格式 # 构建记忆字符串,使用方括号格式
memory_str = "" memory_str = ""
has_any_memory = False has_any_memory = False
@@ -725,7 +725,7 @@ class DefaultReplyer:
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result") tool_result.get("type", "tool_result")
# 不进行截断,让工具自己处理结果长度 # 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}") current_results_parts.append(f"- **{tool_name}**: {content}")
@@ -744,7 +744,7 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}") logger.error(f"工具信息获取失败: {e}")
return "" return ""
def _parse_reply_target(self, target_message: str) -> tuple[str, str]: def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
"""解析回复目标消息 - 使用共享工具""" """解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt from src.chat.utils.prompt import Prompt
@@ -1897,7 +1897,7 @@ class DefaultReplyer:
async def _store_chat_memory_async(self, reply_to: str, reply_message: DatabaseMessages | dict[str, Any] | None = None): async def _store_chat_memory_async(self, reply_to: str, reply_message: DatabaseMessages | dict[str, Any] | None = None):
""" """
[已废弃] 异步存储聊天记忆从build_memory_block迁移而来 [已废弃] 异步存储聊天记忆从build_memory_block迁移而来
此函数已被记忆图系统的工具调用方式替代。 此函数已被记忆图系统的工具调用方式替代。
记忆现在由LLM在对话过程中通过CreateMemoryTool主动创建。 记忆现在由LLM在对话过程中通过CreateMemoryTool主动创建。
@@ -1906,14 +1906,13 @@ class DefaultReplyer:
reply_message: 回复的原始消息 reply_message: 回复的原始消息
""" """
return # 已禁用,保留函数签名以防其他地方有引用 return # 已禁用,保留函数签名以防其他地方有引用
# 以下代码已废弃,不再执行 # 以下代码已废弃,不再执行
try: try:
if not global_config.memory.enable_memory: if not global_config.memory.enable_memory:
return return
# 使用统一记忆系统存储记忆 # 使用统一记忆系统存储记忆
from src.chat.memory_system import get_memory_system
stream = self.chat_stream stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None) user_info_obj = getattr(stream, "user_info", None)
@@ -2036,7 +2035,7 @@ class DefaultReplyer:
timestamp=time.time(), timestamp=time.time(),
limit=int(global_config.chat.max_context_size), limit=int(global_config.chat.max_context_size),
) )
chat_history = await build_readable_messages( await build_readable_messages(
message_list_before_short, message_list_before_short,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,

View File

@@ -400,7 +400,7 @@ class Prompt:
# 初始化预构建参数字典 # 初始化预构建参数字典
pre_built_params = {} pre_built_params = {}
try: try:
# --- 步骤 1: 准备构建任务 --- # --- 步骤 1: 准备构建任务 ---
tasks = [] tasks = []

View File

@@ -87,20 +87,18 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
) )
processed_text = message.processed_plain_text or "" processed_text = message.processed_plain_text or ""
# 1. 判断是否为私聊(强提及) # 1. 判断是否为私聊(强提及)
group_info = getattr(message, "group_info", None) group_info = getattr(message, "group_info", None)
if not group_info or not getattr(group_info, "group_id", None): if not group_info or not getattr(group_info, "group_id", None):
is_private = True
mention_type = 2 mention_type = 2
logger.debug("检测到私聊消息 - 强提及") logger.debug("检测到私聊消息 - 强提及")
# 2. 判断是否被@(强提及) # 2. 判断是否被@(强提及)
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text): if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text):
is_at = True
mention_type = 2 mention_type = 2
logger.debug("检测到@提及 - 强提及") logger.debug("检测到@提及 - 强提及")
# 3. 判断是否被回复(强提及) # 3. 判断是否被回复(强提及)
if re.match( if re.match(
rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\)(.+?)\],说:", processed_text rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\)(.+?)\],说:", processed_text
@@ -108,10 +106,9 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:", rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:",
processed_text, processed_text,
): ):
is_replied = True
mention_type = 2 mention_type = 2
logger.debug("检测到回复消息 - 强提及") logger.debug("检测到回复消息 - 强提及")
# 4. 判断文本中是否提及bot名字或别名弱提及 # 4. 判断文本中是否提及bot名字或别名弱提及
if mention_type == 0: # 只有在没有强提及时才检查弱提及 if mention_type == 0: # 只有在没有强提及时才检查弱提及
# 移除@和回复标记后再检查 # 移除@和回复标记后再检查
@@ -119,21 +116,19 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content) message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\)(.+?)\],说:", "", message_content) message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\)(.+?)\],说:", "", message_content)
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", message_content) message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", message_content)
# 检查bot主名字 # 检查bot主名字
if global_config.bot.nickname in message_content: if global_config.bot.nickname in message_content:
is_text_mentioned = True
mention_type = 1 mention_type = 1
logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及") logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及")
# 如果主名字没匹配,再检查别名 # 如果主名字没匹配,再检查别名
elif nicknames: elif nicknames:
for alias_name in nicknames: for alias_name in nicknames:
if alias_name in message_content: if alias_name in message_content:
is_text_mentioned = True
mention_type = 1 mention_type = 1
logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及") logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及")
break break
# 返回结果 # 返回结果
is_mentioned = mention_type > 0 is_mentioned = mention_type > 0
return is_mentioned, float(mention_type) return is_mentioned, float(mention_type)

View File

@@ -368,13 +368,13 @@ class CacheManager:
if expired_keys: if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目") logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
def get_health_stats(self) -> dict[str, Any]: def get_health_stats(self) -> dict[str, Any]:
"""获取缓存健康统计信息""" """获取缓存健康统计信息"""
# 简化的健康统计,不包含内存监控(因为相关属性未定义) # 简化的健康统计,不包含内存监控(因为相关属性未定义)
return { return {
"l1_count": len(self.l1_kv_cache), "l1_count": len(self.l1_kv_cache),
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0, "l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0,
"tool_stats": { "tool_stats": {
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0), "total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})), "tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
@@ -397,7 +397,7 @@ class CacheManager:
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}") warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
# 检查向量索引大小 # 检查向量索引大小
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0 vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0
if isinstance(vector_count, int) and vector_count > 500: if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}") warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")

View File

@@ -66,7 +66,7 @@ class BatchStats:
last_batch_duration: float = 0.0 last_batch_duration: float = 0.0
last_batch_size: int = 0 last_batch_size: int = 0
congestion_score: float = 0.0 # 拥塞评分 (0-1) congestion_score: float = 0.0 # 拥塞评分 (0-1)
# 🔧 新增:缓存统计 # 🔧 新增:缓存统计
cache_size: int = 0 # 缓存条目数 cache_size: int = 0 # 缓存条目数
cache_memory_mb: float = 0.0 # 缓存内存占用MB cache_memory_mb: float = 0.0 # 缓存内存占用MB
@@ -539,8 +539,7 @@ class AdaptiveBatchScheduler:
def _set_cache(self, cache_key: str, result: Any) -> None: def _set_cache(self, cache_key: str, result: Any) -> None:
"""设置缓存(改进版,带大小限制和内存统计)""" """设置缓存(改进版,带大小限制和内存统计)"""
import sys
# 🔧 检查缓存大小限制 # 🔧 检查缓存大小限制
if len(self._result_cache) >= self._cache_max_size: if len(self._result_cache) >= self._cache_max_size:
# 首先清理过期条目 # 首先清理过期条目
@@ -549,18 +548,18 @@ class AdaptiveBatchScheduler:
k for k, (_, ts) in self._result_cache.items() k for k, (_, ts) in self._result_cache.items()
if current_time - ts >= self.cache_ttl if current_time - ts >= self.cache_ttl
] ]
for k in expired_keys: for k in expired_keys:
# 更新内存统计 # 更新内存统计
if k in self._cache_size_map: if k in self._cache_size_map:
self._cache_memory_estimate -= self._cache_size_map[k] self._cache_memory_estimate -= self._cache_size_map[k]
del self._cache_size_map[k] del self._cache_size_map[k]
del self._result_cache[k] del self._result_cache[k]
# 如果还是太大清理最老的条目LRU # 如果还是太大清理最老的条目LRU
if len(self._result_cache) >= self._cache_max_size: if len(self._result_cache) >= self._cache_max_size:
oldest_key = min( oldest_key = min(
self._result_cache.keys(), self._result_cache.keys(),
key=lambda k: self._result_cache[k][1] key=lambda k: self._result_cache[k][1]
) )
# 更新内存统计 # 更新内存统计
@@ -569,7 +568,7 @@ class AdaptiveBatchScheduler:
del self._cache_size_map[oldest_key] del self._cache_size_map[oldest_key]
del self._result_cache[oldest_key] del self._result_cache[oldest_key]
logger.debug(f"缓存已满,淘汰最老条目: {oldest_key}") logger.debug(f"缓存已满,淘汰最老条目: {oldest_key}")
# 🔧 使用准确的内存估算方法 # 🔧 使用准确的内存估算方法
try: try:
total_size = estimate_size_smart(cache_key) + estimate_size_smart(result) total_size = estimate_size_smart(cache_key) + estimate_size_smart(result)
@@ -580,7 +579,7 @@ class AdaptiveBatchScheduler:
# 使用默认值 # 使用默认值
self._cache_size_map[cache_key] = 1024 self._cache_size_map[cache_key] = 1024
self._cache_memory_estimate += 1024 self._cache_memory_estimate += 1024
self._result_cache[cache_key] = (result, time.time()) self._result_cache[cache_key] = (result, time.time())
async def get_stats(self) -> BatchStats: async def get_stats(self) -> BatchStats:

View File

@@ -171,7 +171,7 @@ class LRUCache(Generic[T]):
) )
else: else:
adjusted_created_at = now adjusted_created_at = now
entry = CacheEntry( entry = CacheEntry(
value=value, value=value,
created_at=adjusted_created_at, created_at=adjusted_created_at,
@@ -345,7 +345,7 @@ class MultiLevelCache:
# 估算数据大小(如果未提供) # 估算数据大小(如果未提供)
if size is None: if size is None:
size = estimate_size_smart(value) size = estimate_size_smart(value)
# 检查单个条目大小是否超过限制 # 检查单个条目大小是否超过限制
if size > self.max_item_size_bytes: if size > self.max_item_size_bytes:
logger.warning( logger.warning(
@@ -354,7 +354,7 @@ class MultiLevelCache:
f"limit={self.max_item_size_bytes / (1024 * 1024):.2f}MB" f"limit={self.max_item_size_bytes / (1024 * 1024):.2f}MB"
) )
return return
# 根据TTL决定写入哪个缓存层 # 根据TTL决定写入哪个缓存层
if ttl is not None: if ttl is not None:
# 有自定义TTL根据TTL大小决定写入层级 # 有自定义TTL根据TTL大小决定写入层级
@@ -394,37 +394,37 @@ class MultiLevelCache:
"""获取所有缓存层的统计信息(修正版,避免重复计数)""" """获取所有缓存层的统计信息(修正版,避免重复计数)"""
l1_stats = await self.l1_cache.get_stats() l1_stats = await self.l1_cache.get_stats()
l2_stats = await self.l2_cache.get_stats() l2_stats = await self.l2_cache.get_stats()
# 🔧 修复计算实际独占的内存避免L1和L2共享数据的重复计数 # 🔧 修复计算实际独占的内存避免L1和L2共享数据的重复计数
l1_keys = set(self.l1_cache._cache.keys()) l1_keys = set(self.l1_cache._cache.keys())
l2_keys = set(self.l2_cache._cache.keys()) l2_keys = set(self.l2_cache._cache.keys())
shared_keys = l1_keys & l2_keys shared_keys = l1_keys & l2_keys
l1_only_keys = l1_keys - l2_keys l1_only_keys = l1_keys - l2_keys
l2_only_keys = l2_keys - l1_keys l2_only_keys = l2_keys - l1_keys
# 计算实际总内存(避免重复计数) # 计算实际总内存(避免重复计数)
# L1独占内存 # L1独占内存
l1_only_size = sum( l1_only_size = sum(
self.l1_cache._cache[k].size self.l1_cache._cache[k].size
for k in l1_only_keys for k in l1_only_keys
if k in self.l1_cache._cache if k in self.l1_cache._cache
) )
# L2独占内存 # L2独占内存
l2_only_size = sum( l2_only_size = sum(
self.l2_cache._cache[k].size self.l2_cache._cache[k].size
for k in l2_only_keys for k in l2_only_keys
if k in self.l2_cache._cache if k in self.l2_cache._cache
) )
# 共享内存只计算一次使用L1的数据 # 共享内存只计算一次使用L1的数据
shared_size = sum( shared_size = sum(
self.l1_cache._cache[k].size self.l1_cache._cache[k].size
for k in shared_keys for k in shared_keys
if k in self.l1_cache._cache if k in self.l1_cache._cache
) )
actual_total_size = l1_only_size + l2_only_size + shared_size actual_total_size = l1_only_size + l2_only_size + shared_size
return { return {
"l1": l1_stats, "l1": l1_stats,
"l2": l2_stats, "l2": l2_stats,
@@ -442,7 +442,7 @@ class MultiLevelCache:
"""检查并强制清理超出内存限制的缓存""" """检查并强制清理超出内存限制的缓存"""
stats = await self.get_stats() stats = await self.get_stats()
total_size = stats["l1"].total_size + stats["l2"].total_size total_size = stats["l1"].total_size + stats["l2"].total_size
if total_size > self.max_memory_bytes: if total_size > self.max_memory_bytes:
memory_mb = total_size / (1024 * 1024) memory_mb = total_size / (1024 * 1024)
max_mb = self.max_memory_bytes / (1024 * 1024) max_mb = self.max_memory_bytes / (1024 * 1024)
@@ -452,14 +452,14 @@ class MultiLevelCache:
) )
# 优先清理L2缓存温数据 # 优先清理L2缓存温数据
await self.l2_cache.clear() await self.l2_cache.clear()
# 如果清理L2后仍超限清理L1 # 如果清理L2后仍超限清理L1
stats_after_l2 = await self.get_stats() stats_after_l2 = await self.get_stats()
total_after_l2 = stats_after_l2["l1"].total_size + stats_after_l2["l2"].total_size total_after_l2 = stats_after_l2["l1"].total_size + stats_after_l2["l2"].total_size
if total_after_l2 > self.max_memory_bytes: if total_after_l2 > self.max_memory_bytes:
logger.warning("清理L2后仍超限继续清理L1缓存") logger.warning("清理L2后仍超限继续清理L1缓存")
await self.l1_cache.clear() await self.l1_cache.clear()
logger.info("缓存强制清理完成") logger.info("缓存强制清理完成")
async def start_cleanup_task(self, interval: float = 60) -> None: async def start_cleanup_task(self, interval: float = 60) -> None:
@@ -476,10 +476,10 @@ class MultiLevelCache:
while not self._is_closing: while not self._is_closing:
try: try:
await asyncio.sleep(interval) await asyncio.sleep(interval)
if self._is_closing: if self._is_closing:
break break
stats = await self.get_stats() stats = await self.get_stats()
l1_stats = stats["l1"] l1_stats = stats["l1"]
l2_stats = stats["l2"] l2_stats = stats["l2"]
@@ -493,13 +493,13 @@ class MultiLevelCache:
f"共享: {stats['shared_keys_count']}键/{stats['shared_mb']:.2f}MB " f"共享: {stats['shared_keys_count']}键/{stats['shared_mb']:.2f}MB "
f"(去重节省{stats['dedup_savings_mb']:.2f}MB)" f"(去重节省{stats['dedup_savings_mb']:.2f}MB)"
) )
# 🔧 清理过期条目 # 🔧 清理过期条目
await self._clean_expired_entries() await self._clean_expired_entries()
# 检查内存限制 # 检查内存限制
await self.check_memory_limit() await self.check_memory_limit()
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
@@ -511,7 +511,7 @@ class MultiLevelCache:
async def stop_cleanup_task(self) -> None: async def stop_cleanup_task(self) -> None:
"""停止清理任务""" """停止清理任务"""
self._is_closing = True self._is_closing = True
if self._cleanup_task is not None: if self._cleanup_task is not None:
self._cleanup_task.cancel() self._cleanup_task.cancel()
try: try:
@@ -520,43 +520,43 @@ class MultiLevelCache:
pass pass
self._cleanup_task = None self._cleanup_task = None
logger.info("缓存清理任务已停止") logger.info("缓存清理任务已停止")
async def _clean_expired_entries(self) -> None: async def _clean_expired_entries(self) -> None:
"""清理过期的缓存条目""" """清理过期的缓存条目"""
try: try:
current_time = time.time() current_time = time.time()
# 清理 L1 过期条目 # 清理 L1 过期条目
async with self.l1_cache._lock: async with self.l1_cache._lock:
expired_keys = [ expired_keys = [
key for key, entry in self.l1_cache._cache.items() key for key, entry in self.l1_cache._cache.items()
if current_time - entry.created_at > self.l1_cache.ttl if current_time - entry.created_at > self.l1_cache.ttl
] ]
for key in expired_keys: for key in expired_keys:
entry = self.l1_cache._cache.pop(key, None) entry = self.l1_cache._cache.pop(key, None)
if entry: if entry:
self.l1_cache._stats.evictions += 1 self.l1_cache._stats.evictions += 1
self.l1_cache._stats.item_count -= 1 self.l1_cache._stats.item_count -= 1
self.l1_cache._stats.total_size -= entry.size self.l1_cache._stats.total_size -= entry.size
# 清理 L2 过期条目 # 清理 L2 过期条目
async with self.l2_cache._lock: async with self.l2_cache._lock:
expired_keys = [ expired_keys = [
key for key, entry in self.l2_cache._cache.items() key for key, entry in self.l2_cache._cache.items()
if current_time - entry.created_at > self.l2_cache.ttl if current_time - entry.created_at > self.l2_cache.ttl
] ]
for key in expired_keys: for key in expired_keys:
entry = self.l2_cache._cache.pop(key, None) entry = self.l2_cache._cache.pop(key, None)
if entry: if entry:
self.l2_cache._stats.evictions += 1 self.l2_cache._stats.evictions += 1
self.l2_cache._stats.item_count -= 1 self.l2_cache._stats.item_count -= 1
self.l2_cache._stats.total_size -= entry.size self.l2_cache._stats.total_size -= entry.size
if expired_keys: if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期缓存条目") logger.debug(f"清理了 {len(expired_keys)} 个过期缓存条目")
except Exception as e: except Exception as e:
logger.error(f"清理过期条目失败: {e}", exc_info=True) logger.error(f"清理过期条目失败: {e}", exc_info=True)
@@ -568,7 +568,7 @@ _cache_lock = asyncio.Lock()
async def get_cache() -> MultiLevelCache: async def get_cache() -> MultiLevelCache:
"""获取全局缓存实例(单例) """获取全局缓存实例(单例)
从配置文件读取缓存参数,如果配置未加载则使用默认值 从配置文件读取缓存参数,如果配置未加载则使用默认值
如果配置中禁用了缓存返回一个最小化的缓存实例容量为1 如果配置中禁用了缓存返回一个最小化的缓存实例容量为1
""" """
@@ -580,9 +580,9 @@ async def get_cache() -> MultiLevelCache:
# 尝试从配置读取参数 # 尝试从配置读取参数
try: try:
from src.config.config import global_config from src.config.config import global_config
db_config = global_config.database db_config = global_config.database
# 检查是否启用缓存 # 检查是否启用缓存
if not db_config.enable_database_cache: if not db_config.enable_database_cache:
logger.info("数据库缓存已禁用,使用最小化缓存实例") logger.info("数据库缓存已禁用,使用最小化缓存实例")
@@ -594,7 +594,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb=1, max_memory_mb=1,
) )
return _global_cache return _global_cache
l1_max_size = db_config.cache_l1_max_size l1_max_size = db_config.cache_l1_max_size
l1_ttl = db_config.cache_l1_ttl l1_ttl = db_config.cache_l1_ttl
l2_max_size = db_config.cache_l2_max_size l2_max_size = db_config.cache_l2_max_size
@@ -602,7 +602,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb = db_config.cache_max_memory_mb max_memory_mb = db_config.cache_max_memory_mb
max_item_size_mb = db_config.cache_max_item_size_mb max_item_size_mb = db_config.cache_max_item_size_mb
cleanup_interval = db_config.cache_cleanup_interval cleanup_interval = db_config.cache_cleanup_interval
logger.info( logger.info(
f"从配置加载缓存参数: L1({l1_max_size}/{l1_ttl}s), " f"从配置加载缓存参数: L1({l1_max_size}/{l1_ttl}s), "
f"L2({l2_max_size}/{l2_ttl}s), 内存限制({max_memory_mb}MB), " f"L2({l2_max_size}/{l2_ttl}s), 内存限制({max_memory_mb}MB), "
@@ -618,7 +618,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb = 100 max_memory_mb = 100
max_item_size_mb = 1 max_item_size_mb = 1
cleanup_interval = 60 cleanup_interval = 60
_global_cache = MultiLevelCache( _global_cache = MultiLevelCache(
l1_max_size=l1_max_size, l1_max_size=l1_max_size,
l1_ttl=l1_ttl, l1_ttl=l1_ttl,

View File

@@ -4,73 +4,74 @@
提供比 sys.getsizeof() 更准确的内存占用估算方法 提供比 sys.getsizeof() 更准确的内存占用估算方法
""" """
import sys
import pickle import pickle
import sys
from typing import Any from typing import Any
import numpy as np import numpy as np
def get_accurate_size(obj: Any, seen: set | None = None) -> int: def get_accurate_size(obj: Any, seen: set | None = None) -> int:
""" """
准确估算对象的内存大小(递归计算所有引用对象) 准确估算对象的内存大小(递归计算所有引用对象)
比 sys.getsizeof() 准确得多,特别是对于复杂嵌套对象。 比 sys.getsizeof() 准确得多,特别是对于复杂嵌套对象。
Args: Args:
obj: 要估算大小的对象 obj: 要估算大小的对象
seen: 已访问对象的集合(用于避免循环引用) seen: 已访问对象的集合(用于避免循环引用)
Returns: Returns:
估算的字节数 估算的字节数
""" """
if seen is None: if seen is None:
seen = set() seen = set()
obj_id = id(obj) obj_id = id(obj)
if obj_id in seen: if obj_id in seen:
return 0 return 0
seen.add(obj_id) seen.add(obj_id)
size = sys.getsizeof(obj) size = sys.getsizeof(obj)
# NumPy 数组特殊处理 # NumPy 数组特殊处理
if isinstance(obj, np.ndarray): if isinstance(obj, np.ndarray):
size += obj.nbytes size += obj.nbytes
return size return size
# 字典:递归计算所有键值对 # 字典:递归计算所有键值对
if isinstance(obj, dict): if isinstance(obj, dict):
size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen) size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen)
for k, v in obj.items()) for k, v in obj.items())
# 列表、元组、集合:递归计算所有元素 # 列表、元组、集合:递归计算所有元素
elif isinstance(obj, (list, tuple, set, frozenset)): elif isinstance(obj, list | tuple | set | frozenset):
size += sum(get_accurate_size(item, seen) for item in obj) size += sum(get_accurate_size(item, seen) for item in obj)
# 有 __dict__ 的对象:递归计算属性 # 有 __dict__ 的对象:递归计算属性
elif hasattr(obj, '__dict__'): elif hasattr(obj, "__dict__"):
size += get_accurate_size(obj.__dict__, seen) size += get_accurate_size(obj.__dict__, seen)
# 其他可迭代对象 # 其他可迭代对象
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): elif hasattr(obj, "__iter__") and not isinstance(obj, str | bytes | bytearray):
try: try:
size += sum(get_accurate_size(item, seen) for item in obj) size += sum(get_accurate_size(item, seen) for item in obj)
except: except:
pass pass
return size return size
def get_pickle_size(obj: Any) -> int: def get_pickle_size(obj: Any) -> int:
""" """
使用 pickle 序列化大小作为参考 使用 pickle 序列化大小作为参考
通常比 sys.getsizeof() 更接近实际内存占用, 通常比 sys.getsizeof() 更接近实际内存占用,
但可能略小于真实内存占用(不包括 Python 对象开销) 但可能略小于真实内存占用(不包括 Python 对象开销)
Args: Args:
obj: 要估算大小的对象 obj: 要估算大小的对象
Returns: Returns:
pickle 序列化后的字节数,失败返回 0 pickle 序列化后的字节数,失败返回 0
""" """
@@ -83,17 +84,17 @@ def get_pickle_size(obj: Any) -> int:
def estimate_size_smart(obj: Any, max_depth: int = 5, sample_large: bool = True) -> int: def estimate_size_smart(obj: Any, max_depth: int = 5, sample_large: bool = True) -> int:
""" """
智能估算对象大小(平衡准确性和性能) 智能估算对象大小(平衡准确性和性能)
使用深度受限的递归估算+采样策略,平衡准确性和性能: 使用深度受限的递归估算+采样策略,平衡准确性和性能:
- 深度5层足以覆盖99%的缓存数据结构 - 深度5层足以覆盖99%的缓存数据结构
- 对大型容器(>100项进行采样估算 - 对大型容器(>100项进行采样估算
- 性能开销约60倍于sys.getsizeof但准确度提升1000+倍 - 性能开销约60倍于sys.getsizeof但准确度提升1000+倍
Args: Args:
obj: 要估算大小的对象 obj: 要估算大小的对象
max_depth: 最大递归深度默认5层可覆盖大多数嵌套结构 max_depth: 最大递归深度默认5层可覆盖大多数嵌套结构
sample_large: 对大型容器是否采样默认True提升性能 sample_large: 对大型容器是否采样默认True提升性能
Returns: Returns:
估算的字节数 估算的字节数
""" """
@@ -105,24 +106,24 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
# 检查深度限制 # 检查深度限制
if depth <= 0: if depth <= 0:
return sys.getsizeof(obj) return sys.getsizeof(obj)
# 检查循环引用 # 检查循环引用
obj_id = id(obj) obj_id = id(obj)
if obj_id in seen: if obj_id in seen:
return 0 return 0
seen.add(obj_id) seen.add(obj_id)
# 基本大小 # 基本大小
size = sys.getsizeof(obj) size = sys.getsizeof(obj)
# 简单类型直接返回 # 简单类型直接返回
if isinstance(obj, (int, float, bool, type(None), str, bytes, bytearray)): if isinstance(obj, int | float | bool | type(None) | str | bytes | bytearray):
return size return size
# NumPy 数组特殊处理 # NumPy 数组特殊处理
if isinstance(obj, np.ndarray): if isinstance(obj, np.ndarray):
return size + obj.nbytes return size + obj.nbytes
# 字典递归 # 字典递归
if isinstance(obj, dict): if isinstance(obj, dict):
items = list(obj.items()) items = list(obj.items())
@@ -130,7 +131,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
# 大字典采样前50 + 中间50 + 最后50 # 大字典采样前50 + 中间50 + 最后50
sample_items = items[:50] + items[len(items)//2-25:len(items)//2+25] + items[-50:] sample_items = items[:50] + items[len(items)//2-25:len(items)//2+25] + items[-50:]
sampled_size = sum( sampled_size = sum(
_estimate_recursive(k, depth - 1, seen, sample_large) + _estimate_recursive(k, depth - 1, seen, sample_large) +
_estimate_recursive(v, depth - 1, seen, sample_large) _estimate_recursive(v, depth - 1, seen, sample_large)
for k, v in sample_items for k, v in sample_items
) )
@@ -142,9 +143,9 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
size += _estimate_recursive(k, depth - 1, seen, sample_large) size += _estimate_recursive(k, depth - 1, seen, sample_large)
size += _estimate_recursive(v, depth - 1, seen, sample_large) size += _estimate_recursive(v, depth - 1, seen, sample_large)
return size return size
# 列表、元组、集合递归 # 列表、元组、集合递归
if isinstance(obj, (list, tuple, set, frozenset)): if isinstance(obj, list | tuple | set | frozenset):
items = list(obj) items = list(obj)
if sample_large and len(items) > 100: if sample_large and len(items) > 100:
# 大容器采样前50 + 中间50 + 最后50 # 大容器采样前50 + 中间50 + 最后50
@@ -160,21 +161,21 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
for item in items: for item in items:
size += _estimate_recursive(item, depth - 1, seen, sample_large) size += _estimate_recursive(item, depth - 1, seen, sample_large)
return size return size
# 有 __dict__ 的对象 # 有 __dict__ 的对象
if hasattr(obj, '__dict__'): if hasattr(obj, "__dict__"):
size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large) size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large)
return size return size
def format_size(size_bytes: int) -> str: def format_size(size_bytes: int) -> str:
""" """
格式化字节数为人类可读的格式 格式化字节数为人类可读的格式
Args: Args:
size_bytes: 字节数 size_bytes: 字节数
Returns: Returns:
格式化后的字符串,如 "1.23 MB" 格式化后的字符串,如 "1.23 MB"
""" """

View File

@@ -2,7 +2,8 @@ import os
import socket import socket
from fastapi import APIRouter, FastAPI from fastapi import APIRouter, FastAPI
from fastapi.middleware.cors import CORSMiddleware # 新增导入 from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from rich.traceback import install from rich.traceback import install
from uvicorn import Config from uvicorn import Config
from uvicorn import Server as UvicornServer from uvicorn import Server as UvicornServer

View File

@@ -2,7 +2,6 @@ import os
import shutil import shutil
import sys import sys
from datetime import datetime from datetime import datetime
from typing import Optional
import tomlkit import tomlkit
from pydantic import Field from pydantic import Field
@@ -381,7 +380,7 @@ class Config(ValidatedConfigBase):
notice: NoticeConfig = Field(..., description="Notice消息配置") notice: NoticeConfig = Field(..., description="Notice消息配置")
emoji: EmojiConfig = Field(..., description="表情配置") emoji: EmojiConfig = Field(..., description="表情配置")
expression: ExpressionConfig = Field(..., description="表达配置") expression: ExpressionConfig = Field(..., description="表达配置")
memory: Optional[MemoryConfig] = Field(default=None, description="记忆配置") memory: MemoryConfig | None = Field(default=None, description="记忆配置")
mood: MoodConfig = Field(..., description="情绪配置") mood: MoodConfig = Field(..., description="情绪配置")
reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置") reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置")
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置") chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")

View File

@@ -401,16 +401,16 @@ class MemoryConfig(ValidatedConfigBase):
memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡") memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡")
memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流") memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流")
memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列") memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列")
# === 记忆图系统配置 (Memory Graph System) === # === 记忆图系统配置 (Memory Graph System) ===
# 新一代记忆系统的配置项 # 新一代记忆系统的配置项
enable: bool = Field(default=True, description="启用记忆图系统") enable: bool = Field(default=True, description="启用记忆图系统")
data_dir: str = Field(default="data/memory_graph", description="记忆数据存储目录") data_dir: str = Field(default="data/memory_graph", description="记忆数据存储目录")
# 向量存储配置 # 向量存储配置
vector_collection_name: str = Field(default="memory_nodes", description="向量集合名称") vector_collection_name: str = Field(default="memory_nodes", description="向量集合名称")
vector_db_path: str = Field(default="data/memory_graph/chroma_db", description="向量数据库路径") vector_db_path: str = Field(default="data/memory_graph/chroma_db", description="向量数据库路径")
# 检索配置 # 检索配置
search_top_k: int = Field(default=10, description="默认检索返回数量") search_top_k: int = Field(default=10, description="默认检索返回数量")
search_min_importance: float = Field(default=0.3, description="最小重要性阈值") search_min_importance: float = Field(default=0.3, description="最小重要性阈值")
@@ -418,13 +418,13 @@ class MemoryConfig(ValidatedConfigBase):
search_max_expand_depth: int = Field(default=2, description="检索时图扩展深度0-3") search_max_expand_depth: int = Field(default=2, description="检索时图扩展深度0-3")
search_expand_semantic_threshold: float = Field(default=0.3, description="图扩展时语义相似度阈值建议0.3-0.5,过低可能引入无关记忆,过高无法扩展)") search_expand_semantic_threshold: float = Field(default=0.3, description="图扩展时语义相似度阈值建议0.3-0.5,过低可能引入无关记忆,过高无法扩展)")
enable_query_optimization: bool = Field(default=True, description="启用查询优化") enable_query_optimization: bool = Field(default=True, description="启用查询优化")
# 检索权重配置 (记忆图系统) # 检索权重配置 (记忆图系统)
search_vector_weight: float = Field(default=0.4, description="向量相似度权重") search_vector_weight: float = Field(default=0.4, description="向量相似度权重")
search_graph_distance_weight: float = Field(default=0.2, description="图距离权重") search_graph_distance_weight: float = Field(default=0.2, description="图距离权重")
search_importance_weight: float = Field(default=0.2, description="重要性权重") search_importance_weight: float = Field(default=0.2, description="重要性权重")
search_recency_weight: float = Field(default=0.2, description="时效性权重") search_recency_weight: float = Field(default=0.2, description="时效性权重")
# 记忆整合配置 # 记忆整合配置
consolidation_enabled: bool = Field(default=False, description="是否启用记忆整合") consolidation_enabled: bool = Field(default=False, description="是否启用记忆整合")
consolidation_interval_hours: float = Field(default=2.0, description="整合任务执行间隔(小时)") consolidation_interval_hours: float = Field(default=2.0, description="整合任务执行间隔(小时)")
@@ -442,21 +442,21 @@ class MemoryConfig(ValidatedConfigBase):
consolidation_linking_min_confidence: float = Field(default=0.7, description="LLM分析最低置信度阈值") consolidation_linking_min_confidence: float = Field(default=0.7, description="LLM分析最低置信度阈值")
consolidation_linking_llm_temperature: float = Field(default=0.2, description="LLM分析温度参数") consolidation_linking_llm_temperature: float = Field(default=0.2, description="LLM分析温度参数")
consolidation_linking_llm_max_tokens: int = Field(default=1500, description="LLM分析最大输出长度") consolidation_linking_llm_max_tokens: int = Field(default=1500, description="LLM分析最大输出长度")
# 遗忘配置 (记忆图系统) # 遗忘配置 (记忆图系统)
forgetting_enabled: bool = Field(default=True, description="是否启用自动遗忘") forgetting_enabled: bool = Field(default=True, description="是否启用自动遗忘")
forgetting_activation_threshold: float = Field(default=0.1, description="激活度阈值") forgetting_activation_threshold: float = Field(default=0.1, description="激活度阈值")
forgetting_min_importance: float = Field(default=0.8, description="最小保护重要性") forgetting_min_importance: float = Field(default=0.8, description="最小保护重要性")
# 激活配置 # 激活配置
activation_decay_rate: float = Field(default=0.9, description="激活度衰减率") activation_decay_rate: float = Field(default=0.9, description="激活度衰减率")
activation_propagation_strength: float = Field(default=0.5, description="激活传播强度") activation_propagation_strength: float = Field(default=0.5, description="激活传播强度")
activation_propagation_depth: int = Field(default=2, description="激活传播深度") activation_propagation_depth: int = Field(default=2, description="激活传播深度")
# 性能配置 # 性能配置
max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数") max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数")
max_related_memories: int = Field(default=5, description="相关记忆最大数量") max_related_memories: int = Field(default=5, description="相关记忆最大数量")
# 节点去重合并配置 # 节点去重合并配置
node_merger_similarity_threshold: float = Field(default=0.85, description="节点去重相似度阈值") node_merger_similarity_threshold: float = Field(default=0.85, description="节点去重相似度阈值")
node_merger_context_match_required: bool = Field(default=True, description="节点合并是否要求上下文匹配") node_merger_context_match_required: bool = Field(default=True, description="节点合并是否要求上下文匹配")
@@ -637,6 +637,7 @@ class WebSearchConfig(ValidatedConfigBase):
exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表支持轮询机制") exa_api_keys: list[str] = Field(default_factory=lambda: [], description="exa API密钥列表支持轮询机制")
searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表") searxng_instances: list[str] = Field(default_factory=list, description="SearXNG 实例 URL 列表")
searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表") searxng_api_keys: list[str] = Field(default_factory=list, description="SearXNG 实例 API 密钥列表")
serper_api_keys: list[str] = Field(default_factory=list, description="serper API 密钥列表")
enabled_engines: list[str] = Field(default_factory=lambda: ["ddg"], description="启用的搜索引擎") enabled_engines: list[str] = Field(default_factory=lambda: ["ddg"], description="启用的搜索引擎")
search_strategy: Literal["fallback", "single", "parallel"] = Field(default="single", description="搜索策略") search_strategy: Literal["fallback", "single", "parallel"] = Field(default="single", description="搜索策略")

View File

@@ -534,7 +534,7 @@ class _RequestExecutor:
model_name = model_info.name model_name = model_info.name
retry_interval = api_provider.retry_interval retry_interval = api_provider.retry_interval
if isinstance(e, (NetworkConnectionError, ReqAbortException)): if isinstance(e, NetworkConnectionError | ReqAbortException):
return await self._check_retry(remain_try, retry_interval, "连接异常", model_name) return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
elif isinstance(e, RespNotOkException): elif isinstance(e, RespNotOkException):
return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info) return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)

View File

@@ -423,15 +423,16 @@ MoFox_Bot(第三方修改版)
# 注册API路由 # 注册API路由
try: try:
from src.api.memory_visualizer_router import router as visualizer_router
from src.api.message_router import router as message_router from src.api.message_router import router as message_router
from src.api.statistic_router import router as llm_statistic_router from src.api.statistic_router import router as llm_statistic_router
self.server.register_router(message_router, prefix="/api") self.server.register_router(message_router, prefix="/api")
self.server.register_router(llm_statistic_router, prefix="/api") self.server.register_router(llm_statistic_router, prefix="/api")
self.server.register_router(visualizer_router, prefix="/visualizer")
logger.info("API路由注册成功") logger.info("API路由注册成功")
except Exception as e: except Exception as e:
logger.error(f"注册API路由失败: {e}") logger.error(f"注册API路由失败: {e}")
# 初始化统一调度器 # 初始化统一调度器
try: try:
from src.schedule.unified_scheduler import initialize_scheduler from src.schedule.unified_scheduler import initialize_scheduler

View File

@@ -100,10 +100,10 @@ class VectorStore:
# 处理额外的元数据,将 list 转换为 JSON 字符串 # 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items(): for key, value in node.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, list | dict):
import orjson import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value metadata[key] = value
else: else:
metadata[key] = str(value) metadata[key] = str(value)
@@ -149,9 +149,9 @@ class VectorStore:
"created_at": n.created_at.isoformat(), "created_at": n.created_at.isoformat(),
} }
for key, value in n.metadata.items(): for key, value in n.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, list | dict):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value # type: ignore metadata[key] = value # type: ignore
else: else:
metadata[key] = str(value) metadata[key] = str(value)

View File

@@ -4,8 +4,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import numpy as np import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -72,7 +70,7 @@ class EmbeddingGenerator:
logger.warning(f"⚠️ Embedding API 初始化失败: {e}") logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
self._api_available = False self._api_available = False
async def generate(self, text: str) -> np.ndarray | None: async def generate(self, text: str) -> np.ndarray | None:
""" """
生成单个文本的嵌入向量 生成单个文本的嵌入向量
@@ -130,7 +128,7 @@ class EmbeddingGenerator:
logger.debug(f"API 嵌入生成失败: {e}") logger.debug(f"API 嵌入生成失败: {e}")
return None return None
def _get_dimension(self) -> int: def _get_dimension(self) -> int:
"""获取嵌入维度""" """获取嵌入维度"""
# 优先使用 API 维度 # 优先使用 API 维度

View File

@@ -7,11 +7,12 @@
""" """
import atexit import atexit
import orjson
import os import os
import threading import threading
from typing import Any, ClassVar from typing import Any, ClassVar
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
# 获取日志记录器 # 获取日志记录器
@@ -125,7 +126,7 @@ class PluginStorage:
try: try:
with open(self.file_path, "w", encoding="utf-8") as f: with open(self.file_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')) f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8"))
self._dirty = False # 保存后重置标志 self._dirty = False # 保存后重置标志
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。") logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
except Exception as e: except Exception as e:

View File

@@ -5,12 +5,12 @@ MCP Client Manager
""" """
import asyncio import asyncio
import orjson
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import mcp.types import mcp.types
import orjson
from fastmcp.client import Client, StdioTransport, StreamableHttpTransport from fastmcp.client import Client, StdioTransport, StreamableHttpTransport
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -4,11 +4,13 @@
""" """
import time import time
from typing import Any, Optional from dataclasses import dataclass, field
from dataclasses import dataclass, asdict, field from typing import Any
import orjson import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
logger = get_logger("stream_tool_history") logger = get_logger("stream_tool_history")
@@ -18,10 +20,10 @@ class ToolCallRecord:
"""工具调用记录""" """工具调用记录"""
tool_name: str tool_name: str
args: dict[str, Any] args: dict[str, Any]
result: Optional[dict[str, Any]] = None result: dict[str, Any] | None = None
status: str = "success" # success, error, pending status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time) timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒) execution_time: float | None = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存 cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览 result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息 error_message: str = "" # 错误信息
@@ -32,9 +34,9 @@ class ToolCallRecord:
content = self.result.get("content", "") content = self.result.get("content", "")
if isinstance(content, str): if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "") self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)): elif isinstance(content, list | dict):
try: try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..." self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..."
except Exception: except Exception:
self.result_preview = str(content)[:500] + "..." self.result_preview = str(content)[:500] + "..."
else: else:
@@ -105,7 +107,7 @@ class StreamToolHistoryManager:
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}") logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""从缓存或历史记录中获取结果 """从缓存或历史记录中获取结果
Args: Args:
@@ -160,9 +162,9 @@ class StreamToolHistoryManager:
return None return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any], async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None, execution_time: float | None = None,
tool_file_path: Optional[str] = None, tool_file_path: str | None = None,
ttl: Optional[int] = None) -> None: ttl: int | None = None) -> None:
"""缓存工具调用结果 """缓存工具调用结果
Args: Args:
@@ -207,7 +209,7 @@ class StreamToolHistoryManager:
except Exception as e: except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}") logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]: async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
"""获取最近的历史记录 """获取最近的历史记录
Args: Args:
@@ -295,7 +297,7 @@ class StreamToolHistoryManager:
self._history.clear() self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除") logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""在内存历史记录中搜索缓存 """在内存历史记录中搜索缓存
Args: Args:
@@ -333,7 +335,7 @@ class StreamToolHistoryManager:
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py") return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]: def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> str | None:
"""提取语义查询参数 """提取语义查询参数
Args: Args:
@@ -370,7 +372,7 @@ class StreamToolHistoryManager:
return "" return ""
try: try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8') args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
if len(args_str) > max_length: if len(args_str) > max_length:
args_str = args_str[:max_length] + "..." args_str = args_str[:max_length] + "..."
return args_str return args_str
@@ -411,4 +413,4 @@ def cleanup_stream_manager(chat_id: str) -> None:
""" """
if chat_id in _stream_managers: if chat_id in _stream_managers:
del _stream_managers[chat_id] del _stream_managers[chat_id]
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器") logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")

View File

@@ -1,5 +1,6 @@
import inspect import inspect
import time import time
from dataclasses import asdict
from typing import Any from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
@@ -10,8 +11,7 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_stream_tool_history_manager
from dataclasses import asdict
logger = get_logger("tool_use") logger = get_logger("tool_use")
@@ -140,7 +140,7 @@ class ToolExecutor:
# 构建工具调用历史文本 # 构建工具调用历史文本
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True) tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
# 获取人设信息 # 获取人设信息
personality_core = global_config.personality.personality_core personality_core = global_config.personality.personality_core
personality_side = global_config.personality.personality_side personality_side = global_config.personality.personality_side
@@ -197,7 +197,7 @@ class ToolExecutor:
return tool_definitions return tool_definitions
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]: async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用 """执行工具调用
@@ -338,9 +338,8 @@ class ToolExecutor:
if tool_instance and result and tool_instance.enable_cache: if tool_instance and result and tool_instance.enable_cache:
try: try:
tool_file_path = inspect.getfile(tool_instance.__class__) tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key: if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key) function_args.get(tool_instance.semantic_cache_query_key)
await self.history_manager.cache_result( await self.history_manager.cache_result(
tool_name=tool_call.func_name, tool_name=tool_call.func_name,

View File

@@ -122,7 +122,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
+ relationship_score * self.score_weights["relationship"] + relationship_score * self.score_weights["relationship"]
+ mentioned_score * self.score_weights["mentioned"] + mentioned_score * self.score_weights["mentioned"]
) )
# 限制总分上限为1.0,确保分数在合理范围内 # 限制总分上限为1.0,确保分数在合理范围内
total_score = min(raw_total_score, 1.0) total_score = min(raw_total_score, 1.0)
@@ -131,7 +131,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
f"{relationship_score:.3f}*{self.score_weights['relationship']} + " f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}" f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}"
) )
if raw_total_score > 1.0: if raw_total_score > 1.0:
logger.debug(f"[Affinity兴趣计算] 原始分数 {raw_total_score:.3f} 超过1.0,已限制为 {total_score:.3f}") logger.debug(f"[Affinity兴趣计算] 原始分数 {raw_total_score:.3f} 超过1.0,已限制为 {total_score:.3f}")
@@ -217,7 +217,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
return 0.0 return 0.0
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数") logger.warning("⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数")
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分 return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
except Exception as e: except Exception as e:
logger.warning(f"智能兴趣匹配失败: {e}") logger.warning(f"智能兴趣匹配失败: {e}")
@@ -251,19 +251,19 @@ class AffinityInterestCalculator(BaseInterestCalculator):
def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float: def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float:
"""计算提及分 - 区分强提及和弱提及 """计算提及分 - 区分强提及和弱提及
强提及(被@、被回复、私聊): 使用 strong_mention_interest_score 强提及(被@、被回复、私聊): 使用 strong_mention_interest_score
弱提及(文本匹配名字/别名): 使用 weak_mention_interest_score 弱提及(文本匹配名字/别名): 使用 weak_mention_interest_score
""" """
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
# 使用统一的提及检测函数 # 使用统一的提及检测函数
is_mentioned, mention_type = is_mentioned_bot_in_message(message) is_mentioned, mention_type = is_mentioned_bot_in_message(message)
if not is_mentioned: if not is_mentioned:
logger.debug("[提及分计算] 未提及机器人返回0.0") logger.debug("[提及分计算] 未提及机器人返回0.0")
return 0.0 return 0.0
# mention_type: 0=未提及, 1=弱提及, 2=强提及 # mention_type: 0=未提及, 1=弱提及, 2=强提及
if mention_type >= 2: if mention_type >= 2:
# 强提及:被@、被回复、私聊 # 强提及:被@、被回复、私聊
@@ -281,22 +281,22 @@ class AffinityInterestCalculator(BaseInterestCalculator):
def _apply_no_reply_threshold_adjustment(self) -> tuple[float, float]: def _apply_no_reply_threshold_adjustment(self) -> tuple[float, float]:
"""应用阈值调整(包括连续不回复和回复后降低机制) """应用阈值调整(包括连续不回复和回复后降低机制)
Returns: Returns:
tuple[float, float]: (调整后的回复阈值, 调整后的动作阈值) tuple[float, float]: (调整后的回复阈值, 调整后的动作阈值)
""" """
# 基础阈值 # 基础阈值
base_reply_threshold = self.reply_threshold base_reply_threshold = self.reply_threshold
base_action_threshold = global_config.affinity_flow.non_reply_action_interest_threshold base_action_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
total_reduction = 0.0 total_reduction = 0.0
# 1. 连续不回复的阈值降低 # 1. 连续不回复的阈值降低
if self.no_reply_count > 0 and self.no_reply_count < self.max_no_reply_count: if self.no_reply_count > 0 and self.no_reply_count < self.max_no_reply_count:
no_reply_reduction = self.no_reply_count * self.probability_boost_per_no_reply no_reply_reduction = self.no_reply_count * self.probability_boost_per_no_reply
total_reduction += no_reply_reduction total_reduction += no_reply_reduction
logger.debug(f"[阈值调整] 连续不回复降低: {no_reply_reduction:.3f} (计数: {self.no_reply_count})") logger.debug(f"[阈值调整] 连续不回复降低: {no_reply_reduction:.3f} (计数: {self.no_reply_count})")
# 2. 回复后的阈值降低使bot更容易连续对话 # 2. 回复后的阈值降低使bot更容易连续对话
if self.enable_post_reply_boost and self.post_reply_boost_remaining > 0: if self.enable_post_reply_boost and self.post_reply_boost_remaining > 0:
# 计算衰减后的降低值 # 计算衰减后的降低值
@@ -309,16 +309,16 @@ class AffinityInterestCalculator(BaseInterestCalculator):
f"[阈值调整] 回复后降低: {post_reply_reduction:.3f} " f"[阈值调整] 回复后降低: {post_reply_reduction:.3f} "
f"(剩余次数: {self.post_reply_boost_remaining}, 衰减: {decay_factor:.2f})" f"(剩余次数: {self.post_reply_boost_remaining}, 衰减: {decay_factor:.2f})"
) )
# 应用总降低量 # 应用总降低量
adjusted_reply_threshold = max(0.0, base_reply_threshold - total_reduction) adjusted_reply_threshold = max(0.0, base_reply_threshold - total_reduction)
adjusted_action_threshold = max(0.0, base_action_threshold - total_reduction) adjusted_action_threshold = max(0.0, base_action_threshold - total_reduction)
return adjusted_reply_threshold, adjusted_action_threshold return adjusted_reply_threshold, adjusted_action_threshold
def _apply_no_reply_boost(self, base_score: float) -> float: def _apply_no_reply_boost(self, base_score: float) -> float:
"""【已弃用】应用连续不回复的概率提升 """【已弃用】应用连续不回复的概率提升
注意:此方法已被 _apply_no_reply_threshold_adjustment 替代 注意:此方法已被 _apply_no_reply_threshold_adjustment 替代
保留用于向后兼容 保留用于向后兼容
""" """
@@ -388,7 +388,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
self.no_reply_count = 0 self.no_reply_count = 0
else: else:
self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count) self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count)
def on_reply_sent(self): def on_reply_sent(self):
"""当机器人发送回复后调用,激活回复后阈值降低机制""" """当机器人发送回复后调用,激活回复后阈值降低机制"""
if self.enable_post_reply_boost: if self.enable_post_reply_boost:
@@ -399,16 +399,16 @@ class AffinityInterestCalculator(BaseInterestCalculator):
) )
# 同时重置不回复计数 # 同时重置不回复计数
self.no_reply_count = 0 self.no_reply_count = 0
def on_message_processed(self, replied: bool): def on_message_processed(self, replied: bool):
"""消息处理完成后调用,更新各种计数器 """消息处理完成后调用,更新各种计数器
Args: Args:
replied: 是否回复了此消息 replied: 是否回复了此消息
""" """
# 更新不回复计数 # 更新不回复计数
self.update_no_reply_count(replied) self.update_no_reply_count(replied)
# 如果已回复,激活回复后降低机制 # 如果已回复,激活回复后降低机制
if replied: if replied:
self.on_reply_sent() self.on_reply_sent()

View File

@@ -4,10 +4,10 @@ AffinityFlow Chatter 规划器模块
包含计划生成、过滤、执行等规划相关功能 包含计划生成、过滤、执行等规划相关功能
""" """
from . import planner_prompts
from .plan_executor import ChatterPlanExecutor from .plan_executor import ChatterPlanExecutor
from .plan_filter import ChatterPlanFilter from .plan_filter import ChatterPlanFilter
from .plan_generator import ChatterPlanGenerator from .plan_generator import ChatterPlanGenerator
from .planner import ChatterActionPlanner from .planner import ChatterActionPlanner
from . import planner_prompts
__all__ = ["ChatterActionPlanner", "planner_prompts", "ChatterPlanGenerator", "ChatterPlanFilter", "ChatterPlanExecutor"] __all__ = ["ChatterActionPlanner", "ChatterPlanExecutor", "ChatterPlanFilter", "ChatterPlanGenerator", "planner_prompts"]

View File

@@ -14,9 +14,7 @@ from json_repair import repair_json
# 旧的Hippocampus系统已被移除现在使用增强记忆系统 # 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_actions,
build_readable_messages_with_id, build_readable_messages_with_id,
get_actions_by_timestamp_with_chat,
) )
from src.chat.utils.prompt import global_prompt_manager from src.chat.utils.prompt import global_prompt_manager
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
@@ -646,7 +644,7 @@ class ChatterPlanFilter:
memory_manager = get_memory_manager() memory_manager = get_memory_manager()
if not memory_manager: if not memory_manager:
return "记忆系统未初始化。" return "记忆系统未初始化。"
# 将关键词转换为查询字符串 # 将关键词转换为查询字符串
query = " ".join(keywords) query = " ".join(keywords)
enhanced_memories = await memory_manager.search_memories( enhanced_memories = await memory_manager.search_memories(

View File

@@ -21,7 +21,6 @@ if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
# 导入提示词模块以确保其被初始化 # 导入提示词模块以确保其被初始化
from src.plugins.built_in.affinity_flow_chatter.planner import planner_prompts
logger = get_logger("planner") logger = get_logger("planner")
@@ -159,10 +158,10 @@ class ChatterActionPlanner:
action_data={}, action_data={},
action_message=None, action_message=None,
) )
# 更新连续不回复计数 # 更新连续不回复计数
await self._update_interest_calculator_state(replied=False) await self._update_interest_calculator_state(replied=False)
initial_plan = await self.generator.generate(chat_mode) initial_plan = await self.generator.generate(chat_mode)
filtered_plan = initial_plan filtered_plan = initial_plan
filtered_plan.decided_actions = [no_action] filtered_plan.decided_actions = [no_action]
@@ -270,7 +269,7 @@ class ChatterActionPlanner:
try: try:
# Normal模式开始时刷新缓存消息到未读列表 # Normal模式开始时刷新缓存消息到未读列表
await self._flush_cached_messages_to_unread(context) await self._flush_cached_messages_to_unread(context)
unread_messages = context.get_unread_messages() if context else [] unread_messages = context.get_unread_messages() if context else []
if not unread_messages: if not unread_messages:
@@ -347,7 +346,7 @@ class ChatterActionPlanner:
self._update_stats_from_execution_result(execution_result) self._update_stats_from_execution_result(execution_result)
logger.info("Normal模式: 执行reply动作完成") logger.info("Normal模式: 执行reply动作完成")
# 更新兴趣计算器状态(回复成功,重置不回复计数) # 更新兴趣计算器状态(回复成功,重置不回复计数)
await self._update_interest_calculator_state(replied=True) await self._update_interest_calculator_state(replied=True)
@@ -465,7 +464,7 @@ class ChatterActionPlanner:
async def _update_interest_calculator_state(self, replied: bool) -> None: async def _update_interest_calculator_state(self, replied: bool) -> None:
"""更新兴趣计算器状态(连续不回复计数和回复后降低机制) """更新兴趣计算器状态(连续不回复计数和回复后降低机制)
Args: Args:
replied: 是否回复了消息 replied: 是否回复了消息
""" """
@@ -504,36 +503,36 @@ class ChatterActionPlanner:
async def _flush_cached_messages_to_unread(self, context: "StreamContext | None") -> list: async def _flush_cached_messages_to_unread(self, context: "StreamContext | None") -> list:
"""在planner开始时将缓存消息刷新到未读消息列表 """在planner开始时将缓存消息刷新到未读消息列表
此方法在动作修改器执行后、生成初始计划前调用,确保计划阶段能看到所有积累的消息。 此方法在动作修改器执行后、生成初始计划前调用,确保计划阶段能看到所有积累的消息。
Args: Args:
context: 流上下文 context: 流上下文
Returns: Returns:
list: 刷新的消息列表 list: 刷新的消息列表
""" """
if not context: if not context:
return [] return []
try: try:
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
stream_id = context.stream_id stream_id = context.stream_id
if message_manager.is_running and message_manager.has_cached_messages(stream_id): if message_manager.is_running and message_manager.has_cached_messages(stream_id):
# 获取缓存消息 # 获取缓存消息
cached_messages = message_manager.flush_cached_messages(stream_id) cached_messages = message_manager.flush_cached_messages(stream_id)
if cached_messages: if cached_messages:
# 直接添加到上下文的未读消息列表 # 直接添加到上下文的未读消息列表
for message in cached_messages: for message in cached_messages:
context.unread_messages.append(message) context.unread_messages.append(message)
logger.info(f"Planner开始前刷新缓存消息到未读列表: stream={stream_id}, 数量={len(cached_messages)}") logger.info(f"Planner开始前刷新缓存消息到未读列表: stream={stream_id}, 数量={len(cached_messages)}")
return cached_messages return cached_messages
return [] return []
except ImportError: except ImportError:
logger.debug("MessageManager不可用跳过缓存刷新") logger.debug("MessageManager不可用跳过缓存刷新")
return [] return []

View File

@@ -9,9 +9,9 @@ from .proactive_thinking_executor import execute_proactive_thinking
from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler
__all__ = [ __all__ = [
"ProactiveThinkingReplyHandler",
"ProactiveThinkingMessageHandler", "ProactiveThinkingMessageHandler",
"execute_proactive_thinking", "ProactiveThinkingReplyHandler",
"ProactiveThinkingScheduler", "ProactiveThinkingScheduler",
"execute_proactive_thinking",
"proactive_thinking_scheduler", "proactive_thinking_scheduler",
] ]

View File

@@ -3,7 +3,6 @@
当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复 当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复
""" """
import orjson
from datetime import datetime from datetime import datetime
from typing import Any, Literal from typing import Any, Literal

View File

@@ -14,7 +14,6 @@ from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import config_api, generator_api, llm_api from src.plugin_system.apis import config_api, generator_api, llm_api
@@ -320,7 +319,7 @@ class ContentService:
- 禁止在说说中直接、完整地提及当前的年月日,除非日期有特殊含义,但也尽量用节日名/节气名字代替。 - 禁止在说说中直接、完整地提及当前的年月日,除非日期有特殊含义,但也尽量用节日名/节气名字代替。
2. **严禁重复**:下方会提供你最近发过的说说历史,你必须创作一条全新的、与历史记录内容和主题都不同的说说。 2. **严禁重复**:下方会提供你最近发过的说说历史,你必须创作一条全新的、与历史记录内容和主题都不同的说说。
**其他的禁止的内容以及说明** **其他的禁止的内容以及说明**
- 绝对禁止提及当下具体几点几分的时间戳。 - 绝对禁止提及当下具体几点几分的时间戳。
- 绝对禁止攻击性内容和过度的负面情绪。 - 绝对禁止攻击性内容和过度的负面情绪。

View File

@@ -136,10 +136,10 @@ class QZoneService:
logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}") logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}")
api_client = await self._get_api_client(qq_account, stream_id) api_client = await self._get_api_client(qq_account, stream_id)
if not api_client: if not api_client:
logger.error(f"[DEBUG] API客户端获取失败返回错误") logger.error("[DEBUG] API客户端获取失败返回错误")
return {"success": False, "message": "获取QZone API客户端失败"} return {"success": False, "message": "获取QZone API客户端失败"}
logger.info(f"[DEBUG] API客户端获取成功准备读取说说") logger.info("[DEBUG] API客户端获取成功准备读取说说")
num_to_read = self.get_config("read.read_number", 5) num_to_read = self.get_config("read.read_number", 5)
# 尝试执行如果Cookie失效则自动重试一次 # 尝试执行如果Cookie失效则自动重试一次
@@ -186,7 +186,7 @@ class QZoneService:
# 检查是否是Cookie失效-3000错误 # 检查是否是Cookie失效-3000错误
if "错误码: -3000" in error_msg and retry_count == 0: if "错误码: -3000" in error_msg and retry_count == 0:
logger.warning(f"检测到Cookie失效-3000错误准备删除缓存并重试...") logger.warning("检测到Cookie失效-3000错误准备删除缓存并重试...")
# 删除Cookie缓存文件 # 删除Cookie缓存文件
cookie_file = self.cookie_service._get_cookie_file_path(qq_account) cookie_file = self.cookie_service._get_cookie_file_path(qq_account)
@@ -623,7 +623,7 @@ class QZoneService:
logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}") logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
return None return None
logger.info(f"[DEBUG] p_skey获取成功") logger.info("[DEBUG] p_skey获取成功")
gtk = self._generate_gtk(p_skey) gtk = self._generate_gtk(p_skey)
uin = cookies.get("uin", "").lstrip("o") uin = cookies.get("uin", "").lstrip("o")
@@ -1230,7 +1230,7 @@ class QZoneService:
logger.error(f"监控好友动态失败: {e}", exc_info=True) logger.error(f"监控好友动态失败: {e}", exc_info=True)
return [] return []
logger.info(f"[DEBUG] API客户端构造完成返回包含6个方法的字典") logger.info("[DEBUG] API客户端构造完成返回包含6个方法的字典")
return { return {
"publish": _publish, "publish": _publish,
"list_feeds": _list_feeds, "list_feeds": _list_feeds,

View File

@@ -3,11 +3,12 @@
负责记录和管理已回复过的评论ID避免重复回复 负责记录和管理已回复过的评论ID避免重复回复
""" """
import orjson
import time import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("MaiZone.ReplyTrackerService") logger = get_logger("MaiZone.ReplyTrackerService")
@@ -117,8 +118,8 @@ class ReplyTrackerService:
temp_file = self.reply_record_file.with_suffix(".tmp") temp_file = self.reply_record_file.with_suffix(".tmp")
# 先写入临时文件 # 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f: with open(temp_file, "w", encoding="utf-8"):
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8') orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8")
# 如果写入成功,重命名为正式文件 # 如果写入成功,重命名为正式文件
if temp_file.stat().st_size > 0: # 确保写入成功 if temp_file.stat().st_size > 0: # 确保写入成功

View File

@@ -1,7 +1,6 @@
import orjson import orjson
import random import random
import time import time
import random
import websockets as Server import websockets as Server
import uuid import uuid
from maim_message import ( from maim_message import (
@@ -205,7 +204,7 @@ class SendHandler:
# 发送响应回MoFox-Bot # 发送响应回MoFox-Bot
logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}") logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}")
await self.send_adapter_command_response(raw_message_base, response, request_id) await self.send_adapter_command_response(raw_message_base, response, request_id)
logger.debug(f"[DEBUG handle_adapter_command] send_adapter_command_response调用完成") logger.debug("[DEBUG handle_adapter_command] send_adapter_command_response调用完成")
if response.get("status") == "ok": if response.get("status") == "ok":
logger.info(f"适配器命令 {action} 执行成功") logger.info(f"适配器命令 {action} 执行成功")

View File

@@ -1,10 +1,10 @@
""" """
Metaso Search Engine (Chat Completions Mode) Metaso Search Engine (Chat Completions Mode)
""" """
import orjson
from typing import Any from typing import Any
import httpx import httpx
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api

View File

@@ -3,9 +3,10 @@ Serper search engine implementation
Google Search via Serper.dev API Google Search via Serper.dev API
""" """
import aiohttp
from typing import Any from typing import Any
import aiohttp
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
""" """
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool from .tools.url_parser import URLParserTool

View File

@@ -113,7 +113,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
search_tasks.append(engine.answer_search(custom_args)) search_tasks.append(engine.answer_search(custom_args))
else: else:
search_tasks.append(engine.search(custom_args)) search_tasks.append(engine.search(custom_args))
@@ -162,7 +162,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索fallback策略") logger.info("使用Exa答案模式进行搜索fallback策略")
results = await engine.answer_search(custom_args) results = await engine.answer_search(custom_args)
else: else:
@@ -195,7 +195,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索") logger.info("使用Exa答案模式进行搜索")
results = await engine.answer_search(custom_args) results = await engine.answer_search(custom_args)
else: else:

View File

@@ -266,13 +266,13 @@ class UnifiedScheduler:
name=f"execute_{task.task_name}" name=f"execute_{task.task_name}"
) )
execution_tasks.append(execution_task) execution_tasks.append(execution_task)
# 追踪正在执行的任务,以便在 remove_schedule 时可以取消 # 追踪正在执行的任务,以便在 remove_schedule 时可以取消
self._executing_tasks[task.schedule_id] = execution_task self._executing_tasks[task.schedule_id] = execution_task
# 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务) # 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务)
results = await asyncio.gather(*execution_tasks, return_exceptions=True) results = await asyncio.gather(*execution_tasks, return_exceptions=True)
# 清理执行追踪 # 清理执行追踪
for task in tasks_to_trigger: for task in tasks_to_trigger:
self._executing_tasks.pop(task.schedule_id, None) self._executing_tasks.pop(task.schedule_id, None)
@@ -515,7 +515,7 @@ class UnifiedScheduler:
async def remove_schedule(self, schedule_id: str) -> bool: async def remove_schedule(self, schedule_id: str) -> bool:
"""移除调度任务 """移除调度任务
如果任务正在执行,会取消执行中的任务 如果任务正在执行,会取消执行中的任务
""" """
async with self._lock: async with self._lock:
@@ -524,7 +524,7 @@ class UnifiedScheduler:
return False return False
task = self._tasks[schedule_id] task = self._tasks[schedule_id]
# 检查是否有正在执行的任务 # 检查是否有正在执行的任务
executing_task = self._executing_tasks.get(schedule_id) executing_task = self._executing_tasks.get(schedule_id)
if executing_task and not executing_task.done(): if executing_task and not executing_task.done():

View File

@@ -19,42 +19,42 @@ logger = get_logger(__name__)
def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str, Any] | list | None: def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str, Any] | list | None:
""" """
从 LLM 响应中提取并解析 JSON 从 LLM 响应中提取并解析 JSON
处理策略: 处理策略:
1. 清理 Markdown 代码块标记(```json 和 ``` 1. 清理 Markdown 代码块标记(```json 和 ```
2. 提取 JSON 对象或数组 2. 提取 JSON 对象或数组
3. 使用 json_repair 修复格式问题 3. 使用 json_repair 修复格式问题
4. 解析为 Python 对象 4. 解析为 Python 对象
Args: Args:
response: LLM 响应字符串 response: LLM 响应字符串
strict: 严格模式,如果为 True 则解析失败时返回 None否则尝试容错处理 strict: 严格模式,如果为 True 则解析失败时返回 None否则尝试容错处理
Returns: Returns:
解析后的 dict 或 list失败时返回 None 解析后的 dict 或 list失败时返回 None
Examples: Examples:
>>> extract_and_parse_json('```json\\n{"key": "value"}\\n```') >>> extract_and_parse_json('```json\\n{"key": "value"}\\n```')
{'key': 'value'} {'key': 'value'}
>>> extract_and_parse_json('Some text {"key": "value"} more text') >>> extract_and_parse_json('Some text {"key": "value"} more text')
{'key': 'value'} {'key': 'value'}
>>> extract_and_parse_json('[{"a": 1}, {"b": 2}]') >>> extract_and_parse_json('[{"a": 1}, {"b": 2}]')
[{'a': 1}, {'b': 2}] [{'a': 1}, {'b': 2}]
""" """
if not response: if not response:
logger.debug("空响应,无法解析 JSON") logger.debug("空响应,无法解析 JSON")
return None return None
try: try:
# 步骤 1: 清理响应 # 步骤 1: 清理响应
cleaned = _clean_llm_response(response) cleaned = _clean_llm_response(response)
if not cleaned: if not cleaned:
logger.warning("清理后的响应为空") logger.warning("清理后的响应为空")
return None return None
# 步骤 2: 尝试直接解析 # 步骤 2: 尝试直接解析
try: try:
result = orjson.loads(cleaned) result = orjson.loads(cleaned)
@@ -62,11 +62,11 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
return result return result
except Exception as direct_error: except Exception as direct_error:
logger.debug(f"直接解析失败: {type(direct_error).__name__}: {direct_error}") logger.debug(f"直接解析失败: {type(direct_error).__name__}: {direct_error}")
# 步骤 3: 使用 json_repair 修复并解析 # 步骤 3: 使用 json_repair 修复并解析
try: try:
repaired = repair_json(cleaned) repaired = repair_json(cleaned)
# repair_json 可能返回字符串或已解析的对象 # repair_json 可能返回字符串或已解析的对象
if isinstance(repaired, str): if isinstance(repaired, str):
result = orjson.loads(repaired) result = orjson.loads(repaired)
@@ -74,16 +74,16 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
else: else:
result = repaired result = repaired
logger.debug(f"✅ JSON 修复后解析成功(对象模式),类型: {type(result).__name__}") logger.debug(f"✅ JSON 修复后解析成功(对象模式),类型: {type(result).__name__}")
return result return result
except Exception as repair_error: except Exception as repair_error:
logger.warning(f"JSON 修复失败: {type(repair_error).__name__}: {repair_error}") logger.warning(f"JSON 修复失败: {type(repair_error).__name__}: {repair_error}")
if strict: if strict:
logger.error(f"严格模式下解析失败,响应片段: {cleaned[:200]}") logger.error(f"严格模式下解析失败,响应片段: {cleaned[:200]}")
return None return None
# 最后的容错尝试:返回空字典或空列表 # 最后的容错尝试:返回空字典或空列表
if cleaned.strip().startswith("["): if cleaned.strip().startswith("["):
logger.warning("返回空列表作为容错") logger.warning("返回空列表作为容错")
@@ -91,7 +91,7 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
else: else:
logger.warning("返回空字典作为容错") logger.warning("返回空字典作为容错")
return {} return {}
except Exception as e: except Exception as e:
logger.error(f"❌ JSON 解析过程出现异常: {type(e).__name__}: {e}") logger.error(f"❌ JSON 解析过程出现异常: {type(e).__name__}: {e}")
if strict: if strict:
@@ -102,37 +102,37 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
def _clean_llm_response(response: str) -> str: def _clean_llm_response(response: str) -> str:
""" """
清理 LLM 响应,提取 JSON 部分 清理 LLM 响应,提取 JSON 部分
处理步骤: 处理步骤:
1. 移除 Markdown 代码块标记(```json 和 ``` 1. 移除 Markdown 代码块标记(```json 和 ```
2. 提取第一个完整的 JSON 对象 {...} 或数组 [...] 2. 提取第一个完整的 JSON 对象 {...} 或数组 [...]
3. 清理多余的空格和换行 3. 清理多余的空格和换行
Args: Args:
response: 原始 LLM 响应 response: 原始 LLM 响应
Returns: Returns:
清理后的 JSON 字符串 清理后的 JSON 字符串
""" """
if not response: if not response:
return "" return ""
cleaned = response.strip() cleaned = response.strip()
# 移除 Markdown 代码块标记 # 移除 Markdown 代码块标记
# 匹配 ```json ... ``` 或 ``` ... ``` # 匹配 ```json ... ``` 或 ``` ... ```
code_block_patterns = [ code_block_patterns = [
r"```json\s*(.*?)```", # ```json ... ``` r"```json\s*(.*?)```", # ```json ... ```
r"```\s*(.*?)```", # ``` ... ``` r"```\s*(.*?)```", # ``` ... ```
] ]
for pattern in code_block_patterns: for pattern in code_block_patterns:
match = re.search(pattern, cleaned, re.IGNORECASE | re.DOTALL) match = re.search(pattern, cleaned, re.IGNORECASE | re.DOTALL)
if match: if match:
cleaned = match.group(1).strip() cleaned = match.group(1).strip()
logger.debug(f"从 Markdown 代码块中提取内容,长度: {len(cleaned)}") logger.debug(f"从 Markdown 代码块中提取内容,长度: {len(cleaned)}")
break break
# 提取 JSON 对象或数组 # 提取 JSON 对象或数组
# 优先查找对象 {...},其次查找数组 [...] # 优先查找对象 {...},其次查找数组 [...]
for start_char, end_char in [("{", "}"), ("[", "]")]: for start_char, end_char in [("{", "}"), ("[", "]")]:
@@ -143,7 +143,7 @@ def _clean_llm_response(response: str) -> str:
if extracted: if extracted:
logger.debug(f"提取到 {start_char}...{end_char} 结构,长度: {len(extracted)}") logger.debug(f"提取到 {start_char}...{end_char} 结构,长度: {len(extracted)}")
return extracted return extracted
# 如果没有找到明确的 JSON 结构,返回清理后的原始内容 # 如果没有找到明确的 JSON 结构,返回清理后的原始内容
logger.debug("未找到明确的 JSON 结构,返回清理后的原始内容") logger.debug("未找到明确的 JSON 结构,返回清理后的原始内容")
return cleaned return cleaned
@@ -152,39 +152,39 @@ def _clean_llm_response(response: str) -> str:
def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char: str) -> str | None: def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char: str) -> str | None:
""" """
从指定位置提取平衡的 JSON 结构 从指定位置提取平衡的 JSON 结构
使用栈匹配算法找到对应的结束符,处理嵌套和字符串中的特殊字符 使用栈匹配算法找到对应的结束符,处理嵌套和字符串中的特殊字符
Args: Args:
text: 源文本 text: 源文本
start_idx: 起始字符的索引 start_idx: 起始字符的索引
start_char: 起始字符({ 或 [ start_char: 起始字符({ 或 [
end_char: 结束字符(} 或 ] end_char: 结束字符(} 或 ]
Returns: Returns:
提取的 JSON 字符串,失败时返回 None 提取的 JSON 字符串,失败时返回 None
""" """
depth = 0 depth = 0
in_string = False in_string = False
escape_next = False escape_next = False
for i in range(start_idx, len(text)): for i in range(start_idx, len(text)):
char = text[i] char = text[i]
# 处理转义字符 # 处理转义字符
if escape_next: if escape_next:
escape_next = False escape_next = False
continue continue
if char == "\\": if char == "\\":
escape_next = True escape_next = True
continue continue
# 处理字符串 # 处理字符串
if char == '"': if char == '"':
in_string = not in_string in_string = not in_string
continue continue
# 只在非字符串内处理括号 # 只在非字符串内处理括号
if not in_string: if not in_string:
if char == start_char: if char == start_char:
@@ -194,7 +194,7 @@ def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char:
if depth == 0: if depth == 0:
# 找到匹配的结束符 # 找到匹配的结束符
return text[start_idx : i + 1].strip() return text[start_idx : i + 1].strip()
# 没有找到匹配的结束符 # 没有找到匹配的结束符
logger.debug(f"未找到匹配的 {end_char},深度: {depth}") logger.debug(f"未找到匹配的 {end_char},深度: {depth}")
return None return None
@@ -203,11 +203,11 @@ def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char:
def safe_parse_json(json_str: str, default: Any = None) -> Any: def safe_parse_json(json_str: str, default: Any = None) -> Any:
""" """
安全解析 JSON失败时返回默认值 安全解析 JSON失败时返回默认值
Args: Args:
json_str: JSON 字符串 json_str: JSON 字符串
default: 解析失败时返回的默认值 default: 解析失败时返回的默认值
Returns: Returns:
解析结果或默认值 解析结果或默认值
""" """
@@ -222,19 +222,19 @@ def safe_parse_json(json_str: str, default: Any = None) -> Any:
def extract_json_field(response: str, field_name: str, default: Any = None) -> Any: def extract_json_field(response: str, field_name: str, default: Any = None) -> Any:
""" """
从 LLM 响应中提取特定字段的值 从 LLM 响应中提取特定字段的值
Args: Args:
response: LLM 响应 response: LLM 响应
field_name: 字段名 field_name: 字段名
default: 字段不存在时的默认值 default: 字段不存在时的默认值
Returns: Returns:
字段值或默认值 字段值或默认值
""" """
parsed = extract_and_parse_json(response, strict=False) parsed = extract_and_parse_json(response, strict=False)
if isinstance(parsed, dict): if isinstance(parsed, dict):
return parsed.get(field_name, default) return parsed.get(field_name, default)
logger.warning(f"解析结果不是字典,无法提取字段 '{field_name}'") logger.warning(f"解析结果不是字典,无法提取字段 '{field_name}'")
return default return default

View File

@@ -1,5 +1,5 @@
[inner] [inner]
version = "7.6.4" version = "7.6.5"
#----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读---- #----以下是给开发人员阅读的如果你只是部署了MoFox-Bot不需要阅读----
#如果你想要修改配置文件请递增version的值 #如果你想要修改配置文件请递增version的值
@@ -478,9 +478,10 @@ exa_api_keys = ["None"]# EXA API密钥列表支持轮询机制
metaso_api_keys = ["None"]# Metaso API密钥列表支持轮询机制 metaso_api_keys = ["None"]# Metaso API密钥列表支持轮询机制
searxng_instances = [] # SearXNG 实例 URL 列表 searxng_instances = [] # SearXNG 实例 URL 列表
searxng_api_keys = []# SearXNG 实例 API 密钥列表 searxng_api_keys = []# SearXNG 实例 API 密钥列表
serper_api_keys = []# serper API 密钥列表
# 搜索引擎配置 # 搜索引擎配置
enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg","bing", "metaso" enabled_engines = ["ddg"] # 启用的搜索引擎列表,可选: "exa", "tavily", "ddg","bing", "metaso","serper"
search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个) search_strategy = "single" # 搜索策略: "single"(使用第一个可用引擎), "parallel"(并行使用所有启用的引擎), "fallback"(按顺序尝试,失败则尝试下一个)
[cross_context] # 跨群聊/私聊上下文共享配置 [cross_context] # 跨群聊/私聊上下文共享配置

View File

@@ -1,108 +0,0 @@
# 🔄 更新日志 - 记忆图可视化工具
## v1.1 - 2025-11-06
### ✨ 新增功能
1. **📂 文件选择器**
- 自动搜索所有可用的记忆图数据文件
- 支持在Web界面中切换不同的数据文件
- 显示文件大小、修改时间等信息
- 高亮显示当前使用的文件
2. **🔍 智能文件搜索**
- 自动查找 `data/memory_graph/graph_store.json`
- 搜索所有备份文件 `graph_store_*.json`
- 搜索 `data/backup/` 目录下的历史数据
- 按修改时间排序,自动使用最新文件
3. **📊 增强的文件信息显示**
- 在侧边栏显示当前文件信息
- 包含文件名、大小、修改时间
- 实时更新,方便追踪
### 🔧 改进
- 更友好的错误提示
- 无数据文件时显示引导信息
- 优化用户体验
### 🎯 使用方法
```bash
# 启动可视化工具
python run_visualizer_simple.py
# 或直接运行
python tools/memory_visualizer/visualizer_simple.py
```
在Web界面中:
1. 点击侧边栏的 "选择文件" 按钮
2. 浏览所有可用的数据文件
3. 点击任意文件切换数据源
4. 图形会自动重新加载
### 📸 新界面预览
侧边栏新增:
```
┌─────────────────────────┐
│ 📂 数据文件 │
│ ┌──────────┬──────────┐ │
│ │ 选择文件 │ 刷新列表 │ │
│ └──────────┴──────────┘ │
│ ┌─────────────────────┐ │
│ │ 📄 graph_store.json │ │
│ │ 大小: 125 KB │ │
│ │ 修改: 2025-11-06 │ │
│ └─────────────────────┘ │
└─────────────────────────┘
```
文件选择对话框:
```
┌────────────────────────────────┐
│ 📂 选择数据文件 [×] │
├────────────────────────────────┤
│ ┌────────────────────────────┐ │
│ │ 📄 graph_store.json [当前] │ │
│ │ 125 KB | 2025-11-06 09:30 │ │
│ └────────────────────────────┘ │
│ ┌────────────────────────────┐ │
│ │ 📄 graph_store_backup.json │ │
│ │ 120 KB | 2025-11-05 18:00 │ │
│ └────────────────────────────┘ │
└────────────────────────────────┘
```
---
## v1.0 - 2025-11-06 (初始版本)
### 🎉 首次发布
- ✅ 基于Vis.js的交互式图形可视化
- ✅ 节点类型颜色分类
- ✅ 搜索和过滤功能
- ✅ 统计信息显示
- ✅ 节点详情查看
- ✅ 数据导出功能
- ✅ 独立版服务器(快速启动)
- ✅ 完整版服务器(实时数据)
---
## 🔮 计划中的功能 (v1.2+)
- [ ] 时间轴视图 - 查看记忆随时间的变化
- [ ] 3D可视化模式
- [ ] 记忆重要性热力图
- [ ] 关系强度可视化
- [ ] 导出为图片/PDF
- [ ] 记忆路径追踪
- [ ] 多文件对比视图
- [ ] 性能优化 - 支持更大规模图形
- [ ] 移动端适配
欢迎提出建议和需求! 🚀

View File

@@ -1,163 +0,0 @@
# 📁 可视化工具文件整理完成
## ✅ 整理结果
### 新的目录结构
```
tools/memory_visualizer/
├── visualizer.ps1 ⭐ 统一启动脚本(主入口)
├── visualizer_simple.py # 独立版服务器
├── visualizer_server.py # 完整版服务器
├── generate_sample_data.py # 测试数据生成器
├── test_visualizer.py # 测试脚本
├── run_visualizer.py # Python 运行脚本(独立版)
├── run_visualizer_simple.py # Python 运行脚本(简化版)
├── start_visualizer.bat # Windows 批处理启动脚本
├── start_visualizer.ps1 # PowerShell 启动脚本
├── start_visualizer.sh # Linux/Mac 启动脚本
├── requirements.txt # Python 依赖
├── templates/ # HTML 模板
│ └── visualizer.html # 可视化界面
├── docs/ # 文档目录
│ ├── VISUALIZER_README.md
│ ├── VISUALIZER_GUIDE.md
│ └── VISUALIZER_INSTALL_COMPLETE.md
├── README.md # 主说明文档
├── QUICKSTART.md # 快速开始指南
└── CHANGELOG.md # 更新日志
```
### 根目录保留文件
```
项目根目录/
├── visualizer.ps1 # 快捷启动脚本(指向 tools/memory_visualizer/visualizer.ps1
└── tools/memory_visualizer/ # 所有可视化工具文件
```
## 🚀 使用方法
### 推荐方式:使用统一启动脚本
```powershell
# 在项目根目录
.\visualizer.ps1
# 或在工具目录
cd tools\memory_visualizer
.\visualizer.ps1
```
### 命令行参数
```powershell
# 直接启动独立版(推荐)
.\visualizer.ps1 -Simple
# 启动完整版
.\visualizer.ps1 -Full
# 生成测试数据
.\visualizer.ps1 -Generate
# 运行测试
.\visualizer.ps1 -Test
```
## 📋 整理内容
### 已移动的文件
从项目根目录移动到 `tools/memory_visualizer/`
1. **脚本文件**
- `generate_sample_data.py`
- `run_visualizer.py`
- `run_visualizer_simple.py`
- `test_visualizer.py`
- `start_visualizer.bat`
- `start_visualizer.ps1`
- `start_visualizer.sh`
- `visualizer.ps1`
2. **文档文件**`docs/` 子目录
- `VISUALIZER_GUIDE.md`
- `VISUALIZER_INSTALL_COMPLETE.md`
- `VISUALIZER_README.md`
### 已创建的新文件
1. **统一启动脚本**
- `tools/memory_visualizer/visualizer.ps1` - 功能齐全的统一入口
2. **快捷脚本**
- `visualizer.ps1`(根目录)- 快捷方式,指向实际脚本
3. **更新的文档**
- `tools/memory_visualizer/README.md` - 更新为反映新结构
## 🎯 优势
### 整理前的问题
- ❌ 文件散落在根目录
- ❌ 多个启动脚本功能重复
- ❌ 文档分散不便管理
- ❌ 不清楚哪个是主入口
### 整理后的改进
- ✅ 所有文件集中在 `tools/memory_visualizer/`
- ✅ 单一统一的启动脚本 `visualizer.ps1`
- ✅ 文档集中在 `docs/` 子目录
- ✅ 清晰的主入口和快捷方式
- ✅ 更好的可维护性
## 📝 功能对比
### 旧的方式(整理前)
```powershell
# 需要记住多个脚本名称
.\start_visualizer.ps1
.\run_visualizer.py
.\run_visualizer_simple.py
.\generate_sample_data.py
```
### 新的方式(整理后)
```powershell
# 只需要一个统一的脚本
.\visualizer.ps1 # 交互式菜单
.\visualizer.ps1 -Simple # 启动独立版
.\visualizer.ps1 -Generate # 生成数据
.\visualizer.ps1 -Test # 运行测试
```
## 🔧 维护说明
### 添加新功能
1.`tools/memory_visualizer/` 目录下添加新文件
2. 如需启动选项,在 `visualizer.ps1` 中添加新参数
3. 更新 `README.md` 文档
### 更新文档
1. 主文档:`tools/memory_visualizer/README.md`
2. 详细文档:`tools/memory_visualizer/docs/`
## ✅ 测试结果
- ✅ 统一启动脚本正常工作
- ✅ 独立版服务器成功启动(端口 5001
- ✅ 数据加载成功725 节点769 边)
- ✅ Web 界面正常访问
- ✅ 所有文件已整理到位
## 📚 相关文档
- [README](tools/memory_visualizer/README.md) - 主要说明文档
- [QUICKSTART](tools/memory_visualizer/QUICKSTART.md) - 快速开始指南
- [CHANGELOG](tools/memory_visualizer/CHANGELOG.md) - 更新日志
- [详细指南](tools/memory_visualizer/docs/VISUALIZER_GUIDE.md) - 完整使用指南
---
整理完成时间2025-11-06

View File

@@ -1,279 +0,0 @@
# 记忆图可视化工具 - 快速入门指南
## 🎯 方案选择
我为你创建了**两个版本**的可视化工具:
### 1⃣ 独立版 (推荐 ⭐)
- **文件**: `tools/memory_visualizer/visualizer_simple.py`
- **优点**:
- 直接读取存储文件,无需初始化完整系统
- 启动快速
- 占用资源少
- **适用**: 快速查看已有记忆数据
### 2⃣ 完整版
- **文件**: `tools/memory_visualizer/visualizer_server.py`
- **优点**:
- 实时数据
- 支持更多功能
- **缺点**:
- 需要完整初始化记忆管理器
- 启动较慢
## 🚀 快速开始
### 步骤 1: 安装依赖
**Windows (PowerShell):**
```powershell
# 依赖会自动检查和安装
.\start_visualizer.ps1
```
**Windows (CMD):**
```cmd
start_visualizer.bat
```
**Linux/Mac:**
```bash
chmod +x start_visualizer.sh
./start_visualizer.sh
```
**手动安装依赖:**
```bash
# 使用虚拟环境
.\.venv\Scripts\python.exe -m pip install flask flask-cors
# 或全局安装
pip install flask flask-cors
```
### 步骤 2: 确保有数据
如果还没有记忆数据,可以:
**选项A**: 运行Bot生成实际数据
```bash
python bot.py
# 与Bot交互一会儿,让它积累一些记忆
```
**选项B**: 生成测试数据 (如果测试脚本可用)
```bash
python test_visualizer.py
# 选择选项 1: 生成测试数据
```
### 步骤 3: 启动可视化服务器
**方式一: 使用启动脚本 (推荐 ⭐)**
Windows PowerShell:
```powershell
.\start_visualizer.ps1
```
Windows CMD:
```cmd
start_visualizer.bat
```
Linux/Mac:
```bash
./start_visualizer.sh
```
**方式二: 手动启动**
使用虚拟环境:
```bash
# Windows
.\.venv\Scripts\python.exe tools/memory_visualizer/visualizer_simple.py
# Linux/Mac
.venv/bin/python tools/memory_visualizer/visualizer_simple.py
```
或使用系统Python:
```bash
python tools/memory_visualizer/visualizer_simple.py
```
服务器将在 http://127.0.0.1:5001 启动
### 步骤 4: 打开浏览器
访问对应的地址,开始探索记忆图! 🎉
## 🎨 界面功能
### 左侧栏
1. **🔍 搜索框**
- 输入关键词搜索相关记忆
- 结果会在图中高亮显示
2. **📊 统计信息**
- 节点总数
- 边总数
- 记忆总数
- 图密度
3. **🎨 节点类型图例**
- 🔴 主体 (SUBJECT) - 记忆的主语
- 🔵 主题 (TOPIC) - 动作或状态
- 🟢 客体 (OBJECT) - 宾语
- 🟠 属性 (ATTRIBUTE) - 延伸属性
- 🟣 值 (VALUE) - 属性的具体值
4. **🔧 过滤器**
- 勾选/取消勾选来显示/隐藏特定类型的节点
- 实时更新图形
5. ** 节点信息**
- 点击任意节点查看详细信息
- 显示节点类型、内容、创建时间等
### 右侧主区域
1. **控制按钮**
- 🔄 刷新图形: 重新加载最新数据
- 📐 适应窗口: 自动调整图形大小
- 💾 导出数据: 下载JSON格式的图数据
2. **交互式图形**
- **拖动节点**: 点击并拖动单个节点
- **拖动画布**: 按住空白处拖动整个图形
- **缩放**: 使用鼠标滚轮放大/缩小
- **点击节点**: 查看详细信息
- **物理模拟**: 节点会自动排列,避免重叠
## 🎮 操作技巧
### 查看特定类型的节点
1. 在左侧过滤器中取消勾选不需要的类型
2. 图形会自动更新,只显示选中的类型
### 查找特定记忆
1. 在搜索框输入关键词(如: "小明", "吃饭")
2. 点击"搜索"按钮
3. 相关节点会被选中并自动聚焦
### 整理混乱的图形
1. 点击"适应窗口"按钮
2. 或者刷新页面重新初始化布局
### 导出数据进行分析
1. 点击"导出数据"按钮
2. JSON文件会自动下载
3. 可以用于进一步的数据分析或备份
## 🎯 示例场景
### 场景1: 了解记忆图整体结构
1. 启动可视化工具
2. 观察不同颜色的节点分布
3. 查看统计信息了解数量
4. 使用过滤器逐个类型查看
### 场景2: 追踪特定主题的记忆
1. 在搜索框输入主题关键词(如: "学习")
2. 点击搜索
3. 查看高亮的相关节点
4. 点击节点查看详情
### 场景3: 调试记忆系统
1. 创建一条新记忆
2. 刷新可视化页面
3. 查看新节点和边是否正确创建
4. 验证节点类型和关系
## 🐛 常见问题
### Q: 页面显示空白或没有数据?
**A**:
1. 检查是否有记忆数据: 查看 `data/memory_graph/` 目录
2. 确保记忆系统已启用: 检查 `config/bot_config.toml``[memory] enable = true`
3. 尝试生成一些测试数据
### Q: 节点太多,看不清楚?
**A**:
1. 使用过滤器只显示某些类型
2. 使用搜索功能定位特定节点
3. 调整浏览器窗口大小,点击"适应窗口"
### Q: 如何更新数据?
**A**:
- **独立版**: 点击"刷新图形"或访问 `/api/reload`
- **完整版**: 点击"刷新图形"会自动加载最新数据
### Q: 端口被占用怎么办?
**A**: 修改启动脚本中的端口号:
```python
run_server(host='127.0.0.1', port=5002, debug=True) # 改为其他端口
```
## 🎨 自定义配置
### 修改节点颜色
编辑 `templates/visualizer.html`,找到:
```javascript
const nodeColors = {
'SUBJECT': '#FF6B6B', // 改为你喜欢的颜色
'TOPIC': '#4ECDC4',
// ...
};
```
### 修改物理引擎参数
在同一文件中找到 `physics` 配置:
```javascript
physics: {
barnesHut: {
gravitationalConstant: -8000, // 调整引力
springLength: 150, // 调整弹簧长度
// ...
}
}
```
### 修改数据加载限制
编辑对应的服务器文件,修改 `get_all_memories()` 的limit参数。
## 📝 文件结构
```
tools/memory_visualizer/
├── README.md # 详细文档
├── requirements.txt # 依赖列表
├── visualizer_server.py # 完整版服务器
├── visualizer_simple.py # 独立版服务器 ⭐
└── templates/
└── visualizer.html # Web界面模板
run_visualizer.py # 快速启动脚本
test_visualizer.py # 测试和演示脚本
```
## 🚀 下一步
现在你可以:
1. ✅ 启动可视化工具查看现有数据
2. ✅ 与Bot交互生成更多记忆
3. ✅ 使用可视化工具验证记忆结构
4. ✅ 根据需要自定义样式和配置
祝你使用愉快! 🎉
---
如有问题,请查看 `tools/memory_visualizer/README.md` 获取更多帮助。

View File

@@ -1,201 +0,0 @@
# 🦊 记忆图可视化工具
一个交互式的 Web 可视化工具,用于查看和分析 MoFox Bot 的记忆图结构。
## 📁 目录结构
```
tools/memory_visualizer/
├── visualizer.ps1 # 统一启动脚本(主入口)⭐
├── visualizer_simple.py # 独立版服务器(推荐)
├── visualizer_server.py # 完整版服务器
├── generate_sample_data.py # 测试数据生成器
├── test_visualizer.py # 测试脚本
├── requirements.txt # Python 依赖
├── templates/ # HTML 模板
│ └── visualizer.html # 可视化界面
├── docs/ # 文档目录
│ ├── VISUALIZER_README.md
│ ├── VISUALIZER_GUIDE.md
│ └── VISUALIZER_INSTALL_COMPLETE.md
├── README.md # 本文件
├── QUICKSTART.md # 快速开始指南
└── CHANGELOG.md # 更新日志
```
## 🚀 快速开始
### 方式 1交互式菜单推荐
```powershell
# 在项目根目录运行
.\visualizer.ps1
# 或在工具目录运行
cd tools\memory_visualizer
.\visualizer.ps1
```
### 方式 2命令行参数
```powershell
# 启动独立版(推荐,快速)
.\visualizer.ps1 -Simple
# 启动完整版(需要 MemoryManager
.\visualizer.ps1 -Full
# 生成测试数据
.\visualizer.ps1 -Generate
# 运行测试
.\visualizer.ps1 -Test
# 查看帮助
.\visualizer.ps1 -Help
```
## 📊 两个版本的区别
### 独立版Simple- 推荐
-**快速启动**:直接读取数据文件,无需初始化 MemoryManager
-**轻量级**:只依赖 Flask 和 vis.js
-**稳定**:不依赖主系统运行状态
- 📌 **端口**5001
- 📁 **数据源**`data/memory_graph/*.json`
### 完整版Full
- 🔄 **实时数据**:使用 MemoryManager 获取最新数据
- 🔌 **集成**:与主系统深度集成
-**功能完整**:支持所有高级功能
- 📌 **端口**5000
- 📁 **数据源**MemoryManager
## ✨ 主要功能
1. **交互式图形可视化**
- 🎨 5 种节点类型(主体、主题、客体、属性、值)
- 🔗 完整路径高亮显示
- 🔍 点击节点查看连接关系
- 📐 自动布局和缩放
2. **高级筛选**
- ☑️ 按节点类型筛选
- 🔎 关键词搜索
- 📊 统计信息实时更新
3. **智能高亮**
- 💡 点击节点高亮所有连接路径(递归探索)
- 👻 无关节点变为半透明
- 🎯 自动聚焦到相关子图
4. **物理引擎优化**
- 🚀 智能布局算法
- ⏱️ 自动停止防止持续运行
- 🔄 筛选后自动重新布局
5. **数据管理**
- 📂 多文件选择器
- 💾 导出图形数据
- 🔄 实时刷新
## 🔧 依赖安装
脚本会自动检查并安装依赖,也可以手动安装:
```powershell
# 激活虚拟环境
.\.venv\Scripts\Activate.ps1
# 安装依赖
pip install -r tools/memory_visualizer/requirements.txt
```
**所需依赖:**
- Flask >= 2.3.0
- flask-cors >= 4.0.0
## 📖 使用说明
### 1. 查看记忆图
1. 启动服务器(推荐独立版)
2. 在浏览器打开 http://127.0.0.1:5001
3. 等待数据加载完成
### 2. 探索连接关系
1. **点击节点**:查看与该节点相关的所有连接路径
2. **点击空白处**:恢复所有节点显示
3. **使用筛选器**:按类型过滤节点
### 3. 搜索记忆
1. 在搜索框输入关键词
2. 点击搜索按钮
3. 相关节点会自动高亮
### 4. 查看统计
- 左侧面板显示实时统计信息
- 节点数、边数、记忆数
- 图密度等指标
## 🎨 节点颜色说明
- 🔴 **主体SUBJECT**:红色 (#FF6B6B)
- 🔵 **主题TOPIC**:青色 (#4ECDC4)
- 🟦 **客体OBJECT**:蓝色 (#45B7D1)
- 🟠 **属性ATTRIBUTE**:橙色 (#FFA07A)
- 🟢 **值VALUE**:绿色 (#98D8C8)
## 🐛 常见问题
### 问题 1没有数据显示
**解决方案:**
1. 检查 `data/memory_graph/` 目录是否存在数据文件
2. 运行 `.\visualizer.ps1 -Generate` 生成测试数据
3. 确保 Bot 已经运行过并生成了记忆数据
### 问题 2物理引擎一直运行
**解决方案:**
- 新版本已修复此问题
- 物理引擎会在稳定后自动停止(最多 5 秒)
### 问题 3筛选后节点排版错乱
**解决方案:**
- 新版本已修复此问题
- 筛选后会自动重新布局
### 问题 4无法查看完整连接路径
**解决方案:**
- 新版本使用 BFS 算法递归探索所有连接
- 点击节点即可查看完整路径
## 📝 开发说明
### 添加新功能
1. 编辑 `visualizer_simple.py``visualizer_server.py`
2. 修改 `templates/visualizer.html` 更新界面
3. 更新 `requirements.txt` 添加新依赖
4. 运行测试:`.\visualizer.ps1 -Test`
### 调试
```powershell
# 启动 Flask 调试模式
$env:FLASK_DEBUG = "1"
python tools/memory_visualizer/visualizer_simple.py
```
## 📚 相关文档
- [快速开始指南](QUICKSTART.md)
- [更新日志](CHANGELOG.md)
- [详细使用指南](docs/VISUALIZER_GUIDE.md)
## 🆘 获取帮助
遇到问题?
1. 查看 [常见问题](#常见问题)
2. 运行 `.\visualizer.ps1 -Help` 查看帮助
3. 查看项目文档目录
## 📄 许可证
与 MoFox Bot 主项目相同

View File

@@ -1,163 +0,0 @@
# 🦊 MoFox Bot 记忆图可视化工具
这是一个交互式的Web界面,用于可视化和探索MoFox Bot的记忆图结构。
## ✨ 功能特性
- **交互式图形可视化**: 使用Vis.js展示节点和边的关系
- **实时数据**: 直接从记忆管理器读取最新数据
- **节点类型分类**: 不同颜色区分不同类型的节点
- 🔴 主体 (SUBJECT)
- 🔵 主题 (TOPIC)
- 🟢 客体 (OBJECT)
- 🟠 属性 (ATTRIBUTE)
- 🟣 值 (VALUE)
- **搜索功能**: 快速查找相关记忆
- **过滤器**: 按节点类型过滤显示
- **统计信息**: 实时显示图的统计数据
- **节点详情**: 点击节点查看详细信息
- **自由缩放拖动**: 支持图形的交互式操作
- **数据导出**: 导出当前图形数据为JSON
## 🚀 快速开始
### 1. 安装依赖
```bash
pip install flask flask-cors
```
### 2. 启动服务器
在项目根目录运行:
```bash
python tools/memory_visualizer/visualizer_server.py
```
或者使用便捷脚本:
```bash
python run_visualizer.py
```
### 3. 打开浏览器
访问: http://127.0.0.1:5000
## 📊 界面说明
### 主界面布局
```
┌─────────────────────────────────────────────────┐
│ 侧边栏 │ 主内容区 │
│ - 搜索框 │ - 控制按钮 │
│ - 统计信息 │ - 图形显示 │
│ - 节点类型图例 │ │
│ - 过滤器 │ │
│ - 节点详情 │ │
└─────────────────────────────────────────────────┘
```
### 操作说明
- **🔍 搜索**: 在搜索框输入关键词,点击"搜索"按钮查找相关记忆
- **🔄 刷新图形**: 重新加载最新的记忆图数据
- **📐 适应窗口**: 自动调整图形大小以适应窗口
- **💾 导出数据**: 将当前图形数据导出为JSON文件
- **✅ 过滤器**: 勾选/取消勾选不同类型的节点来过滤显示
- **👆 点击节点**: 点击任意节点查看详细信息
- **🖱️ 拖动**: 按住鼠标拖动节点或整个图形
- **🔍 缩放**: 使用鼠标滚轮缩放图形
## 🔧 配置说明
### 修改服务器配置
`visualizer_server.py` 的最后:
```python
if __name__ == '__main__':
run_server(
host='127.0.0.1', # 监听地址
port=5000, # 端口号
debug=True # 调试模式
)
```
### API端点
- `GET /` - 主页面
- `GET /api/graph/full` - 获取完整记忆图数据
- `GET /api/memory/<memory_id>` - 获取特定记忆详情
- `GET /api/search?q=<query>&limit=<n>` - 搜索记忆
- `GET /api/stats` - 获取统计信息
## 📝 技术栈
- **后端**: Flask (Python Web框架)
- **前端**:
- Vis.js (图形可视化库)
- 原生JavaScript
- CSS3 (渐变、动画、响应式布局)
- **数据**: 直接从MoFox Bot记忆管理器读取
## 🐛 故障排除
### 问题: 无法启动服务器
**原因**: 记忆系统未启用或配置错误
**解决**: 检查 `config/bot_config.toml` 确保:
```toml
[memory]
enable = true
data_dir = "data/memory_graph"
```
### 问题: 图形显示空白
**原因**: 没有记忆数据
**解决**:
1. 先运行Bot让其生成一些记忆
2. 或者运行测试脚本生成测试数据
### 问题: 节点太多,图形混乱
**解决**:
1. 使用过滤器只显示某些类型的节点
2. 使用搜索功能定位特定记忆
3. 调整物理引擎参数(在visualizer.html中)
## 🎨 自定义样式
修改 `templates/visualizer.html` 中的样式定义:
```javascript
const nodeColors = {
'SUBJECT': '#FF6B6B', // 主体颜色
'TOPIC': '#4ECDC4', // 主题颜色
'OBJECT': '#45B7D1', // 客体颜色
'ATTRIBUTE': '#FFA07A', // 属性颜色
'VALUE': '#98D8C8' // 值颜色
};
```
## 📈 性能优化
对于大型图形(>1000节点):
1. **禁用物理引擎**: 在stabilization完成后自动禁用
2. **限制显示节点**: 使用过滤器或搜索
3. **分页加载**: 修改API使用分页
## 🤝 贡献
欢迎提交Issue和Pull Request!
## 📄 许可
与MoFox Bot主项目相同的许可证

View File

@@ -1,210 +0,0 @@
# ✅ 记忆图可视化工具 - 安装完成
## 🎉 恭喜!可视化工具已成功创建!
---
## 📦 已创建的文件
```
Bot/
├── visualizer.ps1 ⭐⭐⭐ # 统一启动脚本 (推荐使用)
├── start_visualizer.ps1 # 独立版快速启动
├── start_visualizer.bat # CMD版启动脚本
├── generate_sample_data.py # 示例数据生成器
├── VISUALIZER_README.md ⭐ # 快速参考指南
├── VISUALIZER_GUIDE.md # 完整使用指南
└── tools/memory_visualizer/
├── visualizer_simple.py ⭐ # 独立版服务器 (推荐)
├── visualizer_server.py # 完整版服务器
├── README.md # 详细文档
├── QUICKSTART.md # 快速入门
├── CHANGELOG.md # 更新日志
└── templates/
└── visualizer.html ⭐ # 精美Web界面
```
---
## 🚀 立即开始 (3秒)
### 方法 1: 使用统一启动脚本 (最简单 ⭐⭐⭐)
```powershell
.\visualizer.ps1
```
然后按提示选择:
- **1** = 独立版 (推荐,快速)
- **2** = 完整版 (实时数据)
- **3** = 生成示例数据
### 方法 2: 直接启动
```powershell
# 如果还没有数据,先生成
.\.venv\Scripts\python.exe generate_sample_data.py
# 启动可视化
.\start_visualizer.ps1
# 打开浏览器
# http://127.0.0.1:5001
```
---
## 🎨 功能亮点
### ✨ 核心功能
- 🎯 **交互式图形**: 拖动、缩放、点击
- 🎨 **颜色分类**: 5种节点类型自动上色
- 🔍 **智能搜索**: 快速定位相关记忆
- 🔧 **灵活过滤**: 按节点类型筛选
- 📊 **实时统计**: 节点、边、记忆数量
- 💾 **数据导出**: JSON格式导出
### 📂 独立版特色 (推荐)
-**秒速启动**: 2秒内完成
- 📁 **文件切换**: 浏览所有历史数据
- 🔄 **自动搜索**: 智能查找数据文件
- 💚 **低资源**: 占用资源极少
### 🔥 完整版特色
- 🔴 **实时数据**: 与Bot同步
- 🔄 **自动更新**: 无需刷新
- 🛠️ **完整功能**: 使用全部API
---
## 📊 界面预览
```
┌─────────────────────────────────────────────────────────┐
│ 侧边栏 │ 主区域 │
│ ┌─────────────────────┐ │ ┌───────────────────────┐ │
│ │ 📂 数据文件 │ │ │ 🔄 📐 💾 控制按钮 │ │
│ │ [选择] [刷新] │ │ └───────────────────────┘ │
│ │ 📄 当前: xxx.json │ │ ┌───────────────────────┐ │
│ └─────────────────────┘ │ │ │ │
│ │ │ 交互式图形可视化 │ │
│ ┌─────────────────────┐ │ │ │ │
│ │ 🔍 搜索记忆 │ │ │ 🔴 主体 🔵 主题 │ │
│ │ [...........] [搜索] │ │ │ 🟢 客体 🟠 属性 │ │
│ └─────────────────────┘ │ │ 🟣 值 │ │
│ │ │ │ │
│ 📊 统计: 12节点 15边 │ │ 可拖动、缩放、点击 │ │
│ │ │ │ │
│ 🎨 节点类型图例 │ └───────────────────────┘ │
│ 🔧 过滤器 │ │
节点信息 │ │
└─────────────────────────────────────────────────────────┘
```
---
## 🎯 快速命令
```powershell
# 统一启动 (推荐)
.\visualizer.ps1
# 生成示例数据
.\.venv\Scripts\python.exe generate_sample_data.py
# 独立版 (端口 5001)
.\start_visualizer.ps1
# 完整版 (端口 5000)
.\.venv\Scripts\python.exe tools/memory_visualizer/visualizer_server.py
```
---
## 📖 文档索引
### 快速参考 (必读 ⭐)
- **VISUALIZER_README.md** - 快速参考卡片
- **VISUALIZER_GUIDE.md** - 完整使用指南
### 详细文档
- **tools/memory_visualizer/README.md** - 技术文档
- **tools/memory_visualizer/QUICKSTART.md** - 快速入门
- **tools/memory_visualizer/CHANGELOG.md** - 版本历史
---
## 💡 使用建议
### 🎯 对于首次使用者
1. 运行 `.\visualizer.ps1`
2. 选择 `3` 生成示例数据
3. 选择 `1` 启动独立版
4. 打开浏览器访问 http://127.0.0.1:5001
5. 开始探索!
### 🔧 对于开发者
1. 运行Bot积累真实数据
2. 启动完整版可视化: `.\visualizer.ps1``2`
3. 实时查看记忆图变化
4. 调试和优化
### 📊 对于数据分析
1. 使用独立版查看历史数据
2. 切换不同时期的数据文件
3. 使用搜索和过滤功能
4. 导出数据进行分析
---
## 🐛 常见问题
### Q: 未找到数据文件?
**A**: 运行 `.\visualizer.ps1` 选择 `3` 生成示例数据
### Q: 端口被占用?
**A**: 修改对应服务器文件中的端口号,或关闭占用端口的程序
### Q: 两个版本有什么区别?
**A**:
- **独立版**: 快速,读文件,可切换,推荐日常使用
- **完整版**: 实时,用内存,完整功能,推荐开发调试
### Q: 图形显示混乱?
**A**:
1. 使用过滤器减少节点
2. 点击"适应窗口"
3. 刷新页面
---
## 🎉 开始使用
### 立即启动
```powershell
.\visualizer.ps1
```
### 访问地址
- 独立版: http://127.0.0.1:5001
- 完整版: http://127.0.0.1:5000
---
## 🤝 反馈与支持
如有问题或建议,请查看:
- 📖 `VISUALIZER_GUIDE.md` - 完整使用指南
- 📝 `tools/memory_visualizer/README.md` - 技术文档
---
## 🌟 特别感谢
感谢你使用 MoFox Bot 记忆图可视化工具!
**享受探索记忆图的乐趣!** 🚀🦊
---
_最后更新: 2025-11-06_

View File

@@ -1,159 +0,0 @@
# 🎯 记忆图可视化工具 - 快速参考
## 🚀 快速启动
### 推荐方式 (交互式菜单)
```powershell
.\visualizer.ps1
```
然后选择:
- **选项 1**: 独立版 (快速,推荐) ⭐
- **选项 2**: 完整版 (实时数据)
- **选项 3**: 生成示例数据
---
## 📋 各版本对比
| 特性 | 独立版 ⭐ | 完整版 |
|------|---------|--------|
| **启动速度** | 🚀 快速 (2秒) | ⏱️ 较慢 (5-10秒) |
| **数据源** | 📂 文件 | 💾 内存 (实时) |
| **文件切换** | ✅ 支持 | ❌ 不支持 |
| **资源占用** | 💚 低 | 💛 中等 |
| **端口** | 5001 | 5000 |
| **适用场景** | 查看历史数据、调试 | 实时监控、开发 |
---
## 🔧 手动启动命令
### 独立版 (推荐)
```powershell
# Windows
.\start_visualizer.ps1
# 或直接运行
.\.venv\Scripts\python.exe tools/memory_visualizer/visualizer_simple.py
```
访问: http://127.0.0.1:5001
### 完整版
```powershell
.\.venv\Scripts\python.exe tools/memory_visualizer/visualizer_server.py
```
访问: http://127.0.0.1:5000
### 生成示例数据
```powershell
.\.venv\Scripts\python.exe generate_sample_data.py
```
---
## 📊 功能一览
### 🎨 可视化功能
- ✅ 交互式图形 (拖动、缩放、点击)
- ✅ 节点类型颜色分类
- ✅ 实时搜索和过滤
- ✅ 统计信息展示
- ✅ 节点详情查看
### 📂 数据管理
- ✅ 自动搜索数据文件
- ✅ 多文件切换 (独立版)
- ✅ 数据导出 (JSON格式)
- ✅ 文件信息显示
---
## 🎯 使用场景
### 1⃣ 首次使用
```powershell
# 1. 生成示例数据
.\visualizer.ps1
# 选择: 3
# 2. 启动可视化
.\visualizer.ps1
# 选择: 1
# 3. 打开浏览器
# 访问: http://127.0.0.1:5001
```
### 2⃣ 查看实际数据
```powershell
# 先运行Bot生成记忆
# 然后启动可视化
.\visualizer.ps1
# 选择: 1 (独立版) 或 2 (完整版)
```
### 3⃣ 调试记忆系统
```powershell
# 使用完整版,实时查看变化
.\visualizer.ps1
# 选择: 2
```
---
## 🐛 故障排除
### ❌ 问题: 未找到数据文件
**解决**:
```powershell
.\visualizer.ps1
# 选择 3 生成示例数据
```
### ❌ 问题: 端口被占用
**解决**:
- 独立版: 修改 `visualizer_simple.py` 中的 `port=5001`
- 完整版: 修改 `visualizer_server.py` 中的 `port=5000`
### ❌ 问题: 数据加载失败
**可能原因**:
- 数据文件格式不正确
- 文件损坏
**解决**:
1. 检查 `data/memory_graph/` 目录
2. 重新生成示例数据
3. 查看终端错误信息
---
## 📚 相关文档
- **完整指南**: `VISUALIZER_GUIDE.md`
- **快速入门**: `tools/memory_visualizer/QUICKSTART.md`
- **详细文档**: `tools/memory_visualizer/README.md`
- **更新日志**: `tools/memory_visualizer/CHANGELOG.md`
---
## 💡 提示
1. **首次使用**: 先生成示例数据 (选项 3)
2. **查看历史**: 使用独立版,可以切换不同数据文件
3. **实时监控**: 使用完整版与Bot同时运行
4. **性能优化**: 大型图使用过滤器和搜索
5. **快捷键**:
- `Ctrl + 滚轮`: 缩放
- 拖动空白: 移动画布
- 点击节点: 查看详情
---
## 🎉 开始探索!
```powershell
.\visualizer.ps1
```
享受你的记忆图之旅!🚀🦊

View File

@@ -1,9 +0,0 @@
# 记忆图可视化工具依赖
# Web框架
flask>=2.3.0
flask-cors>=4.0.0
# 其他依赖由主项目提供
# - src.memory_graph
# - src.config

View File

@@ -1,38 +0,0 @@
#!/usr/bin/env python3
"""
记忆图可视化工具启动脚本
快速启动记忆图可视化Web服务器
"""
import sys
from pathlib import Path
# 添加项目根目录到路径
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
from tools.memory_visualizer.visualizer_server import run_server
if __name__ == '__main__':
print("=" * 60)
print("🦊 MoFox Bot - 记忆图可视化工具")
print("=" * 60)
print()
print("📊 启动可视化服务器...")
print("🌐 访问地址: http://127.0.0.1:5000")
print("⏹️ 按 Ctrl+C 停止服务器")
print()
print("=" * 60)
try:
run_server(
host='127.0.0.1',
port=5000,
debug=True
)
except KeyboardInterrupt:
print("\n\n👋 服务器已停止")
except Exception as e:
print(f"\n❌ 启动失败: {e}")
sys.exit(1)

View File

@@ -1,39 +0,0 @@
"""
快速启动脚本 - 记忆图可视化工具 (独立版)
使用说明:
1. 直接运行此脚本启动可视化服务器
2. 工具会自动搜索可用的数据文件
3. 如果找到多个文件,会使用最新的文件
4. 你也可以在Web界面中选择其他文件
"""
import sys
from pathlib import Path
# 添加项目根目录
project_root = Path(__file__).parent
sys.path.insert(0, str(project_root))
if __name__ == '__main__':
print("=" * 70)
print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)")
print("=" * 70)
print()
print("✨ 特性:")
print(" • 自动搜索可用的数据文件")
print(" • 支持在Web界面中切换文件")
print(" • 快速启动,无需完整初始化")
print()
print("=" * 70)
try:
from tools.memory_visualizer.visualizer_simple import run_server
run_server(host='127.0.0.1', port=5001, debug=True)
except KeyboardInterrupt:
print("\n\n👋 服务器已停止")
except Exception as e:
print(f"\n❌ 启动失败: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

View File

@@ -1,53 +0,0 @@
@echo off
REM 记忆图可视化工具启动脚本 - CMD版本
echo ======================================================================
echo 🦊 MoFox Bot - 记忆图可视化工具
echo ======================================================================
echo.
REM 检查虚拟环境
set VENV_PYTHON=.venv\Scripts\python.exe
if not exist "%VENV_PYTHON%" (
echo ❌ 未找到虚拟环境: %VENV_PYTHON%
echo.
echo 请先创建虚拟环境:
echo python -m venv .venv
echo .venv\Scripts\activate.bat
echo pip install -r requirements.txt
echo.
exit /b 1
)
echo ✅ 使用虚拟环境: %VENV_PYTHON%
echo.
REM 检查依赖
echo 🔍 检查依赖...
"%VENV_PYTHON%" -c "import flask; import flask_cors" 2>nul
if errorlevel 1 (
echo ⚠️ 缺少依赖,正在安装...
"%VENV_PYTHON%" -m pip install flask flask-cors --quiet
if errorlevel 1 (
echo ❌ 安装依赖失败
exit /b 1
)
echo ✅ 依赖安装完成
)
echo ✅ 依赖检查完成
echo.
REM 显示信息
echo 📊 启动可视化服务器...
echo 🌐 访问地址: http://127.0.0.1:5001
echo ⏹️ 按 Ctrl+C 停止服务器
echo.
echo ======================================================================
echo.
REM 启动服务器
"%VENV_PYTHON%" "tools\memory_visualizer\visualizer_simple.py"
echo.
echo 👋 服务器已停止

View File

@@ -1,65 +0,0 @@
#!/usr/bin/env pwsh
# 记忆图可视化工具启动脚本 - PowerShell版本
Write-Host "=" -NoNewline -ForegroundColor Cyan
Write-Host ("=" * 69) -ForegroundColor Cyan
Write-Host "🦊 MoFox Bot - 记忆图可视化工具" -ForegroundColor Yellow
Write-Host "=" -NoNewline -ForegroundColor Cyan
Write-Host ("=" * 69) -ForegroundColor Cyan
Write-Host ""
# 检查虚拟环境
$venvPath = ".venv\Scripts\python.exe"
if (-not (Test-Path $venvPath)) {
Write-Host "❌ 未找到虚拟环境: $venvPath" -ForegroundColor Red
Write-Host ""
Write-Host "请先创建虚拟环境:" -ForegroundColor Yellow
Write-Host " python -m venv .venv" -ForegroundColor Cyan
Write-Host " .\.venv\Scripts\Activate.ps1" -ForegroundColor Cyan
Write-Host " pip install -r requirements.txt" -ForegroundColor Cyan
Write-Host ""
exit 1
}
Write-Host "✅ 使用虚拟环境: $venvPath" -ForegroundColor Green
Write-Host ""
# 检查依赖
Write-Host "🔍 检查依赖..." -ForegroundColor Cyan
& $venvPath -c "import flask; import flask_cors" 2>$null
if ($LASTEXITCODE -ne 0) {
Write-Host "⚠️ 缺少依赖,正在安装..." -ForegroundColor Yellow
& $venvPath -m pip install flask flask-cors --quiet
if ($LASTEXITCODE -ne 0) {
Write-Host "❌ 安装依赖失败" -ForegroundColor Red
exit 1
}
Write-Host "✅ 依赖安装完成" -ForegroundColor Green
}
Write-Host "✅ 依赖检查完成" -ForegroundColor Green
Write-Host ""
# 显示信息
Write-Host "📊 启动可视化服务器..." -ForegroundColor Cyan
Write-Host "🌐 访问地址: " -NoNewline -ForegroundColor White
Write-Host "http://127.0.0.1:5001" -ForegroundColor Blue
Write-Host "⏹️ 按 Ctrl+C 停止服务器" -ForegroundColor Yellow
Write-Host ""
Write-Host "=" -NoNewline -ForegroundColor Cyan
Write-Host ("=" * 69) -ForegroundColor Cyan
Write-Host ""
# 启动服务器
try {
& $venvPath "tools\memory_visualizer\visualizer_simple.py"
}
catch {
Write-Host ""
Write-Host "❌ 启动失败: $_" -ForegroundColor Red
exit 1
}
finally {
Write-Host ""
Write-Host "👋 服务器已停止" -ForegroundColor Yellow
}

View File

@@ -1,53 +0,0 @@
#!/bin/bash
# 记忆图可视化工具启动脚本 - Bash版本 (Linux/Mac)
echo "======================================================================"
echo "🦊 MoFox Bot - 记忆图可视化工具"
echo "======================================================================"
echo ""
# 检查虚拟环境
VENV_PYTHON=".venv/bin/python"
if [ ! -f "$VENV_PYTHON" ]; then
echo "❌ 未找到虚拟环境: $VENV_PYTHON"
echo ""
echo "请先创建虚拟环境:"
echo " python -m venv .venv"
echo " source .venv/bin/activate"
echo " pip install -r requirements.txt"
echo ""
exit 1
fi
echo "✅ 使用虚拟环境: $VENV_PYTHON"
echo ""
# 检查依赖
echo "🔍 检查依赖..."
$VENV_PYTHON -c "import flask; import flask_cors" 2>/dev/null
if [ $? -ne 0 ]; then
echo "⚠️ 缺少依赖,正在安装..."
$VENV_PYTHON -m pip install flask flask-cors --quiet
if [ $? -ne 0 ]; then
echo "❌ 安装依赖失败"
exit 1
fi
echo "✅ 依赖安装完成"
fi
echo "✅ 依赖检查完成"
echo ""
# 显示信息
echo "📊 启动可视化服务器..."
echo "🌐 访问地址: http://127.0.0.1:5001"
echo "⏹️ 按 Ctrl+C 停止服务器"
echo ""
echo "======================================================================"
echo ""
# 启动服务器
$VENV_PYTHON "tools/memory_visualizer/visualizer_simple.py"
echo ""
echo "👋 服务器已停止"

View File

@@ -1,59 +0,0 @@
# 记忆图可视化工具统一启动脚本
param(
[switch]$Simple,
[switch]$Full,
[switch]$Generate,
[switch]$Test
)
$ScriptDir = Split-Path -Parent $MyInvocation.MyCommand.Path
$ProjectRoot = Split-Path -Parent (Split-Path -Parent $ScriptDir)
Set-Location $ProjectRoot
function Get-Python {
$paths = @(".venv\Scripts\python.exe", "venv\Scripts\python.exe")
foreach ($p in $paths) {
if (Test-Path $p) { return $p }
}
return $null
}
$python = Get-Python
if (-not $python) {
Write-Host "ERROR: Virtual environment not found" -ForegroundColor Red
exit 1
}
if ($Simple) {
Write-Host "Starting Simple Server on http://127.0.0.1:5001" -ForegroundColor Green
& $python "$ScriptDir\visualizer_simple.py"
}
elseif ($Full) {
Write-Host "Starting Full Server on http://127.0.0.1:5000" -ForegroundColor Green
& $python "$ScriptDir\visualizer_server.py"
}
elseif ($Generate) {
& $python "$ScriptDir\generate_sample_data.py"
}
elseif ($Test) {
& $python "$ScriptDir\test_visualizer.py"
}
else {
Write-Host "MoFox Bot - Memory Graph Visualizer" -ForegroundColor Cyan
Write-Host ""
Write-Host "[1] Start Simple Server (Recommended)"
Write-Host "[2] Start Full Server"
Write-Host "[3] Generate Test Data"
Write-Host "[4] Run Tests"
Write-Host "[Q] Quit"
Write-Host ""
$choice = Read-Host "Select"
switch ($choice) {
"1" { & $python "$ScriptDir\visualizer_simple.py" }
"2" { & $python "$ScriptDir\visualizer_server.py" }
"3" { & $python "$ScriptDir\generate_sample_data.py" }
"4" { & $python "$ScriptDir\test_visualizer.py" }
default { exit 0 }
}
}

View File

@@ -1,356 +0,0 @@
"""
记忆图可视化服务器
提供 Web API 用于可视化记忆图数据
"""
import asyncio
import orjson
import logging
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List, Optional
from flask import Flask, jsonify, render_template, request
from flask_cors import CORS
# 添加项目根目录到 Python 路径
import sys
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from src.memory_graph.manager import MemoryManager
from src.memory_graph.models import EdgeType, MemoryType, NodeType
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
CORS(app) # 允许跨域请求
# 全局记忆管理器
memory_manager: Optional[MemoryManager] = None
def init_memory_manager():
"""初始化记忆管理器"""
global memory_manager
if memory_manager is None:
try:
memory_manager = MemoryManager()
# 在新的事件循环中初始化
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(memory_manager.initialize())
logger.info("记忆管理器初始化成功")
except Exception as e:
logger.error(f"初始化记忆管理器失败: {e}")
raise
@app.route('/')
def index():
"""主页面"""
return render_template('visualizer.html')
@app.route('/api/graph/full')
def get_full_graph():
"""
获取完整记忆图数据
返回所有节点和边,格式化为前端可用的结构
"""
try:
if memory_manager is None:
init_memory_manager()
# 获取所有记忆
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 获取所有记忆
all_memories = memory_manager.graph_store.get_all_memories()
# 构建节点和边数据
nodes_dict = {} # {node_id: node_data}
edges_dict = {} # {edge_id: edge_data} - 使用字典去重
memory_info = []
for memory in all_memories:
# 添加记忆信息
memory_info.append({
'id': memory.id,
'type': memory.memory_type.value,
'importance': memory.importance,
'activation': memory.activation,
'status': memory.status.value,
'created_at': memory.created_at.isoformat(),
'text': memory.to_text(),
'access_count': memory.access_count,
})
# 处理节点
for node in memory.nodes:
if node.id not in nodes_dict:
nodes_dict[node.id] = {
'id': node.id,
'label': node.content,
'type': node.node_type.value,
'group': node.node_type.name, # 用于颜色分组
'title': f"{node.node_type.value}: {node.content}",
'metadata': node.metadata,
'created_at': node.created_at.isoformat(),
}
# 处理边 - 使用字典自动去重
for edge in memory.edges:
edge_id = edge.id
# 如果ID已存在生成唯一ID
counter = 1
original_edge_id = edge_id
while edge_id in edges_dict:
edge_id = f"{original_edge_id}_{counter}"
counter += 1
edges_dict[edge_id] = {
'id': edge_id,
'from': edge.source_id,
'to': edge.target_id,
'label': edge.relation,
'type': edge.edge_type.value,
'importance': edge.importance,
'title': f"{edge.edge_type.value}: {edge.relation}",
'arrows': 'to',
'memory_id': memory.id,
}
nodes_list = list(nodes_dict.values())
edges_list = list(edges_dict.values())
return jsonify({
'success': True,
'data': {
'nodes': nodes_list,
'edges': edges_list,
'memories': memory_info,
'stats': {
'total_nodes': len(nodes_list),
'total_edges': len(edges_list),
'total_memories': len(all_memories),
}
}
})
except Exception as e:
logger.error(f"获取图数据失败: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/memory/<memory_id>')
def get_memory_detail(memory_id: str):
"""
获取特定记忆的详细信息
Args:
memory_id: 记忆ID
"""
try:
if memory_manager is None:
init_memory_manager()
memory = memory_manager.graph_store.get_memory_by_id(memory_id)
if memory is None:
return jsonify({
'success': False,
'error': '记忆不存在'
}), 404
return jsonify({
'success': True,
'data': memory.to_dict()
})
except Exception as e:
logger.error(f"获取记忆详情失败: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/search')
def search_memories():
"""
搜索记忆
Query参数:
- q: 搜索关键词
- type: 记忆类型过滤
- limit: 返回数量限制
"""
try:
if memory_manager is None:
init_memory_manager()
query = request.args.get('q', '')
memory_type = request.args.get('type', None)
limit = int(request.args.get('limit', 50))
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# 执行搜索
results = loop.run_until_complete(
memory_manager.search_memories(
query=query,
top_k=limit
)
)
# 构建返回数据
memories = []
for memory in results:
memories.append({
'id': memory.id,
'text': memory.to_text(),
'type': memory.memory_type.value,
'importance': memory.importance,
'created_at': memory.created_at.isoformat(),
})
return jsonify({
'success': True,
'data': {
'results': memories,
'count': len(memories),
}
})
except Exception as e:
logger.error(f"搜索失败: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/stats')
def get_statistics():
"""
获取记忆图统计信息
"""
try:
if memory_manager is None:
init_memory_manager()
# 获取统计信息
all_memories = memory_manager.graph_store.get_all_memories()
all_nodes = set()
all_edges = 0
for memory in all_memories:
for node in memory.nodes:
all_nodes.add(node.id)
all_edges += len(memory.edges)
stats = {
'total_memories': len(all_memories),
'total_nodes': len(all_nodes),
'total_edges': all_edges,
'node_types': {},
'memory_types': {},
}
# 统计节点类型分布
for memory in all_memories:
mem_type = memory.memory_type.value
stats['memory_types'][mem_type] = stats['memory_types'].get(mem_type, 0) + 1
for node in memory.nodes:
node_type = node.node_type.value
stats['node_types'][node_type] = stats['node_types'].get(node_type, 0) + 1
return jsonify({
'success': True,
'data': stats
})
except Exception as e:
logger.error(f"获取统计信息失败: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/files')
def list_files():
"""
列出所有可用的数据文件
注意: 完整版服务器直接使用内存中的数据,不支持文件切换
"""
try:
from pathlib import Path
data_dir = Path("data/memory_graph")
files = []
if data_dir.exists():
for f in data_dir.glob("*.json"):
stat = f.stat()
files.append({
'path': str(f),
'name': f.name,
'size': stat.st_size,
'size_kb': round(stat.st_size / 1024, 2),
'modified': datetime.fromtimestamp(stat.st_mtime).isoformat(),
'modified_readable': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S'),
'is_current': True # 完整版始终使用内存数据
})
return jsonify({
'success': True,
'files': files,
'count': len(files),
'current_file': 'memory_manager (实时数据)',
'note': '完整版服务器使用实时内存数据,如需切换文件请使用独立版服务器'
})
except Exception as e:
logger.error(f"获取文件列表失败: {e}", exc_info=True)
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/reload')
def reload_data():
"""
重新加载数据
"""
return jsonify({
'success': True,
'message': '完整版服务器使用实时数据,无需重新加载',
'note': '数据始终是最新的'
})
def run_server(host: str = '127.0.0.1', port: int = 5000, debug: bool = False):
"""
启动可视化服务器
Args:
host: 服务器地址
port: 端口号
debug: 是否开启调试模式
"""
logger.info(f"启动记忆图可视化服务器: http://{host}:{port}")
app.run(host=host, port=port, debug=debug)
if __name__ == '__main__':
run_server(debug=True)

View File

@@ -1,480 +0,0 @@
"""
记忆图可视化 - 独立版本
直接从存储的数据文件生成可视化,无需启动完整的记忆管理器
"""
import orjson
import sys
from pathlib import Path
from datetime import datetime
from typing import Any, Dict, List, Set
from pathlib import Path
from typing import Any, Dict, List, Optional, Set
# 添加项目根目录
project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root))
from flask import Flask, jsonify, render_template_string, request, send_from_directory
from flask_cors import CORS
app = Flask(__name__)
CORS(app)
# 数据缓存
graph_data_cache = None
data_dir = project_root / "data" / "memory_graph"
current_data_file = None # 当前选择的数据文件
def find_available_data_files() -> List[Path]:
"""查找所有可用的记忆图数据文件"""
files = []
if not data_dir.exists():
return files
# 查找多种可能的文件名
possible_files = [
"graph_store.json",
"memory_graph.json",
"graph_data.json",
]
for filename in possible_files:
file_path = data_dir / filename
if file_path.exists():
files.append(file_path)
# 查找所有备份文件
for pattern in ["graph_store_*.json", "memory_graph_*.json", "graph_data_*.json"]:
for backup_file in data_dir.glob(pattern):
if backup_file not in files:
files.append(backup_file)
# 查找backups子目录
backups_dir = data_dir / "backups"
if backups_dir.exists():
for backup_file in backups_dir.glob("**/*.json"):
if backup_file not in files:
files.append(backup_file)
# 查找data/backup目录
backup_dir = data_dir.parent / "backup"
if backup_dir.exists():
for backup_file in backup_dir.glob("**/graph_*.json"):
if backup_file not in files:
files.append(backup_file)
for backup_file in backup_dir.glob("**/memory_*.json"):
if backup_file not in files:
files.append(backup_file)
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
"""从磁盘加载图数据"""
global graph_data_cache, current_data_file
# 如果指定了新文件,清除缓存
if file_path is not None and file_path != current_data_file:
graph_data_cache = None
current_data_file = file_path
if graph_data_cache is not None:
return graph_data_cache
try:
# 确定要加载的文件
if current_data_file is not None:
graph_file = current_data_file
else:
# 尝试查找可用的数据文件
available_files = find_available_data_files()
if not available_files:
print(f"⚠️ 未找到任何图数据文件")
print(f"📂 搜索目录: {data_dir}")
return {
"nodes": [],
"edges": [],
"memories": [],
"stats": {"total_nodes": 0, "total_edges": 0, "total_memories": 0},
"error": "未找到数据文件",
"available_files": []
}
# 使用最新的文件
graph_file = available_files[0]
current_data_file = graph_file
print(f"📂 自动选择最新文件: {graph_file}")
if not graph_file.exists():
print(f"⚠️ 图数据文件不存在: {graph_file}")
return {
"nodes": [],
"edges": [],
"memories": [],
"stats": {"total_nodes": 0, "total_edges": 0, "total_memories": 0},
"error": f"文件不存在: {graph_file}"
}
print(f"📂 加载图数据: {graph_file}")
with open(graph_file, 'r', encoding='utf-8') as f:
data = orjson.loads(f.read())
# 解析数据
nodes_dict = {}
edges_list = []
memory_info = []
# 实际文件格式是 {nodes: [], edges: [], metadata: {}}
# 不是 {memories: [{nodes: [], edges: []}]}
nodes = data.get("nodes", [])
edges = data.get("edges", [])
metadata = data.get("metadata", {})
print(f"✅ 找到 {len(nodes)} 个节点, {len(edges)} 条边")
# 处理节点
for node in nodes:
node_id = node.get('id', '')
if node_id and node_id not in nodes_dict:
memory_ids = node.get('metadata', {}).get('memory_ids', [])
nodes_dict[node_id] = {
'id': node_id,
'label': node.get('content', ''),
'type': node.get('node_type', ''),
'group': extract_group_from_type(node.get('node_type', '')),
'title': f"{node.get('node_type', '')}: {node.get('content', '')}",
'metadata': node.get('metadata', {}),
'created_at': node.get('created_at', ''),
'memory_ids': memory_ids,
}
# 处理边 - 使用集合去重避免重复的边ID
existing_edge_ids = set()
for edge in edges:
# 边的ID字段可能是 'id' 或 'edge_id'
edge_id = edge.get('edge_id') or edge.get('id', '')
# 如果ID为空或已存在跳过这条边
if not edge_id or edge_id in existing_edge_ids:
continue
existing_edge_ids.add(edge_id)
memory_id = edge.get('metadata', {}).get('memory_id', '')
# 注意: GraphStore 保存的格式使用 'source'/'target', 不是 'source_id'/'target_id'
edges_list.append({
'id': edge_id,
'from': edge.get('source', edge.get('source_id', '')),
'to': edge.get('target', edge.get('target_id', '')),
'label': edge.get('relation', ''),
'type': edge.get('edge_type', ''),
'importance': edge.get('importance', 0.5),
'title': f"{edge.get('edge_type', '')}: {edge.get('relation', '')}",
'arrows': 'to',
'memory_id': memory_id,
})
# 从元数据中获取统计信息
stats = metadata.get('statistics', {})
total_memories = stats.get('total_memories', 0)
# TODO: 如果需要记忆详细信息,需要从其他地方加载
# 目前只有节点和边的数据
graph_data_cache = {
'nodes': list(nodes_dict.values()),
'edges': edges_list,
'memories': memory_info, # 空列表,因为文件中没有记忆详情
'stats': {
'total_nodes': len(nodes_dict),
'total_edges': len(edges_list),
'total_memories': total_memories,
},
'current_file': str(graph_file),
'file_size': graph_file.stat().st_size,
'file_modified': datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
}
print(f"📊 统计: {len(nodes_dict)} 个节点, {len(edges_list)} 条边, {total_memories} 条记忆")
print(f"📄 数据文件: {graph_file} ({graph_file.stat().st_size / 1024:.2f} KB)")
return graph_data_cache
except Exception as e:
print(f"❌ 加载失败: {e}")
import traceback
traceback.print_exc()
return {"nodes": [], "edges": [], "memories": [], "stats": {}}
def extract_group_from_type(node_type: str) -> str:
"""从节点类型提取分组名"""
# 假设类型格式为 "主体" 或 "SUBJECT"
type_mapping = {
'主体': 'SUBJECT',
'主题': 'TOPIC',
'客体': 'OBJECT',
'属性': 'ATTRIBUTE',
'': 'VALUE',
}
return type_mapping.get(node_type, node_type)
def generate_memory_text(memory: Dict[str, Any]) -> str:
"""生成记忆的文本描述"""
try:
nodes = {n['id']: n for n in memory.get('nodes', [])}
edges = memory.get('edges', [])
subject_id = memory.get('subject_id', '')
if not subject_id or subject_id not in nodes:
return f"[记忆 {memory.get('id', '')[:8]}]"
parts = [nodes[subject_id]['content']]
# 找主题节点
for edge in edges:
if edge.get('edge_type') == '记忆类型' and edge.get('source_id') == subject_id:
topic_id = edge.get('target_id', '')
if topic_id in nodes:
parts.append(nodes[topic_id]['content'])
# 找客体
for e2 in edges:
if e2.get('edge_type') == '核心关系' and e2.get('source_id') == topic_id:
obj_id = e2.get('target_id', '')
if obj_id in nodes:
parts.append(f"{e2.get('relation', '')} {nodes[obj_id]['content']}")
break
break
return " ".join(parts)
except Exception:
return f"[记忆 {memory.get('id', '')[:8]}]"
# 使用内嵌的HTML模板(与之前相同)
HTML_TEMPLATE = open(project_root / "tools" / "memory_visualizer" / "templates" / "visualizer.html", 'r', encoding='utf-8').read()
@app.route('/')
def index():
"""主页面"""
return render_template_string(HTML_TEMPLATE)
@app.route('/api/graph/full')
def get_full_graph():
"""获取完整记忆图数据"""
try:
data = load_graph_data()
return jsonify({
'success': True,
'data': data
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/memory/<memory_id>')
def get_memory_detail(memory_id: str):
"""获取记忆详情"""
try:
data = load_graph_data()
memory = next((m for m in data['memories'] if m['id'] == memory_id), None)
if memory is None:
return jsonify({
'success': False,
'error': '记忆不存在'
}), 404
return jsonify({
'success': True,
'data': memory
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/search')
def search_memories():
"""搜索记忆"""
try:
query = request.args.get('q', '').lower()
limit = int(request.args.get('limit', 50))
data = load_graph_data()
# 简单的文本匹配搜索
results = []
for memory in data['memories']:
text = memory.get('text', '').lower()
if query in text:
results.append(memory)
return jsonify({
'success': True,
'data': {
'results': results[:limit],
'count': len(results),
}
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/stats')
def get_statistics():
"""获取统计信息"""
try:
data = load_graph_data()
# 扩展统计信息
node_types = {}
memory_types = {}
for node in data['nodes']:
node_type = node.get('type', 'Unknown')
node_types[node_type] = node_types.get(node_type, 0) + 1
for memory in data['memories']:
mem_type = memory.get('type', 'Unknown')
memory_types[mem_type] = memory_types.get(mem_type, 0) + 1
stats = data.get('stats', {})
stats['node_types'] = node_types
stats['memory_types'] = memory_types
return jsonify({
'success': True,
'data': stats
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/reload')
def reload_data():
"""重新加载数据"""
global graph_data_cache
graph_data_cache = None
data = load_graph_data()
return jsonify({
'success': True,
'message': '数据已重新加载',
'stats': data.get('stats', {})
})
@app.route('/api/files')
def list_files():
"""列出所有可用的数据文件"""
try:
files = find_available_data_files()
file_list = []
for f in files:
stat = f.stat()
file_list.append({
'path': str(f),
'name': f.name,
'size': stat.st_size,
'size_kb': round(stat.st_size / 1024, 2),
'modified': datetime.fromtimestamp(stat.st_mtime).isoformat(),
'modified_readable': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S'),
'is_current': str(f) == str(current_data_file) if current_data_file else False
})
return jsonify({
'success': True,
'files': file_list,
'count': len(file_list),
'current_file': str(current_data_file) if current_data_file else None
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
@app.route('/api/select_file', methods=['POST'])
def select_file():
"""选择要加载的数据文件"""
global graph_data_cache, current_data_file
try:
data = request.get_json()
file_path = data.get('file_path')
if not file_path:
return jsonify({
'success': False,
'error': '未提供文件路径'
}), 400
file_path = Path(file_path)
if not file_path.exists():
return jsonify({
'success': False,
'error': f'文件不存在: {file_path}'
}), 404
# 清除缓存并加载新文件
graph_data_cache = None
current_data_file = file_path
graph_data = load_graph_data(file_path)
return jsonify({
'success': True,
'message': f'已切换到文件: {file_path.name}',
'stats': graph_data.get('stats', {})
})
except Exception as e:
return jsonify({
'success': False,
'error': str(e)
}), 500
def run_server(host: str = '127.0.0.1', port: int = 5001, debug: bool = False):
"""启动服务器"""
print("=" * 60)
print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)")
print("=" * 60)
print(f"📂 数据目录: {data_dir}")
print(f"🌐 访问地址: http://{host}:{port}")
print("⏹️ 按 Ctrl+C 停止服务器")
print("=" * 60)
print()
# 预加载数据
load_graph_data()
app.run(host=host, port=port, debug=debug)
if __name__ == '__main__':
try:
run_server(debug=True)
except KeyboardInterrupt:
print("\n\n👋 服务器已停止")
except Exception as e:
print(f"\n❌ 启动失败: {e}")
sys.exit(1)

View File

@@ -1,16 +0,0 @@
#!/usr/bin/env pwsh
# ======================================================================
# 记忆图可视化工具 - 快捷启动脚本
# ======================================================================
# 此脚本是快捷方式,实际脚本位于 tools/memory_visualizer/ 目录
# ======================================================================
$visualizerScript = Join-Path $PSScriptRoot "tools\memory_visualizer\visualizer.ps1"
if (Test-Path $visualizerScript) {
& $visualizerScript @args
} else {
Write-Host "❌ 错误:找不到可视化工具脚本" -ForegroundColor Red
Write-Host " 预期位置: $visualizerScript" -ForegroundColor Yellow
exit 1
}