This commit is contained in:
明天好像没什么
2025-11-07 21:01:45 +08:00
parent 80b040da2f
commit c8d7c09625
49 changed files with 854 additions and 872 deletions

View File

@@ -17,19 +17,19 @@
import argparse import argparse
import json import json
from pathlib import Path
from typing import Dict, Any, List, Tuple
import logging import logging
from pathlib import Path
from typing import Any
import orjson import orjson
# 配置日志 # 配置日志
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[ handlers=[
logging.StreamHandler(), logging.StreamHandler(),
logging.FileHandler('embedding_cleanup.log', encoding='utf-8') logging.FileHandler("embedding_cleanup.log", encoding="utf-8")
] ]
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,13 +49,13 @@ class EmbeddingCleaner:
self.cleaned_files = [] self.cleaned_files = []
self.errors = [] self.errors = []
self.stats = { self.stats = {
'files_processed': 0, "files_processed": 0,
'embedings_removed': 0, "embedings_removed": 0,
'bytes_saved': 0, "bytes_saved": 0,
'nodes_processed': 0 "nodes_processed": 0
} }
def find_json_files(self) -> List[Path]: def find_json_files(self) -> list[Path]:
"""查找可能包含向量数据的 JSON 文件""" """查找可能包含向量数据的 JSON 文件"""
json_files = [] json_files = []
@@ -65,7 +65,7 @@ class EmbeddingCleaner:
json_files.append(memory_graph_file) json_files.append(memory_graph_file)
# 测试数据文件 # 测试数据文件
test_dir = self.data_dir / "test_*" self.data_dir / "test_*"
for test_path in self.data_dir.glob("test_*/memory_graph.json"): for test_path in self.data_dir.glob("test_*/memory_graph.json"):
if test_path.exists(): if test_path.exists():
json_files.append(test_path) json_files.append(test_path)
@@ -82,7 +82,7 @@ class EmbeddingCleaner:
logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件") logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件")
return json_files return json_files
def analyze_embedding_in_data(self, data: Dict[str, Any]) -> int: def analyze_embedding_in_data(self, data: dict[str, Any]) -> int:
""" """
分析数据中的 embedding 字段数量 分析数据中的 embedding 字段数量
@@ -97,7 +97,7 @@ class EmbeddingCleaner:
def count_embeddings(obj): def count_embeddings(obj):
nonlocal embedding_count nonlocal embedding_count
if isinstance(obj, dict): if isinstance(obj, dict):
if 'embedding' in obj: if "embedding" in obj:
embedding_count += 1 embedding_count += 1
for value in obj.values(): for value in obj.values():
count_embeddings(value) count_embeddings(value)
@@ -108,7 +108,7 @@ class EmbeddingCleaner:
count_embeddings(data) count_embeddings(data)
return embedding_count return embedding_count
def clean_embedding_from_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: def clean_embedding_from_data(self, data: dict[str, Any]) -> tuple[dict[str, Any], int]:
""" """
从数据中移除 embedding 字段 从数据中移除 embedding 字段
@@ -123,8 +123,8 @@ class EmbeddingCleaner:
def remove_embeddings(obj): def remove_embeddings(obj):
nonlocal removed_count nonlocal removed_count
if isinstance(obj, dict): if isinstance(obj, dict):
if 'embedding' in obj: if "embedding" in obj:
del obj['embedding'] del obj["embedding"]
removed_count += 1 removed_count += 1
for value in obj.values(): for value in obj.values():
remove_embeddings(value) remove_embeddings(value)
@@ -162,14 +162,14 @@ class EmbeddingCleaner:
data = orjson.loads(original_content) data = orjson.loads(original_content)
except orjson.JSONDecodeError: except orjson.JSONDecodeError:
# 回退到标准 json # 回退到标准 json
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# 分析 embedding 数据 # 分析 embedding 数据
embedding_count = self.analyze_embedding_in_data(data) embedding_count = self.analyze_embedding_in_data(data)
if embedding_count == 0: if embedding_count == 0:
logger.info(f" ✓ 文件中没有 embedding 数据,跳过") logger.info(" ✓ 文件中没有 embedding 数据,跳过")
return True return True
logger.info(f" 发现 {embedding_count} 个 embedding 字段") logger.info(f" 发现 {embedding_count} 个 embedding 字段")
@@ -193,30 +193,30 @@ class EmbeddingCleaner:
cleaned_data, cleaned_data,
indent=2, indent=2,
ensure_ascii=False ensure_ascii=False
).encode('utf-8') ).encode("utf-8")
cleaned_size = len(cleaned_content) cleaned_size = len(cleaned_content)
bytes_saved = original_size - cleaned_size bytes_saved = original_size - cleaned_size
# 原子写入 # 原子写入
temp_file = file_path.with_suffix('.tmp') temp_file = file_path.with_suffix(".tmp")
temp_file.write_bytes(cleaned_content) temp_file.write_bytes(cleaned_content)
temp_file.replace(file_path) temp_file.replace(file_path)
logger.info(f" ✓ 清理完成:") logger.info(" ✓ 清理完成:")
logger.info(f" - 移除 embedding 字段: {removed_count}") logger.info(f" - 移除 embedding 字段: {removed_count}")
logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)") logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)")
logger.info(f" - 新文件大小: {cleaned_size:,} 字节") logger.info(f" - 新文件大小: {cleaned_size:,} 字节")
# 更新统计 # 更新统计
self.stats['embedings_removed'] += removed_count self.stats["embedings_removed"] += removed_count
self.stats['bytes_saved'] += bytes_saved self.stats["bytes_saved"] += bytes_saved
else: else:
logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段") logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段")
self.stats['embedings_removed'] += embedding_count self.stats["embedings_removed"] += embedding_count
self.stats['files_processed'] += 1 self.stats["files_processed"] += 1
self.cleaned_files.append(file_path) self.cleaned_files.append(file_path)
return True return True
@@ -236,12 +236,12 @@ class EmbeddingCleaner:
节点数量 节点数量
""" """
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
node_count = 0 node_count = 0
if 'nodes' in data and isinstance(data['nodes'], list): if "nodes" in data and isinstance(data["nodes"], list):
node_count = len(data['nodes']) node_count = len(data["nodes"])
return node_count return node_count
@@ -268,7 +268,7 @@ class EmbeddingCleaner:
# 统计总节点数 # 统计总节点数
total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files) total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files)
self.stats['nodes_processed'] = total_nodes self.stats["nodes_processed"] = total_nodes
logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点") logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点")
@@ -295,8 +295,8 @@ class EmbeddingCleaner:
if not dry_run: if not dry_run:
logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节") logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节")
if self.stats['bytes_saved'] > 0: if self.stats["bytes_saved"] > 0:
mb_saved = self.stats['bytes_saved'] / 1024 / 1024 mb_saved = self.stats["bytes_saved"] / 1024 / 1024
logger.info(f"节省空间: {mb_saved:.2f} MB") logger.info(f"节省空间: {mb_saved:.2f} MB")
if self.errors: if self.errors:
@@ -342,7 +342,7 @@ def main():
print(" 请确保向量数据库正在正常工作。") print(" 请确保向量数据库正在正常工作。")
print() print()
response = input("确认继续?(yes/no): ") response = input("确认继续?(yes/no): ")
if response.lower() not in ['yes', 'y', '']: if response.lower() not in ["yes", "y", ""]:
print("操作已取消") print("操作已取消")
return return

View File

@@ -22,7 +22,6 @@ import argparse
import asyncio import asyncio
import gc import gc
import json import json
import os
import subprocess import subprocess
import sys import sys
import threading import threading
@@ -30,7 +29,6 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional
import psutil import psutil
@@ -101,10 +99,10 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
if children: if children:
print(f" 子进程: {info['children_count']}") print(f" 子进程: {info['children_count']}")
print(f" 子进程内存: {info['children_mem_mb']:.2f} MB") print(f" 子进程内存: {info['children_mem_mb']:.2f} MB")
total_mem = info['rss_mb'] + info['children_mem_mb'] total_mem = info["rss_mb"] + info["children_mem_mb"]
print(f" 总内存: {total_mem:.2f} MB") print(f" 总内存: {total_mem:.2f} MB")
print(f"\n 📋 子进程详情:") print("\n 📋 子进程详情:")
for idx, child in enumerate(children, 1): for idx, child in enumerate(children, 1):
try: try:
child_mem = child.memory_info().rss / 1024 / 1024 child_mem = child.memory_info().rss / 1024 / 1024
@@ -119,13 +117,13 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
if len(history) > 1: if len(history) > 1:
prev = history[-2] prev = history[-2]
rss_diff = info['rss_mb'] - prev['rss_mb'] rss_diff = info["rss_mb"] - prev["rss_mb"]
print(f"\n变化:") print("\n变化:")
print(f" RSS: {rss_diff:+.2f} MB") print(f" RSS: {rss_diff:+.2f} MB")
if rss_diff > 10: if rss_diff > 10:
print(f" ⚠️ 内存增长较快!") print(" ⚠️ 内存增长较快!")
if info['rss_mb'] > 1000: if info["rss_mb"] > 1000:
print(f" ⚠️ 内存使用超过 1GB") print(" ⚠️ 内存使用超过 1GB")
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
await asyncio.sleep(interval) await asyncio.sleep(interval)
@@ -163,7 +161,7 @@ def save_process_history(history: list, pid: int):
f.write(f"RSS: {info['rss_mb']:.2f} MB\n") f.write(f"RSS: {info['rss_mb']:.2f} MB\n")
f.write(f"VMS: {info['vms_mb']:.2f} MB\n") f.write(f"VMS: {info['vms_mb']:.2f} MB\n")
f.write(f"占比: {info['percent']:.2f}%\n") f.write(f"占比: {info['percent']:.2f}%\n")
if info['children_count'] > 0: if info["children_count"] > 0:
f.write(f"子进程: {info['children_count']}\n") f.write(f"子进程: {info['children_count']}\n")
f.write(f"子进程内存: {info['children_mem_mb']:.2f} MB\n") f.write(f"子进程内存: {info['children_mem_mb']:.2f} MB\n")
f.write("\n") f.write("\n")
@@ -264,7 +262,7 @@ async def run_monitor_mode(interval: int):
class ObjectMemoryProfiler: class ObjectMemoryProfiler:
"""对象级内存分析器""" """对象级内存分析器"""
def __init__(self, interval: int = 10, output_file: Optional[str] = None, object_limit: int = 20): def __init__(self, interval: int = 10, output_file: str | None = None, object_limit: int = 20):
self.interval = interval self.interval = interval
self.output_file = output_file self.output_file = output_file
self.object_limit = object_limit self.object_limit = object_limit
@@ -274,7 +272,7 @@ class ObjectMemoryProfiler:
self.tracker = tracker.SummaryTracker() self.tracker = tracker.SummaryTracker()
self.iteration = 0 self.iteration = 0
def get_object_stats(self) -> Dict: def get_object_stats(self) -> dict:
"""获取当前进程的对象统计(所有线程)""" """获取当前进程的对象统计(所有线程)"""
if not PYMPLER_AVAILABLE: if not PYMPLER_AVAILABLE:
return {} return {}
@@ -317,7 +315,7 @@ class ObjectMemoryProfiler:
print(f"❌ 获取对象统计失败: {e}") print(f"❌ 获取对象统计失败: {e}")
return {} return {}
def _get_module_stats(self, all_objects: list) -> Dict: def _get_module_stats(self, all_objects: list) -> dict:
"""统计各模块的内存占用""" """统计各模块的内存占用"""
module_mem = defaultdict(lambda: {"count": 0, "size": 0}) module_mem = defaultdict(lambda: {"count": 0, "size": 0})
@@ -329,7 +327,7 @@ class ObjectMemoryProfiler:
if module_name: if module_name:
# 获取顶级模块名(例如 src.chat.xxx -> src # 获取顶级模块名(例如 src.chat.xxx -> src
top_module = module_name.split('.')[0] top_module = module_name.split(".")[0]
obj_size = sys.getsizeof(obj) obj_size = sys.getsizeof(obj)
module_mem[top_module]["count"] += 1 module_mem[top_module]["count"] += 1
@@ -351,7 +349,7 @@ class ObjectMemoryProfiler:
"total_modules": len(module_mem) "total_modules": len(module_mem)
} }
def print_stats(self, stats: Dict, iteration: int): def print_stats(self, stats: dict, iteration: int):
"""打印统计信息""" """打印统计信息"""
print("\n" + "=" * 80) print("\n" + "=" * 80)
print(f"🔍 对象级内存分析 #{iteration} - {time.strftime('%H:%M:%S')}") print(f"🔍 对象级内存分析 #{iteration} - {time.strftime('%H:%M:%S')}")
@@ -374,8 +372,8 @@ class ObjectMemoryProfiler:
print(f"{obj_type:<50} {obj_count:>12,} {size_str:>15}") print(f"{obj_type:<50} {obj_count:>12,} {size_str:>15}")
if "module_stats" in stats and stats["module_stats"]: if stats.get("module_stats"):
print(f"\n📚 模块内存占用 (前 20 个模块):\n") print("\n📚 模块内存占用 (前 20 个模块):\n")
print(f"{'模块名':<40} {'对象数':>12} {'总内存':>15}") print(f"{'模块名':<40} {'对象数':>12} {'总内存':>15}")
print("-" * 80) print("-" * 80)
@@ -402,7 +400,7 @@ class ObjectMemoryProfiler:
if "gc_stats" in stats: if "gc_stats" in stats:
gc_stats = stats["gc_stats"] gc_stats = stats["gc_stats"]
print(f"\n🗑️ 垃圾回收:") print("\n🗑️ 垃圾回收:")
print(f" 代 0: {gc_stats['collections'][0]:,}") print(f" 代 0: {gc_stats['collections'][0]:,}")
print(f" 代 1: {gc_stats['collections'][1]:,}") print(f" 代 1: {gc_stats['collections'][1]:,}")
print(f" 代 2: {gc_stats['collections'][2]:,}") print(f" 代 2: {gc_stats['collections'][2]:,}")
@@ -423,7 +421,7 @@ class ObjectMemoryProfiler:
self.tracker.print_diff() self.tracker.print_diff()
print("-" * 80) print("-" * 80)
def save_to_file(self, stats: Dict): def save_to_file(self, stats: dict):
"""保存统计信息到文件""" """保存统计信息到文件"""
if not self.output_file: if not self.output_file:
return return
@@ -441,7 +439,7 @@ class ObjectMemoryProfiler:
for obj_type, obj_count, obj_size in stats["summary"]: for obj_type, obj_count, obj_size in stats["summary"]:
f.write(f" {obj_type}: {obj_count:,} 个, {obj_size:,} 字节\n") f.write(f" {obj_type}: {obj_count:,} 个, {obj_size:,} 字节\n")
if "module_stats" in stats and stats["module_stats"]: if stats.get("module_stats"):
f.write("\n模块统计 (前 20 个):\n") f.write("\n模块统计 (前 20 个):\n")
for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]: for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]:
f.write(f" {module_name}: {obj_count:,} 个对象, {obj_size:,} 字节\n") f.write(f" {module_name}: {obj_count:,} 个对象, {obj_size:,} 字节\n")
@@ -452,7 +450,7 @@ class ObjectMemoryProfiler:
# 保存 JSONL # 保存 JSONL
jsonl_path = str(self.output_file) + ".jsonl" jsonl_path = str(self.output_file) + ".jsonl"
record = { record = {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"iteration": self.iteration, "iteration": self.iteration,
"total_objects": stats.get("total_objects", 0), "total_objects": stats.get("total_objects", 0),
"threads": stats.get("threads", []), "threads": stats.get("threads", []),
@@ -479,7 +477,7 @@ class ObjectMemoryProfiler:
self.running = True self.running = True
def monitor_loop(): def monitor_loop():
print(f"🚀 对象分析器已启动") print("🚀 对象分析器已启动")
print(f" 监控间隔: {self.interval}") print(f" 监控间隔: {self.interval}")
print(f" 对象类型限制: {self.object_limit}") print(f" 对象类型限制: {self.object_limit}")
print(f" 输出文件: {self.output_file or ''}") print(f" 输出文件: {self.output_file or ''}")
@@ -506,14 +504,14 @@ class ObjectMemoryProfiler:
monitor_thread = threading.Thread(target=monitor_loop, daemon=True) monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
monitor_thread.start() monitor_thread.start()
print(f"✓ 监控线程已启动\n") print("✓ 监控线程已启动\n")
def stop(self): def stop(self):
"""停止监控""" """停止监控"""
self.running = False self.running = False
def run_objects_mode(interval: int, output: Optional[str], object_limit: int): def run_objects_mode(interval: int, output: str | None, object_limit: int):
"""对象分析模式主函数""" """对象分析模式主函数"""
if not PYMPLER_AVAILABLE: if not PYMPLER_AVAILABLE:
print("❌ pympler 未安装,无法使用对象分析模式") print("❌ pympler 未安装,无法使用对象分析模式")
@@ -549,9 +547,9 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
try: try:
import bot import bot
if hasattr(bot, 'main_async'): if hasattr(bot, "main_async"):
asyncio.run(bot.main_async()) asyncio.run(bot.main_async())
elif hasattr(bot, 'main'): elif hasattr(bot, "main"):
bot.main() bot.main()
else: else:
print("⚠️ bot.py 未找到 main_async() 或 main() 函数") print("⚠️ bot.py 未找到 main_async() 或 main() 函数")
@@ -577,10 +575,10 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
# 可视化模式 # 可视化模式
# ============================================================================ # ============================================================================
def load_jsonl(path: Path) -> List[Dict]: def load_jsonl(path: Path) -> list[dict]:
"""加载 JSONL 文件""" """加载 JSONL 文件"""
snapshots = [] snapshots = []
with open(path, "r", encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line: if not line:
@@ -592,7 +590,7 @@ def load_jsonl(path: Path) -> List[Dict]:
return snapshots return snapshots
def aggregate_top_types(snapshots: List[Dict], top_n: int = 10): def aggregate_top_types(snapshots: list[dict], top_n: int = 10):
"""聚合前 N 个对象类型的时间序列""" """聚合前 N 个对象类型的时间序列"""
type_max = defaultdict(int) type_max = defaultdict(int)
for snap in snapshots: for snap in snapshots:
@@ -622,7 +620,7 @@ def aggregate_top_types(snapshots: List[Dict], top_n: int = 10):
return times, series return times, series
def plot_series(times: List, series: Dict, output: Path, top_n: int): def plot_series(times: list, series: dict, output: Path, top_n: int):
"""绘制时间序列图""" """绘制时间序列图"""
plt.figure(figsize=(14, 8)) plt.figure(figsize=(14, 8))

View File

@@ -1,5 +1,4 @@
import os import os
import random
import time import time
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any

View File

@@ -127,7 +127,7 @@ class StyleLearner:
# 最后使用时间(越近越好) # 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0) last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float('inf') time_since_used = current_time - last_used if last_used > 0 else float("inf")
# 综合分数:使用次数越多越好,距离上次使用时间越短越好 # 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响 # 使用对数来平滑使用次数的影响

View File

@@ -8,7 +8,6 @@ from datetime import datetime
from typing import Any from typing import Any
import numpy as np import numpy as np
import orjson
from sqlalchemy import select from sqlalchemy import select
from src.common.config_helpers import resolve_embedding_dimension from src.common.config_helpers import resolve_embedding_dimension
@@ -605,7 +604,7 @@ class BotInterestManager:
) )
# 如果有新生成的扩展embedding保存到缓存文件 # 如果有新生成的扩展embedding保存到缓存文件
if hasattr(self, '_new_expanded_embeddings_generated') and self._new_expanded_embeddings_generated: if hasattr(self, "_new_expanded_embeddings_generated") and self._new_expanded_embeddings_generated:
await self._save_embedding_cache_to_file(self.current_interests.personality_id) await self._save_embedding_cache_to_file(self.current_interests.personality_id)
self._new_expanded_embeddings_generated = False self._new_expanded_embeddings_generated = False
logger.debug("💾 已保存新生成的扩展embedding到缓存文件") logger.debug("💾 已保存新生成的扩展embedding到缓存文件")
@@ -670,59 +669,59 @@ class BotInterestManager:
tag_lower = tag_name.lower() tag_lower = tag_name.lower()
# 技术编程类标签(具体化描述) # 技术编程类标签(具体化描述)
if any(word in tag_lower for word in ['python', 'java', 'code', '代码', '编程', '脚本', '算法', '开发']): if any(word in tag_lower for word in ["python", "java", "code", "代码", "编程", "脚本", "算法", "开发"]):
if 'python' in tag_lower: if "python" in tag_lower:
return f"讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题" return "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
elif '算法' in tag_lower: elif "算法" in tag_lower:
return f"讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化" return "讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化"
elif '代码' in tag_lower or '被窝' in tag_lower: elif "代码" in tag_lower or "被窝" in tag_lower:
return f"讨论写代码、编程开发、代码实现、技术方案、编程技巧" return "讨论写代码、编程开发、代码实现、技术方案、编程技巧"
else: else:
return f"讨论编程开发、软件技术、代码编写、技术实现" return "讨论编程开发、软件技术、代码编写、技术实现"
# 情感表达类标签(具体化为真实对话场景) # 情感表达类标签(具体化为真实对话场景)
elif any(word in tag_lower for word in ['治愈', '撒娇', '安慰', '呼噜', '', '卖萌']): elif any(word in tag_lower for word in ["治愈", "撒娇", "安慰", "呼噜", "", "卖萌"]):
return f"想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话" return "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
# 游戏娱乐类标签(具体游戏场景) # 游戏娱乐类标签(具体游戏场景)
elif any(word in tag_lower for word in ['游戏', '网游', 'mmo', '', '']): elif any(word in tag_lower for word in ["游戏", "网游", "mmo", "", ""]):
return f"讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得" return "讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得"
# 动漫影视类标签(具体观看行为) # 动漫影视类标签(具体观看行为)
elif any(word in tag_lower for word in ['', '动漫', '视频', 'b站', '弹幕', '追番', '云新番']): elif any(word in tag_lower for word in ["", "动漫", "视频", "b站", "弹幕", "追番", "云新番"]):
# 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西" # 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西"
if '' in tag_lower or '新番' in tag_lower: if "" in tag_lower or "新番" in tag_lower:
return f"讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色" return "讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色"
else: else:
return f"讨论动漫番剧内容、B站视频、弹幕文化、追番体验" return "讨论动漫番剧内容、B站视频、弹幕文化、追番体验"
# 社交平台类标签(具体平台行为) # 社交平台类标签(具体平台行为)
elif any(word in tag_lower for word in ['小红书', '贴吧', '论坛', '社区', '吃瓜', '八卦']): elif any(word in tag_lower for word in ["小红书", "贴吧", "论坛", "社区", "吃瓜", "八卦"]):
if '吃瓜' in tag_lower: if "吃瓜" in tag_lower:
return f"聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题" return "聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题"
else: else:
return f"讨论社交平台内容、网络社区话题、论坛讨论、分享生活" return "讨论社交平台内容、网络社区话题、论坛讨论、分享生活"
# 生活日常类标签(具体萌宠场景) # 生活日常类标签(具体萌宠场景)
elif any(word in tag_lower for word in ['', '宠物', '尾巴', '耳朵', '毛绒']): elif any(word in tag_lower for word in ["", "宠物", "尾巴", "耳朵", "毛绒"]):
return f"讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得" return "讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得"
# 状态心情类标签(具体情绪状态) # 状态心情类标签(具体情绪状态)
elif any(word in tag_lower for word in ['社恐', '隐身', '流浪', '深夜', '被窝']): elif any(word in tag_lower for word in ["社恐", "隐身", "流浪", "深夜", "被窝"]):
if '社恐' in tag_lower: if "社恐" in tag_lower:
return f"表达社交焦虑、不想见人、想躲起来、害怕社交的心情" return "表达社交焦虑、不想见人、想躲起来、害怕社交的心情"
elif '深夜' in tag_lower: elif "深夜" in tag_lower:
return f"深夜睡不着、熬夜、夜猫子、深夜思考人生的对话" return "深夜睡不着、熬夜、夜猫子、深夜思考人生的对话"
else: else:
return f"表达当前心情状态、个人感受、生活状态" return "表达当前心情状态、个人感受、生活状态"
# 物品装备类标签(具体使用场景) # 物品装备类标签(具体使用场景)
elif any(word in tag_lower for word in ['键盘', '耳机', '装备', '设备']): elif any(word in tag_lower for word in ["键盘", "耳机", "装备", "设备"]):
return f"讨论键盘耳机装备、数码产品、使用体验、装备推荐评测" return "讨论键盘耳机装备、数码产品、使用体验、装备推荐评测"
# 互动关系类标签 # 互动关系类标签
elif any(word in tag_lower for word in ['拾风', '互怼', '互动']): elif any(word in tag_lower for word in ["拾风", "互怼", "互动"]):
return f"聊天互动、开玩笑、友好互怼、日常对话交流" return "聊天互动、开玩笑、友好互怼、日常对话交流"
# 默认:尽量具体化 # 默认:尽量具体化
else: else:
@@ -1011,9 +1010,10 @@ class BotInterestManager:
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None: async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None:
"""从文件加载embedding缓存""" """从文件加载embedding缓存"""
try: try:
import orjson
from pathlib import Path from pathlib import Path
import orjson
cache_dir = Path("data/embedding") cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json" cache_file = cache_dir / f"{personality_id}_embeddings.json"
@@ -1053,9 +1053,10 @@ class BotInterestManager:
async def _save_embedding_cache_to_file(self, personality_id: str): async def _save_embedding_cache_to_file(self, personality_id: str):
"""保存embedding缓存到文件包括扩展标签的embedding""" """保存embedding缓存到文件包括扩展标签的embedding"""
try: try:
import orjson
from pathlib import Path
from datetime import datetime from datetime import datetime
from pathlib import Path
import orjson
cache_dir = Path("data/embedding") cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -9,8 +9,8 @@ from .scheduler_dispatcher import SchedulerDispatcher, scheduler_dispatcher
__all__ = [ __all__ = [
"MessageManager", "MessageManager",
"SingleStreamContextManager",
"SchedulerDispatcher", "SchedulerDispatcher",
"SingleStreamContextManager",
"message_manager", "message_manager",
"scheduler_dispatcher", "scheduler_dispatcher",
] ]

View File

@@ -73,7 +73,7 @@ class SingleStreamContextManager:
cache_enabled = global_config.chat.enable_message_cache cache_enabled = global_config.chat.enable_message_cache
use_cache_system = message_manager.is_running and cache_enabled use_cache_system = message_manager.is_running and cache_enabled
if not cache_enabled: if not cache_enabled:
logger.debug(f"消息缓存系统已在配置中禁用") logger.debug("消息缓存系统已在配置中禁用")
except Exception as e: except Exception as e:
logger.debug(f"MessageManager不可用使用直接添加: {e}") logger.debug(f"MessageManager不可用使用直接添加: {e}")
use_cache_system = False use_cache_system = False

View File

@@ -4,13 +4,11 @@
""" """
import asyncio import asyncio
import random
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats

View File

@@ -367,7 +367,7 @@ class ChatBot:
message_segment = message_data.get("message_segment") message_segment = message_data.get("message_segment")
if message_segment and isinstance(message_segment, dict): if message_segment and isinstance(message_segment, dict):
if message_segment.get("type") == "adapter_response": if message_segment.get("type") == "adapter_response":
logger.info(f"[DEBUG bot.py message_process] 检测到adapter_response立即处理") logger.info("[DEBUG bot.py message_process] 检测到adapter_response立即处理")
await self._handle_adapter_response_from_dict(message_segment.get("data")) await self._handle_adapter_response_from_dict(message_segment.get("data"))
return return

View File

@@ -557,7 +557,7 @@ class DefaultReplyer:
participants = [] participants = []
try: try:
# 尝试从聊天流中获取参与者信息 # 尝试从聊天流中获取参与者信息
if hasattr(stream, 'chat_history_manager'): if hasattr(stream, "chat_history_manager"):
history_manager = stream.chat_history_manager history_manager = stream.chat_history_manager
# 获取最近的参与者列表 # 获取最近的参与者列表
recent_records = history_manager.get_memory_chat_history( recent_records = history_manager.get_memory_chat_history(
@@ -586,9 +586,9 @@ class DefaultReplyer:
formatted_history = "" formatted_history = ""
if chat_history: if chat_history:
# 移除过长的历史记录,只保留最近部分 # 移除过长的历史记录,只保留最近部分
lines = chat_history.strip().split('\n') lines = chat_history.strip().split("\n")
recent_lines = lines[-10:] if len(lines) > 10 else lines recent_lines = lines[-10:] if len(lines) > 10 else lines
formatted_history = '\n'.join(recent_lines) formatted_history = "\n".join(recent_lines)
query_context = { query_context = {
"chat_history": formatted_history, "chat_history": formatted_history,
@@ -725,7 +725,7 @@ class DefaultReplyer:
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result") tool_result.get("type", "tool_result")
# 不进行截断,让工具自己处理结果长度 # 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}") current_results_parts.append(f"- **{tool_name}**: {content}")
@@ -1913,7 +1913,6 @@ class DefaultReplyer:
return return
# 使用统一记忆系统存储记忆 # 使用统一记忆系统存储记忆
from src.chat.memory_system import get_memory_system
stream = self.chat_stream stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None) user_info_obj = getattr(stream, "user_info", None)
@@ -2036,7 +2035,7 @@ class DefaultReplyer:
timestamp=time.time(), timestamp=time.time(),
limit=int(global_config.chat.max_context_size), limit=int(global_config.chat.max_context_size),
) )
chat_history = await build_readable_messages( await build_readable_messages(
message_list_before_short, message_list_before_short,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,

View File

@@ -91,13 +91,11 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
# 1. 判断是否为私聊(强提及) # 1. 判断是否为私聊(强提及)
group_info = getattr(message, "group_info", None) group_info = getattr(message, "group_info", None)
if not group_info or not getattr(group_info, "group_id", None): if not group_info or not getattr(group_info, "group_id", None):
is_private = True
mention_type = 2 mention_type = 2
logger.debug("检测到私聊消息 - 强提及") logger.debug("检测到私聊消息 - 强提及")
# 2. 判断是否被@(强提及) # 2. 判断是否被@(强提及)
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text): if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text):
is_at = True
mention_type = 2 mention_type = 2
logger.debug("检测到@提及 - 强提及") logger.debug("检测到@提及 - 强提及")
@@ -108,7 +106,6 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:", rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:",
processed_text, processed_text,
): ):
is_replied = True
mention_type = 2 mention_type = 2
logger.debug("检测到回复消息 - 强提及") logger.debug("检测到回复消息 - 强提及")
@@ -122,14 +119,12 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
# 检查bot主名字 # 检查bot主名字
if global_config.bot.nickname in message_content: if global_config.bot.nickname in message_content:
is_text_mentioned = True
mention_type = 1 mention_type = 1
logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及") logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及")
# 如果主名字没匹配,再检查别名 # 如果主名字没匹配,再检查别名
elif nicknames: elif nicknames:
for alias_name in nicknames: for alias_name in nicknames:
if alias_name in message_content: if alias_name in message_content:
is_text_mentioned = True
mention_type = 1 mention_type = 1
logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及") logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及")
break break

View File

@@ -374,7 +374,7 @@ class CacheManager:
# 简化的健康统计,不包含内存监控(因为相关属性未定义) # 简化的健康统计,不包含内存监控(因为相关属性未定义)
return { return {
"l1_count": len(self.l1_kv_cache), "l1_count": len(self.l1_kv_cache),
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0, "l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0,
"tool_stats": { "tool_stats": {
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0), "total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})), "tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
@@ -397,7 +397,7 @@ class CacheManager:
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}") warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
# 检查向量索引大小 # 检查向量索引大小
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0 vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0
if isinstance(vector_count, int) and vector_count > 500: if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}") warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")

View File

@@ -539,7 +539,6 @@ class AdaptiveBatchScheduler:
def _set_cache(self, cache_key: str, result: Any) -> None: def _set_cache(self, cache_key: str, result: Any) -> None:
"""设置缓存(改进版,带大小限制和内存统计)""" """设置缓存(改进版,带大小限制和内存统计)"""
import sys
# 🔧 检查缓存大小限制 # 🔧 检查缓存大小限制
if len(self._result_cache) >= self._cache_max_size: if len(self._result_cache) >= self._cache_max_size:

View File

@@ -4,9 +4,10 @@
提供比 sys.getsizeof() 更准确的内存占用估算方法 提供比 sys.getsizeof() 更准确的内存占用估算方法
""" """
import sys
import pickle import pickle
import sys
from typing import Any from typing import Any
import numpy as np import numpy as np
@@ -44,15 +45,15 @@ def get_accurate_size(obj: Any, seen: set | None = None) -> int:
for k, v in obj.items()) for k, v in obj.items())
# 列表、元组、集合:递归计算所有元素 # 列表、元组、集合:递归计算所有元素
elif isinstance(obj, (list, tuple, set, frozenset)): elif isinstance(obj, list | tuple | set | frozenset):
size += sum(get_accurate_size(item, seen) for item in obj) size += sum(get_accurate_size(item, seen) for item in obj)
# 有 __dict__ 的对象:递归计算属性 # 有 __dict__ 的对象:递归计算属性
elif hasattr(obj, '__dict__'): elif hasattr(obj, "__dict__"):
size += get_accurate_size(obj.__dict__, seen) size += get_accurate_size(obj.__dict__, seen)
# 其他可迭代对象 # 其他可迭代对象
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): elif hasattr(obj, "__iter__") and not isinstance(obj, str | bytes | bytearray):
try: try:
size += sum(get_accurate_size(item, seen) for item in obj) size += sum(get_accurate_size(item, seen) for item in obj)
except: except:
@@ -116,7 +117,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
size = sys.getsizeof(obj) size = sys.getsizeof(obj)
# 简单类型直接返回 # 简单类型直接返回
if isinstance(obj, (int, float, bool, type(None), str, bytes, bytearray)): if isinstance(obj, int | float | bool | type(None) | str | bytes | bytearray):
return size return size
# NumPy 数组特殊处理 # NumPy 数组特殊处理
@@ -144,7 +145,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
return size return size
# 列表、元组、集合递归 # 列表、元组、集合递归
if isinstance(obj, (list, tuple, set, frozenset)): if isinstance(obj, list | tuple | set | frozenset):
items = list(obj) items = list(obj)
if sample_large and len(items) > 100: if sample_large and len(items) > 100:
# 大容器采样前50 + 中间50 + 最后50 # 大容器采样前50 + 中间50 + 最后50
@@ -162,7 +163,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
return size return size
# 有 __dict__ 的对象 # 有 __dict__ 的对象
if hasattr(obj, '__dict__'): if hasattr(obj, "__dict__"):
size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large) size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large)
return size return size

View File

@@ -2,7 +2,6 @@ import os
import shutil import shutil
import sys import sys
from datetime import datetime from datetime import datetime
from typing import Optional
import tomlkit import tomlkit
from pydantic import Field from pydantic import Field
@@ -381,7 +380,7 @@ class Config(ValidatedConfigBase):
notice: NoticeConfig = Field(..., description="Notice消息配置") notice: NoticeConfig = Field(..., description="Notice消息配置")
emoji: EmojiConfig = Field(..., description="表情配置") emoji: EmojiConfig = Field(..., description="表情配置")
expression: ExpressionConfig = Field(..., description="表达配置") expression: ExpressionConfig = Field(..., description="表达配置")
memory: Optional[MemoryConfig] = Field(default=None, description="记忆配置") memory: MemoryConfig | None = Field(default=None, description="记忆配置")
mood: MoodConfig = Field(..., description="情绪配置") mood: MoodConfig = Field(..., description="情绪配置")
reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置") reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置")
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置") chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")

View File

@@ -534,7 +534,7 @@ class _RequestExecutor:
model_name = model_info.name model_name = model_info.name
retry_interval = api_provider.retry_interval retry_interval = api_provider.retry_interval
if isinstance(e, (NetworkConnectionError, ReqAbortException)): if isinstance(e, NetworkConnectionError | ReqAbortException):
return await self._check_retry(remain_try, retry_interval, "连接异常", model_name) return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
elif isinstance(e, RespNotOkException): elif isinstance(e, RespNotOkException):
return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info) return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)

View File

@@ -100,10 +100,10 @@ class VectorStore:
# 处理额外的元数据,将 list 转换为 JSON 字符串 # 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items(): for key, value in node.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, list | dict):
import orjson import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value metadata[key] = value
else: else:
metadata[key] = str(value) metadata[key] = str(value)
@@ -149,9 +149,9 @@ class VectorStore:
"created_at": n.created_at.isoformat(), "created_at": n.created_at.isoformat(),
} }
for key, value in n.metadata.items(): for key, value in n.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, list | dict):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value # type: ignore metadata[key] = value # type: ignore
else: else:
metadata[key] = str(value) metadata[key] = str(value)

View File

@@ -4,8 +4,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import numpy as np import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -7,11 +7,12 @@
""" """
import atexit import atexit
import orjson
import os import os
import threading import threading
from typing import Any, ClassVar from typing import Any, ClassVar
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
# 获取日志记录器 # 获取日志记录器
@@ -125,7 +126,7 @@ class PluginStorage:
try: try:
with open(self.file_path, "w", encoding="utf-8") as f: with open(self.file_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')) f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8"))
self._dirty = False # 保存后重置标志 self._dirty = False # 保存后重置标志
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。") logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
except Exception as e: except Exception as e:

View File

@@ -5,12 +5,12 @@ MCP Client Manager
""" """
import asyncio import asyncio
import orjson
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import mcp.types import mcp.types
import orjson
from fastmcp.client import Client, StdioTransport, StreamableHttpTransport from fastmcp.client import Client, StdioTransport, StreamableHttpTransport
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -4,11 +4,13 @@
""" """
import time import time
from typing import Any, Optional from dataclasses import dataclass, field
from dataclasses import dataclass, asdict, field from typing import Any
import orjson import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
logger = get_logger("stream_tool_history") logger = get_logger("stream_tool_history")
@@ -18,10 +20,10 @@ class ToolCallRecord:
"""工具调用记录""" """工具调用记录"""
tool_name: str tool_name: str
args: dict[str, Any] args: dict[str, Any]
result: Optional[dict[str, Any]] = None result: dict[str, Any] | None = None
status: str = "success" # success, error, pending status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time) timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒) execution_time: float | None = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存 cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览 result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息 error_message: str = "" # 错误信息
@@ -32,9 +34,9 @@ class ToolCallRecord:
content = self.result.get("content", "") content = self.result.get("content", "")
if isinstance(content, str): if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "") self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)): elif isinstance(content, list | dict):
try: try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..." self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..."
except Exception: except Exception:
self.result_preview = str(content)[:500] + "..." self.result_preview = str(content)[:500] + "..."
else: else:
@@ -105,7 +107,7 @@ class StreamToolHistoryManager:
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}") logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""从缓存或历史记录中获取结果 """从缓存或历史记录中获取结果
Args: Args:
@@ -160,9 +162,9 @@ class StreamToolHistoryManager:
return None return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any], async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None, execution_time: float | None = None,
tool_file_path: Optional[str] = None, tool_file_path: str | None = None,
ttl: Optional[int] = None) -> None: ttl: int | None = None) -> None:
"""缓存工具调用结果 """缓存工具调用结果
Args: Args:
@@ -207,7 +209,7 @@ class StreamToolHistoryManager:
except Exception as e: except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}") logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]: async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
"""获取最近的历史记录 """获取最近的历史记录
Args: Args:
@@ -295,7 +297,7 @@ class StreamToolHistoryManager:
self._history.clear() self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除") logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""在内存历史记录中搜索缓存 """在内存历史记录中搜索缓存
Args: Args:
@@ -333,7 +335,7 @@ class StreamToolHistoryManager:
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py") return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]: def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> str | None:
"""提取语义查询参数 """提取语义查询参数
Args: Args:
@@ -370,7 +372,7 @@ class StreamToolHistoryManager:
return "" return ""
try: try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8') args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
if len(args_str) > max_length: if len(args_str) > max_length:
args_str = args_str[:max_length] + "..." args_str = args_str[:max_length] + "..."
return args_str return args_str

View File

@@ -1,5 +1,6 @@
import inspect import inspect
import time import time
from dataclasses import asdict
from typing import Any from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
@@ -10,8 +11,7 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_stream_tool_history_manager
from dataclasses import asdict
logger = get_logger("tool_use") logger = get_logger("tool_use")
@@ -338,9 +338,8 @@ class ToolExecutor:
if tool_instance and result and tool_instance.enable_cache: if tool_instance and result and tool_instance.enable_cache:
try: try:
tool_file_path = inspect.getfile(tool_instance.__class__) tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key: if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key) function_args.get(tool_instance.semantic_cache_query_key)
await self.history_manager.cache_result( await self.history_manager.cache_result(
tool_name=tool_call.func_name, tool_name=tool_call.func_name,

View File

@@ -217,7 +217,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
return 0.0 return 0.0
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数") logger.warning("⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数")
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分 return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
except Exception as e: except Exception as e:
logger.warning(f"智能兴趣匹配失败: {e}") logger.warning(f"智能兴趣匹配失败: {e}")

View File

@@ -4,10 +4,10 @@ AffinityFlow Chatter 规划器模块
包含计划生成、过滤、执行等规划相关功能 包含计划生成、过滤、执行等规划相关功能
""" """
from . import planner_prompts
from .plan_executor import ChatterPlanExecutor from .plan_executor import ChatterPlanExecutor
from .plan_filter import ChatterPlanFilter from .plan_filter import ChatterPlanFilter
from .plan_generator import ChatterPlanGenerator from .plan_generator import ChatterPlanGenerator
from .planner import ChatterActionPlanner from .planner import ChatterActionPlanner
from . import planner_prompts
__all__ = ["ChatterActionPlanner", "planner_prompts", "ChatterPlanGenerator", "ChatterPlanFilter", "ChatterPlanExecutor"] __all__ = ["ChatterActionPlanner", "ChatterPlanExecutor", "ChatterPlanFilter", "ChatterPlanGenerator", "planner_prompts"]

View File

@@ -14,9 +14,7 @@ from json_repair import repair_json
# 旧的Hippocampus系统已被移除现在使用增强记忆系统 # 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_actions,
build_readable_messages_with_id, build_readable_messages_with_id,
get_actions_by_timestamp_with_chat,
) )
from src.chat.utils.prompt import global_prompt_manager from src.chat.utils.prompt import global_prompt_manager
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan from src.common.data_models.info_data_model import ActionPlannerInfo, Plan

View File

@@ -21,7 +21,6 @@ if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
# 导入提示词模块以确保其被初始化 # 导入提示词模块以确保其被初始化
from src.plugins.built_in.affinity_flow_chatter.planner import planner_prompts
logger = get_logger("planner") logger = get_logger("planner")

View File

@@ -9,9 +9,9 @@ from .proactive_thinking_executor import execute_proactive_thinking
from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler
__all__ = [ __all__ = [
"ProactiveThinkingReplyHandler",
"ProactiveThinkingMessageHandler", "ProactiveThinkingMessageHandler",
"execute_proactive_thinking", "ProactiveThinkingReplyHandler",
"ProactiveThinkingScheduler", "ProactiveThinkingScheduler",
"execute_proactive_thinking",
"proactive_thinking_scheduler", "proactive_thinking_scheduler",
] ]

View File

@@ -3,7 +3,6 @@
当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复 当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复
""" """
import orjson
from datetime import datetime from datetime import datetime
from typing import Any, Literal from typing import Any, Literal

View File

@@ -14,7 +14,6 @@ from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import config_api, generator_api, llm_api from src.plugin_system.apis import config_api, generator_api, llm_api

View File

@@ -136,10 +136,10 @@ class QZoneService:
logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}") logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}")
api_client = await self._get_api_client(qq_account, stream_id) api_client = await self._get_api_client(qq_account, stream_id)
if not api_client: if not api_client:
logger.error(f"[DEBUG] API客户端获取失败返回错误") logger.error("[DEBUG] API客户端获取失败返回错误")
return {"success": False, "message": "获取QZone API客户端失败"} return {"success": False, "message": "获取QZone API客户端失败"}
logger.info(f"[DEBUG] API客户端获取成功准备读取说说") logger.info("[DEBUG] API客户端获取成功准备读取说说")
num_to_read = self.get_config("read.read_number", 5) num_to_read = self.get_config("read.read_number", 5)
# 尝试执行如果Cookie失效则自动重试一次 # 尝试执行如果Cookie失效则自动重试一次
@@ -186,7 +186,7 @@ class QZoneService:
# 检查是否是Cookie失效-3000错误 # 检查是否是Cookie失效-3000错误
if "错误码: -3000" in error_msg and retry_count == 0: if "错误码: -3000" in error_msg and retry_count == 0:
logger.warning(f"检测到Cookie失效-3000错误准备删除缓存并重试...") logger.warning("检测到Cookie失效-3000错误准备删除缓存并重试...")
# 删除Cookie缓存文件 # 删除Cookie缓存文件
cookie_file = self.cookie_service._get_cookie_file_path(qq_account) cookie_file = self.cookie_service._get_cookie_file_path(qq_account)
@@ -623,7 +623,7 @@ class QZoneService:
logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}") logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
return None return None
logger.info(f"[DEBUG] p_skey获取成功") logger.info("[DEBUG] p_skey获取成功")
gtk = self._generate_gtk(p_skey) gtk = self._generate_gtk(p_skey)
uin = cookies.get("uin", "").lstrip("o") uin = cookies.get("uin", "").lstrip("o")
@@ -1230,7 +1230,7 @@ class QZoneService:
logger.error(f"监控好友动态失败: {e}", exc_info=True) logger.error(f"监控好友动态失败: {e}", exc_info=True)
return [] return []
logger.info(f"[DEBUG] API客户端构造完成返回包含6个方法的字典") logger.info("[DEBUG] API客户端构造完成返回包含6个方法的字典")
return { return {
"publish": _publish, "publish": _publish,
"list_feeds": _list_feeds, "list_feeds": _list_feeds,

View File

@@ -3,11 +3,12 @@
负责记录和管理已回复过的评论ID避免重复回复 负责记录和管理已回复过的评论ID避免重复回复
""" """
import orjson
import time import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("MaiZone.ReplyTrackerService") logger = get_logger("MaiZone.ReplyTrackerService")
@@ -117,8 +118,8 @@ class ReplyTrackerService:
temp_file = self.reply_record_file.with_suffix(".tmp") temp_file = self.reply_record_file.with_suffix(".tmp")
# 先写入临时文件 # 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f: with open(temp_file, "w", encoding="utf-8"):
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8') orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8")
# 如果写入成功,重命名为正式文件 # 如果写入成功,重命名为正式文件
if temp_file.stat().st_size > 0: # 确保写入成功 if temp_file.stat().st_size > 0: # 确保写入成功

View File

@@ -1,7 +1,6 @@
import orjson import orjson
import random import random
import time import time
import random
import websockets as Server import websockets as Server
import uuid import uuid
from maim_message import ( from maim_message import (
@@ -205,7 +204,7 @@ class SendHandler:
# 发送响应回MoFox-Bot # 发送响应回MoFox-Bot
logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}") logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}")
await self.send_adapter_command_response(raw_message_base, response, request_id) await self.send_adapter_command_response(raw_message_base, response, request_id)
logger.debug(f"[DEBUG handle_adapter_command] send_adapter_command_response调用完成") logger.debug("[DEBUG handle_adapter_command] send_adapter_command_response调用完成")
if response.get("status") == "ok": if response.get("status") == "ok":
logger.info(f"适配器命令 {action} 执行成功") logger.info(f"适配器命令 {action} 执行成功")

View File

@@ -1,10 +1,10 @@
""" """
Metaso Search Engine (Chat Completions Mode) Metaso Search Engine (Chat Completions Mode)
""" """
import orjson
from typing import Any from typing import Any
import httpx import httpx
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api

View File

@@ -3,9 +3,10 @@ Serper search engine implementation
Google Search via Serper.dev API Google Search via Serper.dev API
""" """
import aiohttp
from typing import Any from typing import Any
import aiohttp
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
""" """
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool from .tools.url_parser import URLParserTool

View File

@@ -113,7 +113,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
search_tasks.append(engine.answer_search(custom_args)) search_tasks.append(engine.answer_search(custom_args))
else: else:
search_tasks.append(engine.search(custom_args)) search_tasks.append(engine.search(custom_args))
@@ -162,7 +162,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索fallback策略") logger.info("使用Exa答案模式进行搜索fallback策略")
results = await engine.answer_search(custom_args) results = await engine.answer_search(custom_args)
else: else:
@@ -195,7 +195,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索") logger.info("使用Exa答案模式进行搜索")
results = await engine.answer_search(custom_args) results = await engine.answer_search(custom_args)
else: else:

View File

@@ -14,7 +14,7 @@ sys.path.insert(0, str(project_root))
from tools.memory_visualizer.visualizer_server import run_server from tools.memory_visualizer.visualizer_server import run_server
if __name__ == '__main__': if __name__ == "__main__":
print("=" * 60) print("=" * 60)
print("🦊 MoFox Bot - 记忆图可视化工具") print("🦊 MoFox Bot - 记忆图可视化工具")
print("=" * 60) print("=" * 60)
@@ -27,7 +27,7 @@ if __name__ == '__main__':
try: try:
run_server( run_server(
host='127.0.0.1', host="127.0.0.1",
port=5000, port=5000,
debug=True debug=True
) )

View File

@@ -15,7 +15,7 @@ from pathlib import Path
project_root = Path(__file__).parent project_root = Path(__file__).parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
if __name__ == '__main__': if __name__ == "__main__":
print("=" * 70) print("=" * 70)
print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)") print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)")
print("=" * 70) print("=" * 70)
@@ -29,7 +29,7 @@ if __name__ == '__main__':
try: try:
from tools.memory_visualizer.visualizer_simple import run_server from tools.memory_visualizer.visualizer_simple import run_server
run_server(host='127.0.0.1', port=5001, debug=True) run_server(host="127.0.0.1", port=5001, debug=True)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n👋 服务器已停止") print("\n\n👋 服务器已停止")
except Exception as e: except Exception as e:

View File

@@ -11,7 +11,6 @@ import logging
import sys import sys
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional
from flask import Flask, jsonify, render_template, request from flask import Flask, jsonify, render_template, request
from flask_cors import CORS from flask_cors import CORS
@@ -28,7 +27,7 @@ app = Flask(__name__)
CORS(app) # 允许跨域请求 CORS(app) # 允许跨域请求
# 全局记忆管理器 # 全局记忆管理器
memory_manager: Optional[MemoryManager] = None memory_manager: MemoryManager | None = None
def init_memory_manager(): def init_memory_manager():
@@ -189,7 +188,7 @@ def search_memories():
init_memory_manager() init_memory_manager()
query = request.args.get("q", "") query = request.args.get("q", "")
memory_type = request.args.get("type", None) request.args.get("type", None)
limit = int(request.args.get("limit", 50)) limit = int(request.args.get("limit", 50))
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()

View File

@@ -4,20 +4,18 @@
直接从存储的数据文件生成可视化,无需启动完整的记忆管理器 直接从存储的数据文件生成可视化,无需启动完整的记忆管理器
""" """
import orjson
import sys import sys
from pathlib import Path
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Set
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set from typing import Any
import orjson
# 添加项目根目录 # 添加项目根目录
project_root = Path(__file__).parent.parent.parent project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
from flask import Flask, jsonify, render_template_string, request, send_from_directory from flask import Flask, jsonify, render_template_string, request
from flask_cors import CORS from flask_cors import CORS
app = Flask(__name__) app = Flask(__name__)
@@ -29,7 +27,7 @@ data_dir = project_root / "data" / "memory_graph"
current_data_file = None # 当前选择的数据文件 current_data_file = None # 当前选择的数据文件
def find_available_data_files() -> List[Path]: def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件""" """查找所有可用的记忆图数据文件"""
files = [] files = []
@@ -74,7 +72,7 @@ def find_available_data_files() -> List[Path]:
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True) return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]: def load_graph_data(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据""" """从磁盘加载图数据"""
global graph_data_cache, current_data_file global graph_data_cache, current_data_file
@@ -94,7 +92,7 @@ def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
# 尝试查找可用的数据文件 # 尝试查找可用的数据文件
available_files = find_available_data_files() available_files = find_available_data_files()
if not available_files: if not available_files:
print(f"⚠️ 未找到任何图数据文件") print("⚠️ 未找到任何图数据文件")
print(f"📂 搜索目录: {data_dir}") print(f"📂 搜索目录: {data_dir}")
return { return {
"nodes": [], "nodes": [],
@@ -121,7 +119,7 @@ def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
} }
print(f"📂 加载图数据: {graph_file}") print(f"📂 加载图数据: {graph_file}")
with open(graph_file, 'r', encoding='utf-8') as f: with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read()) data = orjson.loads(f.read())
# 解析数据 # 解析数据
@@ -139,64 +137,64 @@ def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
# 处理节点 # 处理节点
for node in nodes: for node in nodes:
node_id = node.get('id', '') node_id = node.get("id", "")
if node_id and node_id not in nodes_dict: if node_id and node_id not in nodes_dict:
memory_ids = node.get('metadata', {}).get('memory_ids', []) memory_ids = node.get("metadata", {}).get("memory_ids", [])
nodes_dict[node_id] = { nodes_dict[node_id] = {
'id': node_id, "id": node_id,
'label': node.get('content', ''), "label": node.get("content", ""),
'type': node.get('node_type', ''), "type": node.get("node_type", ""),
'group': extract_group_from_type(node.get('node_type', '')), "group": extract_group_from_type(node.get("node_type", "")),
'title': f"{node.get('node_type', '')}: {node.get('content', '')}", "title": f"{node.get('node_type', '')}: {node.get('content', '')}",
'metadata': node.get('metadata', {}), "metadata": node.get("metadata", {}),
'created_at': node.get('created_at', ''), "created_at": node.get("created_at", ""),
'memory_ids': memory_ids, "memory_ids": memory_ids,
} }
# 处理边 - 使用集合去重避免重复的边ID # 处理边 - 使用集合去重避免重复的边ID
existing_edge_ids = set() existing_edge_ids = set()
for edge in edges: for edge in edges:
# 边的ID字段可能是 'id' 或 'edge_id' # 边的ID字段可能是 'id' 或 'edge_id'
edge_id = edge.get('edge_id') or edge.get('id', '') edge_id = edge.get("edge_id") or edge.get("id", "")
# 如果ID为空或已存在跳过这条边 # 如果ID为空或已存在跳过这条边
if not edge_id or edge_id in existing_edge_ids: if not edge_id or edge_id in existing_edge_ids:
continue continue
existing_edge_ids.add(edge_id) existing_edge_ids.add(edge_id)
memory_id = edge.get('metadata', {}).get('memory_id', '') memory_id = edge.get("metadata", {}).get("memory_id", "")
# 注意: GraphStore 保存的格式使用 'source'/'target', 不是 'source_id'/'target_id' # 注意: GraphStore 保存的格式使用 'source'/'target', 不是 'source_id'/'target_id'
edges_list.append({ edges_list.append({
'id': edge_id, "id": edge_id,
'from': edge.get('source', edge.get('source_id', '')), "from": edge.get("source", edge.get("source_id", "")),
'to': edge.get('target', edge.get('target_id', '')), "to": edge.get("target", edge.get("target_id", "")),
'label': edge.get('relation', ''), "label": edge.get("relation", ""),
'type': edge.get('edge_type', ''), "type": edge.get("edge_type", ""),
'importance': edge.get('importance', 0.5), "importance": edge.get("importance", 0.5),
'title': f"{edge.get('edge_type', '')}: {edge.get('relation', '')}", "title": f"{edge.get('edge_type', '')}: {edge.get('relation', '')}",
'arrows': 'to', "arrows": "to",
'memory_id': memory_id, "memory_id": memory_id,
}) })
# 从元数据中获取统计信息 # 从元数据中获取统计信息
stats = metadata.get('statistics', {}) stats = metadata.get("statistics", {})
total_memories = stats.get('total_memories', 0) total_memories = stats.get("total_memories", 0)
# TODO: 如果需要记忆详细信息,需要从其他地方加载 # TODO: 如果需要记忆详细信息,需要从其他地方加载
# 目前只有节点和边的数据 # 目前只有节点和边的数据
graph_data_cache = { graph_data_cache = {
'nodes': list(nodes_dict.values()), "nodes": list(nodes_dict.values()),
'edges': edges_list, "edges": edges_list,
'memories': memory_info, # 空列表,因为文件中没有记忆详情 "memories": memory_info, # 空列表,因为文件中没有记忆详情
'stats': { "stats": {
'total_nodes': len(nodes_dict), "total_nodes": len(nodes_dict),
'total_edges': len(edges_list), "total_edges": len(edges_list),
'total_memories': total_memories, "total_memories": total_memories,
}, },
'current_file': str(graph_file), "current_file": str(graph_file),
'file_size': graph_file.stat().st_size, "file_size": graph_file.stat().st_size,
'file_modified': datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(), "file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
} }
print(f"📊 统计: {len(nodes_dict)} 个节点, {len(edges_list)} 条边, {total_memories} 条记忆") print(f"📊 统计: {len(nodes_dict)} 个节点, {len(edges_list)} 条边, {total_memories} 条记忆")
@@ -214,38 +212,38 @@ def extract_group_from_type(node_type: str) -> str:
"""从节点类型提取分组名""" """从节点类型提取分组名"""
# 假设类型格式为 "主体" 或 "SUBJECT" # 假设类型格式为 "主体" 或 "SUBJECT"
type_mapping = { type_mapping = {
'主体': 'SUBJECT', "主体": "SUBJECT",
'主题': 'TOPIC', "主题": "TOPIC",
'客体': 'OBJECT', "客体": "OBJECT",
'属性': 'ATTRIBUTE', "属性": "ATTRIBUTE",
'': 'VALUE', "": "VALUE",
} }
return type_mapping.get(node_type, node_type) return type_mapping.get(node_type, node_type)
def generate_memory_text(memory: Dict[str, Any]) -> str: def generate_memory_text(memory: dict[str, Any]) -> str:
"""生成记忆的文本描述""" """生成记忆的文本描述"""
try: try:
nodes = {n['id']: n for n in memory.get('nodes', [])} nodes = {n["id"]: n for n in memory.get("nodes", [])}
edges = memory.get('edges', []) edges = memory.get("edges", [])
subject_id = memory.get('subject_id', '') subject_id = memory.get("subject_id", "")
if not subject_id or subject_id not in nodes: if not subject_id or subject_id not in nodes:
return f"[记忆 {memory.get('id', '')[:8]}]" return f"[记忆 {memory.get('id', '')[:8]}]"
parts = [nodes[subject_id]['content']] parts = [nodes[subject_id]["content"]]
# 找主题节点 # 找主题节点
for edge in edges: for edge in edges:
if edge.get('edge_type') == '记忆类型' and edge.get('source_id') == subject_id: if edge.get("edge_type") == "记忆类型" and edge.get("source_id") == subject_id:
topic_id = edge.get('target_id', '') topic_id = edge.get("target_id", "")
if topic_id in nodes: if topic_id in nodes:
parts.append(nodes[topic_id]['content']) parts.append(nodes[topic_id]["content"])
# 找客体 # 找客体
for e2 in edges: for e2 in edges:
if e2.get('edge_type') == '核心关系' and e2.get('source_id') == topic_id: if e2.get("edge_type") == "核心关系" and e2.get("source_id") == topic_id:
obj_id = e2.get('target_id', '') obj_id = e2.get("target_id", "")
if obj_id in nodes: if obj_id in nodes:
parts.append(f"{e2.get('relation', '')} {nodes[obj_id]['content']}") parts.append(f"{e2.get('relation', '')} {nodes[obj_id]['content']}")
break break
@@ -257,86 +255,86 @@ def generate_memory_text(memory: Dict[str, Any]) -> str:
# 使用内嵌的HTML模板(与之前相同) # 使用内嵌的HTML模板(与之前相同)
HTML_TEMPLATE = open(project_root / "tools" / "memory_visualizer" / "templates" / "visualizer.html", 'r', encoding='utf-8').read() HTML_TEMPLATE = open(project_root / "tools" / "memory_visualizer" / "templates" / "visualizer.html", encoding="utf-8").read()
@app.route('/') @app.route("/")
def index(): def index():
"""主页面""" """主页面"""
return render_template_string(HTML_TEMPLATE) return render_template_string(HTML_TEMPLATE)
@app.route('/api/graph/full') @app.route("/api/graph/full")
def get_full_graph(): def get_full_graph():
"""获取完整记忆图数据""" """获取完整记忆图数据"""
try: try:
data = load_graph_data() data = load_graph_data()
return jsonify({ return jsonify({
'success': True, "success": True,
'data': data "data": data
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/memory/<memory_id>') @app.route("/api/memory/<memory_id>")
def get_memory_detail(memory_id: str): def get_memory_detail(memory_id: str):
"""获取记忆详情""" """获取记忆详情"""
try: try:
data = load_graph_data() data = load_graph_data()
memory = next((m for m in data['memories'] if m['id'] == memory_id), None) memory = next((m for m in data["memories"] if m["id"] == memory_id), None)
if memory is None: if memory is None:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': '记忆不存在' "error": "记忆不存在"
}), 404 }), 404
return jsonify({ return jsonify({
'success': True, "success": True,
'data': memory "data": memory
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/search') @app.route("/api/search")
def search_memories(): def search_memories():
"""搜索记忆""" """搜索记忆"""
try: try:
query = request.args.get('q', '').lower() query = request.args.get("q", "").lower()
limit = int(request.args.get('limit', 50)) limit = int(request.args.get("limit", 50))
data = load_graph_data() data = load_graph_data()
# 简单的文本匹配搜索 # 简单的文本匹配搜索
results = [] results = []
for memory in data['memories']: for memory in data["memories"]:
text = memory.get('text', '').lower() text = memory.get("text", "").lower()
if query in text: if query in text:
results.append(memory) results.append(memory)
return jsonify({ return jsonify({
'success': True, "success": True,
'data': { "data": {
'results': results[:limit], "results": results[:limit],
'count': len(results), "count": len(results),
} }
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/stats') @app.route("/api/stats")
def get_statistics(): def get_statistics():
"""获取统计信息""" """获取统计信息"""
try: try:
@@ -346,43 +344,43 @@ def get_statistics():
node_types = {} node_types = {}
memory_types = {} memory_types = {}
for node in data['nodes']: for node in data["nodes"]:
node_type = node.get('type', 'Unknown') node_type = node.get("type", "Unknown")
node_types[node_type] = node_types.get(node_type, 0) + 1 node_types[node_type] = node_types.get(node_type, 0) + 1
for memory in data['memories']: for memory in data["memories"]:
mem_type = memory.get('type', 'Unknown') mem_type = memory.get("type", "Unknown")
memory_types[mem_type] = memory_types.get(mem_type, 0) + 1 memory_types[mem_type] = memory_types.get(mem_type, 0) + 1
stats = data.get('stats', {}) stats = data.get("stats", {})
stats['node_types'] = node_types stats["node_types"] = node_types
stats['memory_types'] = memory_types stats["memory_types"] = memory_types
return jsonify({ return jsonify({
'success': True, "success": True,
'data': stats "data": stats
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/reload') @app.route("/api/reload")
def reload_data(): def reload_data():
"""重新加载数据""" """重新加载数据"""
global graph_data_cache global graph_data_cache
graph_data_cache = None graph_data_cache = None
data = load_graph_data() data = load_graph_data()
return jsonify({ return jsonify({
'success': True, "success": True,
'message': '数据已重新加载', "message": "数据已重新加载",
'stats': data.get('stats', {}) "stats": data.get("stats", {})
}) })
@app.route('/api/files') @app.route("/api/files")
def list_files(): def list_files():
"""列出所有可用的数据文件""" """列出所有可用的数据文件"""
try: try:
@@ -392,48 +390,48 @@ def list_files():
for f in files: for f in files:
stat = f.stat() stat = f.stat()
file_list.append({ file_list.append({
'path': str(f), "path": str(f),
'name': f.name, "name": f.name,
'size': stat.st_size, "size": stat.st_size,
'size_kb': round(stat.st_size / 1024, 2), "size_kb": round(stat.st_size / 1024, 2),
'modified': datetime.fromtimestamp(stat.st_mtime).isoformat(), "modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
'modified_readable': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S'), "modified_readable": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
'is_current': str(f) == str(current_data_file) if current_data_file else False "is_current": str(f) == str(current_data_file) if current_data_file else False
}) })
return jsonify({ return jsonify({
'success': True, "success": True,
'files': file_list, "files": file_list,
'count': len(file_list), "count": len(file_list),
'current_file': str(current_data_file) if current_data_file else None "current_file": str(current_data_file) if current_data_file else None
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/select_file', methods=['POST']) @app.route("/api/select_file", methods=["POST"])
def select_file(): def select_file():
"""选择要加载的数据文件""" """选择要加载的数据文件"""
global graph_data_cache, current_data_file global graph_data_cache, current_data_file
try: try:
data = request.get_json() data = request.get_json()
file_path = data.get('file_path') file_path = data.get("file_path")
if not file_path: if not file_path:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': '未提供文件路径' "error": "未提供文件路径"
}), 400 }), 400
file_path = Path(file_path) file_path = Path(file_path)
if not file_path.exists(): if not file_path.exists():
return jsonify({ return jsonify({
'success': False, "success": False,
'error': f'文件不存在: {file_path}' "error": f"文件不存在: {file_path}"
}), 404 }), 404
# 清除缓存并加载新文件 # 清除缓存并加载新文件
@@ -442,18 +440,18 @@ def select_file():
graph_data = load_graph_data(file_path) graph_data = load_graph_data(file_path)
return jsonify({ return jsonify({
'success': True, "success": True,
'message': f'已切换到文件: {file_path.name}', "message": f"已切换到文件: {file_path.name}",
'stats': graph_data.get('stats', {}) "stats": graph_data.get("stats", {})
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
def run_server(host: str = '127.0.0.1', port: int = 5001, debug: bool = False): def run_server(host: str = "127.0.0.1", port: int = 5001, debug: bool = False):
"""启动服务器""" """启动服务器"""
print("=" * 60) print("=" * 60)
print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)") print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)")
@@ -470,7 +468,7 @@ def run_server(host: str = '127.0.0.1', port: int = 5001, debug: bool = False):
app.run(host=host, port=port, debug=debug) app.run(host=host, port=port, debug=debug)
if __name__ == '__main__': if __name__ == "__main__":
try: try:
run_server(debug=True) run_server(debug=True)
except KeyboardInterrupt: except KeyboardInterrupt: