This commit is contained in:
明天好像没什么
2025-11-07 21:01:45 +08:00
parent 80b040da2f
commit c8d7c09625
49 changed files with 854 additions and 872 deletions

View File

@@ -17,19 +17,19 @@
import argparse import argparse
import json import json
from pathlib import Path
from typing import Dict, Any, List, Tuple
import logging import logging
from pathlib import Path
from typing import Any
import orjson import orjson
# 配置日志 # 配置日志
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[ handlers=[
logging.StreamHandler(), logging.StreamHandler(),
logging.FileHandler('embedding_cleanup.log', encoding='utf-8') logging.FileHandler("embedding_cleanup.log", encoding="utf-8")
] ]
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,13 +49,13 @@ class EmbeddingCleaner:
self.cleaned_files = [] self.cleaned_files = []
self.errors = [] self.errors = []
self.stats = { self.stats = {
'files_processed': 0, "files_processed": 0,
'embedings_removed': 0, "embedings_removed": 0,
'bytes_saved': 0, "bytes_saved": 0,
'nodes_processed': 0 "nodes_processed": 0
} }
def find_json_files(self) -> List[Path]: def find_json_files(self) -> list[Path]:
"""查找可能包含向量数据的 JSON 文件""" """查找可能包含向量数据的 JSON 文件"""
json_files = [] json_files = []
@@ -65,7 +65,7 @@ class EmbeddingCleaner:
json_files.append(memory_graph_file) json_files.append(memory_graph_file)
# 测试数据文件 # 测试数据文件
test_dir = self.data_dir / "test_*" self.data_dir / "test_*"
for test_path in self.data_dir.glob("test_*/memory_graph.json"): for test_path in self.data_dir.glob("test_*/memory_graph.json"):
if test_path.exists(): if test_path.exists():
json_files.append(test_path) json_files.append(test_path)
@@ -82,7 +82,7 @@ class EmbeddingCleaner:
logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件") logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件")
return json_files return json_files
def analyze_embedding_in_data(self, data: Dict[str, Any]) -> int: def analyze_embedding_in_data(self, data: dict[str, Any]) -> int:
""" """
分析数据中的 embedding 字段数量 分析数据中的 embedding 字段数量
@@ -97,7 +97,7 @@ class EmbeddingCleaner:
def count_embeddings(obj): def count_embeddings(obj):
nonlocal embedding_count nonlocal embedding_count
if isinstance(obj, dict): if isinstance(obj, dict):
if 'embedding' in obj: if "embedding" in obj:
embedding_count += 1 embedding_count += 1
for value in obj.values(): for value in obj.values():
count_embeddings(value) count_embeddings(value)
@@ -108,7 +108,7 @@ class EmbeddingCleaner:
count_embeddings(data) count_embeddings(data)
return embedding_count return embedding_count
def clean_embedding_from_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: def clean_embedding_from_data(self, data: dict[str, Any]) -> tuple[dict[str, Any], int]:
""" """
从数据中移除 embedding 字段 从数据中移除 embedding 字段
@@ -123,8 +123,8 @@ class EmbeddingCleaner:
def remove_embeddings(obj): def remove_embeddings(obj):
nonlocal removed_count nonlocal removed_count
if isinstance(obj, dict): if isinstance(obj, dict):
if 'embedding' in obj: if "embedding" in obj:
del obj['embedding'] del obj["embedding"]
removed_count += 1 removed_count += 1
for value in obj.values(): for value in obj.values():
remove_embeddings(value) remove_embeddings(value)
@@ -162,14 +162,14 @@ class EmbeddingCleaner:
data = orjson.loads(original_content) data = orjson.loads(original_content)
except orjson.JSONDecodeError: except orjson.JSONDecodeError:
# 回退到标准 json # 回退到标准 json
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# 分析 embedding 数据 # 分析 embedding 数据
embedding_count = self.analyze_embedding_in_data(data) embedding_count = self.analyze_embedding_in_data(data)
if embedding_count == 0: if embedding_count == 0:
logger.info(f" ✓ 文件中没有 embedding 数据,跳过") logger.info(" ✓ 文件中没有 embedding 数据,跳过")
return True return True
logger.info(f" 发现 {embedding_count} 个 embedding 字段") logger.info(f" 发现 {embedding_count} 个 embedding 字段")
@@ -193,30 +193,30 @@ class EmbeddingCleaner:
cleaned_data, cleaned_data,
indent=2, indent=2,
ensure_ascii=False ensure_ascii=False
).encode('utf-8') ).encode("utf-8")
cleaned_size = len(cleaned_content) cleaned_size = len(cleaned_content)
bytes_saved = original_size - cleaned_size bytes_saved = original_size - cleaned_size
# 原子写入 # 原子写入
temp_file = file_path.with_suffix('.tmp') temp_file = file_path.with_suffix(".tmp")
temp_file.write_bytes(cleaned_content) temp_file.write_bytes(cleaned_content)
temp_file.replace(file_path) temp_file.replace(file_path)
logger.info(f" ✓ 清理完成:") logger.info(" ✓ 清理完成:")
logger.info(f" - 移除 embedding 字段: {removed_count}") logger.info(f" - 移除 embedding 字段: {removed_count}")
logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)") logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)")
logger.info(f" - 新文件大小: {cleaned_size:,} 字节") logger.info(f" - 新文件大小: {cleaned_size:,} 字节")
# 更新统计 # 更新统计
self.stats['embedings_removed'] += removed_count self.stats["embedings_removed"] += removed_count
self.stats['bytes_saved'] += bytes_saved self.stats["bytes_saved"] += bytes_saved
else: else:
logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段") logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段")
self.stats['embedings_removed'] += embedding_count self.stats["embedings_removed"] += embedding_count
self.stats['files_processed'] += 1 self.stats["files_processed"] += 1
self.cleaned_files.append(file_path) self.cleaned_files.append(file_path)
return True return True
@@ -236,12 +236,12 @@ class EmbeddingCleaner:
节点数量 节点数量
""" """
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
node_count = 0 node_count = 0
if 'nodes' in data and isinstance(data['nodes'], list): if "nodes" in data and isinstance(data["nodes"], list):
node_count = len(data['nodes']) node_count = len(data["nodes"])
return node_count return node_count
@@ -268,7 +268,7 @@ class EmbeddingCleaner:
# 统计总节点数 # 统计总节点数
total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files) total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files)
self.stats['nodes_processed'] = total_nodes self.stats["nodes_processed"] = total_nodes
logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点") logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点")
@@ -295,8 +295,8 @@ class EmbeddingCleaner:
if not dry_run: if not dry_run:
logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节") logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节")
if self.stats['bytes_saved'] > 0: if self.stats["bytes_saved"] > 0:
mb_saved = self.stats['bytes_saved'] / 1024 / 1024 mb_saved = self.stats["bytes_saved"] / 1024 / 1024
logger.info(f"节省空间: {mb_saved:.2f} MB") logger.info(f"节省空间: {mb_saved:.2f} MB")
if self.errors: if self.errors:
@@ -342,7 +342,7 @@ def main():
print(" 请确保向量数据库正在正常工作。") print(" 请确保向量数据库正在正常工作。")
print() print()
response = input("确认继续?(yes/no): ") response = input("确认继续?(yes/no): ")
if response.lower() not in ['yes', 'y', '']: if response.lower() not in ["yes", "y", ""]:
print("操作已取消") print("操作已取消")
return return
@@ -352,4 +352,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@@ -10,10 +10,10 @@
示例: 示例:
# 进程监控(启动 bot 并监控) # 进程监控(启动 bot 并监控)
python scripts/memory_profiler.py --monitor --interval 10 python scripts/memory_profiler.py --monitor --interval 10
# 对象分析(深度对象统计) # 对象分析(深度对象统计)
python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt
# 生成可视化图表 # 生成可视化图表
python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15 python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15
""" """
@@ -22,7 +22,6 @@ import argparse
import asyncio import asyncio
import gc import gc
import json import json
import os
import subprocess import subprocess
import sys import sys
import threading import threading
@@ -30,7 +29,6 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional
import psutil import psutil
@@ -56,29 +54,29 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
if bot_process.pid is None: if bot_process.pid is None:
print("❌ Bot 进程 PID 为空") print("❌ Bot 进程 PID 为空")
return return
print(f"🔍 开始监控 Bot 内存PID: {bot_process.pid}") print(f"🔍 开始监控 Bot 内存PID: {bot_process.pid}")
print(f"监控间隔: {interval}") print(f"监控间隔: {interval}")
print("按 Ctrl+C 停止监控和 Bot\n") print("按 Ctrl+C 停止监控和 Bot\n")
try: try:
process = psutil.Process(bot_process.pid) process = psutil.Process(bot_process.pid)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
print("❌ 无法找到 Bot 进程") print("❌ 无法找到 Bot 进程")
return return
history = [] history = []
iteration = 0 iteration = 0
try: try:
while bot_process.poll() is None: while bot_process.poll() is None:
try: try:
mem_info = process.memory_info() mem_info = process.memory_info()
mem_percent = process.memory_percent() mem_percent = process.memory_percent()
children = process.children(recursive=True) children = process.children(recursive=True)
children_mem = sum(child.memory_info().rss for child in children) children_mem = sum(child.memory_info().rss for child in children)
info = { info = {
"timestamp": time.strftime("%H:%M:%S"), "timestamp": time.strftime("%H:%M:%S"),
"rss_mb": mem_info.rss / 1024 / 1024, "rss_mb": mem_info.rss / 1024 / 1024,
@@ -87,24 +85,24 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
"children_count": len(children), "children_count": len(children),
"children_mem_mb": children_mem / 1024 / 1024, "children_mem_mb": children_mem / 1024 / 1024,
} }
history.append(info) history.append(info)
iteration += 1 iteration += 1
print(f"{'=' * 80}") print(f"{'=' * 80}")
print(f"检查点 #{iteration} - {info['timestamp']}") print(f"检查点 #{iteration} - {info['timestamp']}")
print(f"Bot 进程 (PID: {bot_process.pid})") print(f"Bot 进程 (PID: {bot_process.pid})")
print(f" RSS: {info['rss_mb']:.2f} MB") print(f" RSS: {info['rss_mb']:.2f} MB")
print(f" VMS: {info['vms_mb']:.2f} MB") print(f" VMS: {info['vms_mb']:.2f} MB")
print(f" 占比: {info['percent']:.2f}%") print(f" 占比: {info['percent']:.2f}%")
if children: if children:
print(f" 子进程: {info['children_count']}") print(f" 子进程: {info['children_count']}")
print(f" 子进程内存: {info['children_mem_mb']:.2f} MB") print(f" 子进程内存: {info['children_mem_mb']:.2f} MB")
total_mem = info['rss_mb'] + info['children_mem_mb'] total_mem = info["rss_mb"] + info["children_mem_mb"]
print(f" 总内存: {total_mem:.2f} MB") print(f" 总内存: {total_mem:.2f} MB")
print(f"\n 📋 子进程详情:") print("\n 📋 子进程详情:")
for idx, child in enumerate(children, 1): for idx, child in enumerate(children, 1):
try: try:
child_mem = child.memory_info().rss / 1024 / 1024 child_mem = child.memory_info().rss / 1024 / 1024
@@ -116,30 +114,30 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
print(f" 命令: {child_cmdline}") print(f" 命令: {child_cmdline}")
except (psutil.NoSuchProcess, psutil.AccessDenied): except (psutil.NoSuchProcess, psutil.AccessDenied):
print(f" [{idx}] 无法访问进程信息") print(f" [{idx}] 无法访问进程信息")
if len(history) > 1: if len(history) > 1:
prev = history[-2] prev = history[-2]
rss_diff = info['rss_mb'] - prev['rss_mb'] rss_diff = info["rss_mb"] - prev["rss_mb"]
print(f"\n变化:") print("\n变化:")
print(f" RSS: {rss_diff:+.2f} MB") print(f" RSS: {rss_diff:+.2f} MB")
if rss_diff > 10: if rss_diff > 10:
print(f" ⚠️ 内存增长较快!") print(" ⚠️ 内存增长较快!")
if info['rss_mb'] > 1000: if info["rss_mb"] > 1000:
print(f" ⚠️ 内存使用超过 1GB") print(" ⚠️ 内存使用超过 1GB")
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
await asyncio.sleep(interval) await asyncio.sleep(interval)
except psutil.NoSuchProcess: except psutil.NoSuchProcess:
print("\n❌ Bot 进程已结束") print("\n❌ Bot 进程已结束")
break break
except Exception as e: except Exception as e:
print(f"\n❌ 监控出错: {e}") print(f"\n❌ 监控出错: {e}")
break break
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n⚠️ 用户中断监控") print("\n\n⚠️ 用户中断监控")
finally: finally:
if history and bot_process.pid: if history and bot_process.pid:
save_process_history(history, bot_process.pid) save_process_history(history, bot_process.pid)
@@ -149,25 +147,25 @@ def save_process_history(history: list, pid: int):
"""保存进程监控历史""" """保存进程监控历史"""
output_dir = Path("data/memory_diagnostics") output_dir = Path("data/memory_diagnostics")
output_dir.mkdir(parents=True, exist_ok=True) output_dir.mkdir(parents=True, exist_ok=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
output_file = output_dir / f"process_monitor_{timestamp}_pid{pid}.txt" output_file = output_dir / f"process_monitor_{timestamp}_pid{pid}.txt"
with open(output_file, "w", encoding="utf-8") as f: with open(output_file, "w", encoding="utf-8") as f:
f.write("Bot 进程内存监控历史记录\n") f.write("Bot 进程内存监控历史记录\n")
f.write("=" * 80 + "\n\n") f.write("=" * 80 + "\n\n")
f.write(f"Bot PID: {pid}\n\n") f.write(f"Bot PID: {pid}\n\n")
for info in history: for info in history:
f.write(f"时间: {info['timestamp']}\n") f.write(f"时间: {info['timestamp']}\n")
f.write(f"RSS: {info['rss_mb']:.2f} MB\n") f.write(f"RSS: {info['rss_mb']:.2f} MB\n")
f.write(f"VMS: {info['vms_mb']:.2f} MB\n") f.write(f"VMS: {info['vms_mb']:.2f} MB\n")
f.write(f"占比: {info['percent']:.2f}%\n") f.write(f"占比: {info['percent']:.2f}%\n")
if info['children_count'] > 0: if info["children_count"] > 0:
f.write(f"子进程: {info['children_count']}\n") f.write(f"子进程: {info['children_count']}\n")
f.write(f"子进程内存: {info['children_mem_mb']:.2f} MB\n") f.write(f"子进程内存: {info['children_mem_mb']:.2f} MB\n")
f.write("\n") f.write("\n")
print(f"\n✅ 监控历史已保存到: {output_file}") print(f"\n✅ 监控历史已保存到: {output_file}")
@@ -182,28 +180,28 @@ async def run_monitor_mode(interval: int):
print(" 3. 显示子进程详细信息") print(" 3. 显示子进程详细信息")
print(" 4. 自动保存监控历史") print(" 4. 自动保存监控历史")
print("=" * 80 + "\n") print("=" * 80 + "\n")
project_root = Path(__file__).parent.parent project_root = Path(__file__).parent.parent
bot_file = project_root / "bot.py" bot_file = project_root / "bot.py"
if not bot_file.exists(): if not bot_file.exists():
print(f"❌ 找不到 bot.py: {bot_file}") print(f"❌ 找不到 bot.py: {bot_file}")
return 1 return 1
# 检测虚拟环境 # 检测虚拟环境
venv_python = project_root / ".venv" / "Scripts" / "python.exe" venv_python = project_root / ".venv" / "Scripts" / "python.exe"
if not venv_python.exists(): if not venv_python.exists():
venv_python = project_root / ".venv" / "bin" / "python" venv_python = project_root / ".venv" / "bin" / "python"
if venv_python.exists(): if venv_python.exists():
python_exe = str(venv_python) python_exe = str(venv_python)
print(f"🐍 使用虚拟环境: {venv_python}") print(f"🐍 使用虚拟环境: {venv_python}")
else: else:
python_exe = sys.executable python_exe = sys.executable
print(f"⚠️ 未找到虚拟环境,使用当前 Python: {python_exe}") print(f"⚠️ 未找到虚拟环境,使用当前 Python: {python_exe}")
print(f"🤖 启动 Bot: {bot_file}") print(f"🤖 启动 Bot: {bot_file}")
bot_process = subprocess.Popen( bot_process = subprocess.Popen(
[python_exe, str(bot_file)], [python_exe, str(bot_file)],
cwd=str(project_root), cwd=str(project_root),
@@ -212,9 +210,9 @@ async def run_monitor_mode(interval: int):
text=True, text=True,
bufsize=1, bufsize=1,
) )
await asyncio.sleep(2) await asyncio.sleep(2)
if bot_process.poll() is not None: if bot_process.poll() is not None:
print("❌ Bot 启动失败") print("❌ Bot 启动失败")
if bot_process.stdout: if bot_process.stdout:
@@ -222,9 +220,9 @@ async def run_monitor_mode(interval: int):
if output: if output:
print(f"\nBot 输出:\n{output}") print(f"\nBot 输出:\n{output}")
return 1 return 1
print(f"✅ Bot 已启动 (PID: {bot_process.pid})\n") print(f"✅ Bot 已启动 (PID: {bot_process.pid})\n")
# 启动输出读取线程 # 启动输出读取线程
def read_bot_output(): def read_bot_output():
if bot_process.stdout: if bot_process.stdout:
@@ -233,15 +231,15 @@ async def run_monitor_mode(interval: int):
print(f"[Bot] {line}", end="") print(f"[Bot] {line}", end="")
except Exception: except Exception:
pass pass
output_thread = threading.Thread(target=read_bot_output, daemon=True) output_thread = threading.Thread(target=read_bot_output, daemon=True)
output_thread.start() output_thread.start()
try: try:
await monitor_bot_process(bot_process, interval) await monitor_bot_process(bot_process, interval)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n⚠️ 用户中断") print("\n\n⚠️ 用户中断")
if bot_process.poll() is None: if bot_process.poll() is None:
print("\n正在停止 Bot...") print("\n正在停止 Bot...")
bot_process.terminate() bot_process.terminate()
@@ -251,9 +249,9 @@ async def run_monitor_mode(interval: int):
print("⚠️ 强制终止 Bot...") print("⚠️ 强制终止 Bot...")
bot_process.kill() bot_process.kill()
bot_process.wait() bot_process.wait()
print("✅ Bot 已停止") print("✅ Bot 已停止")
return 0 return 0
@@ -263,8 +261,8 @@ async def run_monitor_mode(interval: int):
class ObjectMemoryProfiler: class ObjectMemoryProfiler:
"""对象级内存分析器""" """对象级内存分析器"""
def __init__(self, interval: int = 10, output_file: Optional[str] = None, object_limit: int = 20): def __init__(self, interval: int = 10, output_file: str | None = None, object_limit: int = 20):
self.interval = interval self.interval = interval
self.output_file = output_file self.output_file = output_file
self.object_limit = object_limit self.object_limit = object_limit
@@ -273,23 +271,23 @@ class ObjectMemoryProfiler:
if PYMPLER_AVAILABLE: if PYMPLER_AVAILABLE:
self.tracker = tracker.SummaryTracker() self.tracker = tracker.SummaryTracker()
self.iteration = 0 self.iteration = 0
def get_object_stats(self) -> Dict: def get_object_stats(self) -> dict:
"""获取当前进程的对象统计(所有线程)""" """获取当前进程的对象统计(所有线程)"""
if not PYMPLER_AVAILABLE: if not PYMPLER_AVAILABLE:
return {} return {}
try: try:
gc.collect() gc.collect()
all_objects = muppy.get_objects() all_objects = muppy.get_objects()
sum_data = summary.summarize(all_objects) sum_data = summary.summarize(all_objects)
# 按总大小第3个元素降序排序 # 按总大小第3个元素降序排序
sorted_sum_data = sorted(sum_data, key=lambda x: x[2], reverse=True) sorted_sum_data = sorted(sum_data, key=lambda x: x[2], reverse=True)
# 按模块统计内存 # 按模块统计内存
module_stats = self._get_module_stats(all_objects) module_stats = self._get_module_stats(all_objects)
threads = threading.enumerate() threads = threading.enumerate()
thread_info = [ thread_info = [
{ {
@@ -299,13 +297,13 @@ class ObjectMemoryProfiler:
} }
for t in threads for t in threads
] ]
gc_stats = { gc_stats = {
"collections": gc.get_count(), "collections": gc.get_count(),
"garbage": len(gc.garbage), "garbage": len(gc.garbage),
"tracked": len(gc.get_objects()), "tracked": len(gc.get_objects()),
} }
return { return {
"summary": sorted_sum_data[:self.object_limit], "summary": sorted_sum_data[:self.object_limit],
"module_stats": module_stats, "module_stats": module_stats,
@@ -316,52 +314,52 @@ class ObjectMemoryProfiler:
except Exception as e: except Exception as e:
print(f"❌ 获取对象统计失败: {e}") print(f"❌ 获取对象统计失败: {e}")
return {} return {}
def _get_module_stats(self, all_objects: list) -> Dict: def _get_module_stats(self, all_objects: list) -> dict:
"""统计各模块的内存占用""" """统计各模块的内存占用"""
module_mem = defaultdict(lambda: {"count": 0, "size": 0}) module_mem = defaultdict(lambda: {"count": 0, "size": 0})
for obj in all_objects: for obj in all_objects:
try: try:
# 获取对象所属模块 # 获取对象所属模块
obj_type = type(obj) obj_type = type(obj)
module_name = obj_type.__module__ module_name = obj_type.__module__
if module_name: if module_name:
# 获取顶级模块名(例如 src.chat.xxx -> src # 获取顶级模块名(例如 src.chat.xxx -> src
top_module = module_name.split('.')[0] top_module = module_name.split(".")[0]
obj_size = sys.getsizeof(obj) obj_size = sys.getsizeof(obj)
module_mem[top_module]["count"] += 1 module_mem[top_module]["count"] += 1
module_mem[top_module]["size"] += obj_size module_mem[top_module]["size"] += obj_size
except Exception: except Exception:
# 忽略无法获取大小的对象 # 忽略无法获取大小的对象
continue continue
# 转换为列表并按大小排序 # 转换为列表并按大小排序
sorted_modules = sorted( sorted_modules = sorted(
[(mod, stats["count"], stats["size"]) [(mod, stats["count"], stats["size"])
for mod, stats in module_mem.items()], for mod, stats in module_mem.items()],
key=lambda x: x[2], key=lambda x: x[2],
reverse=True reverse=True
) )
return { return {
"top_modules": sorted_modules[:20], # 前20个模块 "top_modules": sorted_modules[:20], # 前20个模块
"total_modules": len(module_mem) "total_modules": len(module_mem)
} }
def print_stats(self, stats: Dict, iteration: int): def print_stats(self, stats: dict, iteration: int):
"""打印统计信息""" """打印统计信息"""
print("\n" + "=" * 80) print("\n" + "=" * 80)
print(f"🔍 对象级内存分析 #{iteration} - {time.strftime('%H:%M:%S')}") print(f"🔍 对象级内存分析 #{iteration} - {time.strftime('%H:%M:%S')}")
print("=" * 80) print("=" * 80)
if "summary" in stats: if "summary" in stats:
print(f"\n📦 对象统计 (前 {self.object_limit} 个类型):\n") print(f"\n📦 对象统计 (前 {self.object_limit} 个类型):\n")
print(f"{'类型':<50} {'数量':>12} {'总大小':>15}") print(f"{'类型':<50} {'数量':>12} {'总大小':>15}")
print("-" * 80) print("-" * 80)
for obj_type, obj_count, obj_size in stats["summary"]: for obj_type, obj_count, obj_size in stats["summary"]:
if obj_size >= 1024 * 1024 * 1024: if obj_size >= 1024 * 1024 * 1024:
size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB" size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB"
@@ -371,14 +369,14 @@ class ObjectMemoryProfiler:
size_str = f"{obj_size / 1024:.2f} KB" size_str = f"{obj_size / 1024:.2f} KB"
else: else:
size_str = f"{obj_size} B" size_str = f"{obj_size} B"
print(f"{obj_type:<50} {obj_count:>12,} {size_str:>15}") print(f"{obj_type:<50} {obj_count:>12,} {size_str:>15}")
if "module_stats" in stats and stats["module_stats"]: if stats.get("module_stats"):
print(f"\n📚 模块内存占用 (前 20 个模块):\n") print("\n📚 模块内存占用 (前 20 个模块):\n")
print(f"{'模块名':<40} {'对象数':>12} {'总内存':>15}") print(f"{'模块名':<40} {'对象数':>12} {'总内存':>15}")
print("-" * 80) print("-" * 80)
for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]: for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]:
if obj_size >= 1024 * 1024 * 1024: if obj_size >= 1024 * 1024 * 1024:
size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB" size_str = f"{obj_size / 1024 / 1024 / 1024:.2f} GB"
@@ -388,46 +386,46 @@ class ObjectMemoryProfiler:
size_str = f"{obj_size / 1024:.2f} KB" size_str = f"{obj_size / 1024:.2f} KB"
else: else:
size_str = f"{obj_size} B" size_str = f"{obj_size} B"
print(f"{module_name:<40} {obj_count:>12,} {size_str:>15}") print(f"{module_name:<40} {obj_count:>12,} {size_str:>15}")
print(f"\n 总模块数: {stats['module_stats']['total_modules']}") print(f"\n 总模块数: {stats['module_stats']['total_modules']}")
if "threads" in stats: if "threads" in stats:
print(f"\n🧵 线程信息 ({len(stats['threads'])} 个):") print(f"\n🧵 线程信息 ({len(stats['threads'])} 个):")
for idx, t in enumerate(stats["threads"], 1): for idx, t in enumerate(stats["threads"], 1):
status = "" if t["alive"] else "" status = "" if t["alive"] else ""
daemon = "(守护)" if t["daemon"] else "" daemon = "(守护)" if t["daemon"] else ""
print(f" [{idx}] {status} {t['name']} {daemon}") print(f" [{idx}] {status} {t['name']} {daemon}")
if "gc_stats" in stats: if "gc_stats" in stats:
gc_stats = stats["gc_stats"] gc_stats = stats["gc_stats"]
print(f"\n🗑️ 垃圾回收:") print("\n🗑️ 垃圾回收:")
print(f" 代 0: {gc_stats['collections'][0]:,}") print(f" 代 0: {gc_stats['collections'][0]:,}")
print(f" 代 1: {gc_stats['collections'][1]:,}") print(f" 代 1: {gc_stats['collections'][1]:,}")
print(f" 代 2: {gc_stats['collections'][2]:,}") print(f" 代 2: {gc_stats['collections'][2]:,}")
print(f" 追踪对象: {gc_stats['tracked']:,}") print(f" 追踪对象: {gc_stats['tracked']:,}")
if "total_objects" in stats: if "total_objects" in stats:
print(f"\n📊 总对象数: {stats['total_objects']:,}") print(f"\n📊 总对象数: {stats['total_objects']:,}")
print("=" * 80 + "\n") print("=" * 80 + "\n")
def print_diff(self): def print_diff(self):
"""打印对象变化""" """打印对象变化"""
if not PYMPLER_AVAILABLE or not self.tracker: if not PYMPLER_AVAILABLE or not self.tracker:
return return
print("\n📈 对象变化分析:") print("\n📈 对象变化分析:")
print("-" * 80) print("-" * 80)
self.tracker.print_diff() self.tracker.print_diff()
print("-" * 80) print("-" * 80)
def save_to_file(self, stats: Dict): def save_to_file(self, stats: dict):
"""保存统计信息到文件""" """保存统计信息到文件"""
if not self.output_file: if not self.output_file:
return return
try: try:
# 保存文本 # 保存文本
with open(self.output_file, "a", encoding="utf-8") as f: with open(self.output_file, "a", encoding="utf-8") as f:
@@ -435,91 +433,91 @@ class ObjectMemoryProfiler:
f.write(f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n") f.write(f"时间: {time.strftime('%Y-%m-%d %H:%M:%S')}\n")
f.write(f"迭代: #{self.iteration}\n") f.write(f"迭代: #{self.iteration}\n")
f.write(f"{'=' * 80}\n\n") f.write(f"{'=' * 80}\n\n")
if "summary" in stats: if "summary" in stats:
f.write("对象统计:\n") f.write("对象统计:\n")
for obj_type, obj_count, obj_size in stats["summary"]: for obj_type, obj_count, obj_size in stats["summary"]:
f.write(f" {obj_type}: {obj_count:,} 个, {obj_size:,} 字节\n") f.write(f" {obj_type}: {obj_count:,} 个, {obj_size:,} 字节\n")
if "module_stats" in stats and stats["module_stats"]: if stats.get("module_stats"):
f.write("\n模块统计 (前 20 个):\n") f.write("\n模块统计 (前 20 个):\n")
for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]: for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]:
f.write(f" {module_name}: {obj_count:,} 个对象, {obj_size:,} 字节\n") f.write(f" {module_name}: {obj_count:,} 个对象, {obj_size:,} 字节\n")
f.write(f"\n总对象数: {stats.get('total_objects', 0):,}\n") f.write(f"\n总对象数: {stats.get('total_objects', 0):,}\n")
f.write(f"线程数: {len(stats.get('threads', []))}\n") f.write(f"线程数: {len(stats.get('threads', []))}\n")
# 保存 JSONL # 保存 JSONL
jsonl_path = str(self.output_file) + ".jsonl" jsonl_path = str(self.output_file) + ".jsonl"
record = { record = {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"iteration": self.iteration, "iteration": self.iteration,
"total_objects": stats.get("total_objects", 0), "total_objects": stats.get("total_objects", 0),
"threads": stats.get("threads", []), "threads": stats.get("threads", []),
"gc_stats": stats.get("gc_stats", {}), "gc_stats": stats.get("gc_stats", {}),
"summary": [ "summary": [
{"type": t, "count": c, "size": s} {"type": t, "count": c, "size": s}
for (t, c, s) in stats.get("summary", []) for (t, c, s) in stats.get("summary", [])
], ],
"module_stats": stats.get("module_stats", {}), "module_stats": stats.get("module_stats", {}),
} }
with open(jsonl_path, "a", encoding="utf-8") as jf: with open(jsonl_path, "a", encoding="utf-8") as jf:
jf.write(json.dumps(record, ensure_ascii=False) + "\n") jf.write(json.dumps(record, ensure_ascii=False) + "\n")
if self.iteration == 1: if self.iteration == 1:
print(f"💾 数据保存到: {self.output_file}") print(f"💾 数据保存到: {self.output_file}")
print(f"💾 结构化数据: {jsonl_path}") print(f"💾 结构化数据: {jsonl_path}")
except Exception as e: except Exception as e:
print(f"⚠️ 保存文件失败: {e}") print(f"⚠️ 保存文件失败: {e}")
def start_monitoring(self): def start_monitoring(self):
"""启动监控线程""" """启动监控线程"""
self.running = True self.running = True
def monitor_loop(): def monitor_loop():
print(f"🚀 对象分析器已启动") print("🚀 对象分析器已启动")
print(f" 监控间隔: {self.interval}") print(f" 监控间隔: {self.interval}")
print(f" 对象类型限制: {self.object_limit}") print(f" 对象类型限制: {self.object_limit}")
print(f" 输出文件: {self.output_file or ''}") print(f" 输出文件: {self.output_file or ''}")
print() print()
while self.running: while self.running:
try: try:
self.iteration += 1 self.iteration += 1
stats = self.get_object_stats() stats = self.get_object_stats()
self.print_stats(stats, self.iteration) self.print_stats(stats, self.iteration)
if self.iteration % 3 == 0 and self.tracker: if self.iteration % 3 == 0 and self.tracker:
self.print_diff() self.print_diff()
if self.output_file: if self.output_file:
self.save_to_file(stats) self.save_to_file(stats)
time.sleep(self.interval) time.sleep(self.interval)
except Exception as e: except Exception as e:
print(f"❌ 监控出错: {e}") print(f"❌ 监控出错: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()
monitor_thread = threading.Thread(target=monitor_loop, daemon=True) monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
monitor_thread.start() monitor_thread.start()
print(f"✓ 监控线程已启动\n") print("✓ 监控线程已启动\n")
def stop(self): def stop(self):
"""停止监控""" """停止监控"""
self.running = False self.running = False
def run_objects_mode(interval: int, output: Optional[str], object_limit: int): def run_objects_mode(interval: int, output: str | None, object_limit: int):
"""对象分析模式主函数""" """对象分析模式主函数"""
if not PYMPLER_AVAILABLE: if not PYMPLER_AVAILABLE:
print("❌ pympler 未安装,无法使用对象分析模式") print("❌ pympler 未安装,无法使用对象分析模式")
print(" 安装: pip install pympler") print(" 安装: pip install pympler")
return 1 return 1
print("=" * 80) print("=" * 80)
print("🔬 对象分析模式") print("🔬 对象分析模式")
print("=" * 80) print("=" * 80)
@@ -529,38 +527,38 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
print(" 3. 显示对象变化diff") print(" 3. 显示对象变化diff")
print(" 4. 保存 JSONL 数据用于可视化") print(" 4. 保存 JSONL 数据用于可视化")
print("=" * 80 + "\n") print("=" * 80 + "\n")
# 添加项目根目录到 Python 路径 # 添加项目根目录到 Python 路径
project_root = Path(__file__).parent.parent project_root = Path(__file__).parent.parent
if str(project_root) not in sys.path: if str(project_root) not in sys.path:
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
print(f"✓ 已添加项目根目录到 Python 路径: {project_root}\n") print(f"✓ 已添加项目根目录到 Python 路径: {project_root}\n")
profiler = ObjectMemoryProfiler( profiler = ObjectMemoryProfiler(
interval=interval, interval=interval,
output_file=output, output_file=output,
object_limit=object_limit object_limit=object_limit
) )
profiler.start_monitoring() profiler.start_monitoring()
print("🤖 正在启动 Bot...\n") print("🤖 正在启动 Bot...\n")
try: try:
import bot import bot
if hasattr(bot, 'main_async'): if hasattr(bot, "main_async"):
asyncio.run(bot.main_async()) asyncio.run(bot.main_async())
elif hasattr(bot, 'main'): elif hasattr(bot, "main"):
bot.main() bot.main()
else: else:
print("⚠️ bot.py 未找到 main_async() 或 main() 函数") print("⚠️ bot.py 未找到 main_async() 或 main() 函数")
print(" Bot 模块已导入,监控线程在后台运行") print(" Bot 模块已导入,监控线程在后台运行")
print(" 按 Ctrl+C 停止\n") print(" 按 Ctrl+C 停止\n")
while profiler.running: while profiler.running:
time.sleep(1) time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n⚠️ 用户中断") print("\n\n⚠️ 用户中断")
except Exception as e: except Exception as e:
@@ -569,7 +567,7 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
traceback.print_exc() traceback.print_exc()
finally: finally:
profiler.stop() profiler.stop()
return 0 return 0
@@ -577,10 +575,10 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
# 可视化模式 # 可视化模式
# ============================================================================ # ============================================================================
def load_jsonl(path: Path) -> List[Dict]: def load_jsonl(path: Path) -> list[dict]:
"""加载 JSONL 文件""" """加载 JSONL 文件"""
snapshots = [] snapshots = []
with open(path, "r", encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line: if not line:
@@ -592,7 +590,7 @@ def load_jsonl(path: Path) -> List[Dict]:
return snapshots return snapshots
def aggregate_top_types(snapshots: List[Dict], top_n: int = 10): def aggregate_top_types(snapshots: list[dict], top_n: int = 10):
"""聚合前 N 个对象类型的时间序列""" """聚合前 N 个对象类型的时间序列"""
type_max = defaultdict(int) type_max = defaultdict(int)
for snap in snapshots: for snap in snapshots:
@@ -600,37 +598,37 @@ def aggregate_top_types(snapshots: List[Dict], top_n: int = 10):
t = item.get("type") t = item.get("type")
s = int(item.get("size", 0)) s = int(item.get("size", 0))
type_max[t] = max(type_max[t], s) type_max[t] = max(type_max[t], s)
top_types = sorted(type_max.items(), key=lambda kv: kv[1], reverse=True)[:top_n] top_types = sorted(type_max.items(), key=lambda kv: kv[1], reverse=True)[:top_n]
top_names = [t for t, _ in top_types] top_names = [t for t, _ in top_types]
times = [] times = []
series = {t: [] for t in top_names} series = {t: [] for t in top_names}
for snap in snapshots: for snap in snapshots:
ts = snap.get("timestamp") ts = snap.get("timestamp")
try: try:
times.append(datetime.strptime(ts, "%Y-%m-%d %H:%M:%S")) times.append(datetime.strptime(ts, "%Y-%m-%d %H:%M:%S"))
except Exception: except Exception:
times.append(None) times.append(None)
summary = {item.get("type"): int(item.get("size", 0)) summary = {item.get("type"): int(item.get("size", 0))
for item in snap.get("summary", [])} for item in snap.get("summary", [])}
for t in top_names: for t in top_names:
series[t].append(summary.get(t, 0) / 1024.0 / 1024.0) series[t].append(summary.get(t, 0) / 1024.0 / 1024.0)
return times, series return times, series
def plot_series(times: List, series: Dict, output: Path, top_n: int): def plot_series(times: list, series: dict, output: Path, top_n: int):
"""绘制时间序列图""" """绘制时间序列图"""
plt.figure(figsize=(14, 8)) plt.figure(figsize=(14, 8))
for name, values in series.items(): for name, values in series.items():
if all(v == 0 for v in values): if all(v == 0 for v in values):
continue continue
plt.plot(times, values, marker="o", label=name, linewidth=2) plt.plot(times, values, marker="o", label=name, linewidth=2)
plt.xlabel("时间", fontsize=12) plt.xlabel("时间", fontsize=12)
plt.ylabel("内存 (MB)", fontsize=12) plt.ylabel("内存 (MB)", fontsize=12)
plt.title(f"对象类型随时间的内存占用 (前 {top_n} 类型)", fontsize=14) plt.title(f"对象类型随时间的内存占用 (前 {top_n} 类型)", fontsize=14)
@@ -647,31 +645,31 @@ def run_visualize_mode(input_file: str, output_file: str, top: int):
print("❌ matplotlib 未安装,无法使用可视化模式") print("❌ matplotlib 未安装,无法使用可视化模式")
print(" 安装: pip install matplotlib") print(" 安装: pip install matplotlib")
return 1 return 1
print("=" * 80) print("=" * 80)
print("📊 可视化模式") print("📊 可视化模式")
print("=" * 80) print("=" * 80)
path = Path(input_file) path = Path(input_file)
if not path.exists(): if not path.exists():
print(f"❌ 找不到输入文件: {path}") print(f"❌ 找不到输入文件: {path}")
return 1 return 1
print(f"📂 读取数据: {path}") print(f"📂 读取数据: {path}")
snaps = load_jsonl(path) snaps = load_jsonl(path)
if not snaps: if not snaps:
print("❌ 未读取到任何快照数据") print("❌ 未读取到任何快照数据")
return 1 return 1
print(f"✓ 读取 {len(snaps)} 个快照") print(f"✓ 读取 {len(snaps)} 个快照")
times, series = aggregate_top_types(snaps, top_n=top) times, series = aggregate_top_types(snaps, top_n=top)
print(f"✓ 提取前 {top} 个对象类型") print(f"✓ 提取前 {top} 个对象类型")
output_path = Path(output_file) output_path = Path(output_file)
plot_series(times, series, output_path, top) plot_series(times, series, output_path, top)
return 0 return 0
@@ -693,10 +691,10 @@ def main():
使用示例: 使用示例:
# 进程监控(启动 bot 并监控) # 进程监控(启动 bot 并监控)
python scripts/memory_profiler.py --monitor --interval 10 python scripts/memory_profiler.py --monitor --interval 10
# 对象分析(深度对象统计) # 对象分析(深度对象统计)
python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt python scripts/memory_profiler.py --objects --interval 10 --output memory_data.txt
# 生成可视化图表 # 生成可视化图表
python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15 --output plot.png python scripts/memory_profiler.py --visualize --input memory_data.txt.jsonl --top 15 --output plot.png
@@ -705,26 +703,26 @@ def main():
- 可视化模式需要: pip install matplotlib - 可视化模式需要: pip install matplotlib
""", """,
) )
# 模式选择 # 模式选择
mode_group = parser.add_mutually_exclusive_group(required=True) mode_group = parser.add_mutually_exclusive_group(required=True)
mode_group.add_argument("--monitor", "-m", action="store_true", mode_group.add_argument("--monitor", "-m", action="store_true",
help="进程监控模式(外部监控 bot 进程)") help="进程监控模式(外部监控 bot 进程)")
mode_group.add_argument("--objects", "-o", action="store_true", mode_group.add_argument("--objects", "-o", action="store_true",
help="对象分析模式(内部统计所有对象)") help="对象分析模式(内部统计所有对象)")
mode_group.add_argument("--visualize", "-v", action="store_true", mode_group.add_argument("--visualize", "-v", action="store_true",
help="可视化模式(绘制 JSONL 数据)") help="可视化模式(绘制 JSONL 数据)")
# 通用参数 # 通用参数
parser.add_argument("--interval", "-i", type=int, default=10, parser.add_argument("--interval", "-i", type=int, default=10,
help="监控间隔(秒),默认 10") help="监控间隔(秒),默认 10")
# 对象分析参数 # 对象分析参数
parser.add_argument("--output", type=str, parser.add_argument("--output", type=str,
help="输出文件路径(对象分析模式)") help="输出文件路径(对象分析模式)")
parser.add_argument("--object-limit", "-l", type=int, default=20, parser.add_argument("--object-limit", "-l", type=int, default=20,
help="对象类型显示数量,默认 20") help="对象类型显示数量,默认 20")
# 可视化参数 # 可视化参数
parser.add_argument("--input", type=str, parser.add_argument("--input", type=str,
help="输入 JSONL 文件(可视化模式)") help="输入 JSONL 文件(可视化模式)")
@@ -732,24 +730,24 @@ def main():
help="展示前 N 个类型(可视化模式),默认 10") help="展示前 N 个类型(可视化模式),默认 10")
parser.add_argument("--plot-output", type=str, default="memory_analysis_plot.png", parser.add_argument("--plot-output", type=str, default="memory_analysis_plot.png",
help="图表输出文件,默认 memory_analysis_plot.png") help="图表输出文件,默认 memory_analysis_plot.png")
args = parser.parse_args() args = parser.parse_args()
# 根据模式执行 # 根据模式执行
if args.monitor: if args.monitor:
return asyncio.run(run_monitor_mode(args.interval)) return asyncio.run(run_monitor_mode(args.interval))
elif args.objects: elif args.objects:
if not args.output: if not args.output:
print("⚠️ 建议使用 --output 指定输出文件以保存数据") print("⚠️ 建议使用 --output 指定输出文件以保存数据")
return run_objects_mode(args.interval, args.output, args.object_limit) return run_objects_mode(args.interval, args.output, args.object_limit)
elif args.visualize: elif args.visualize:
if not args.input: if not args.input:
print("❌ 可视化模式需要 --input 参数指定 JSONL 文件") print("❌ 可视化模式需要 --input 参数指定 JSONL 文件")
return 1 return 1
return run_visualize_mode(args.input, args.plot_output, args.top) return run_visualize_mode(args.input, args.plot_output, args.top)
return 0 return 0

View File

@@ -680,9 +680,9 @@ class EmojiManager:
try: try:
# 🔧 使用 QueryBuilder 以启用数据库缓存 # 🔧 使用 QueryBuilder 以启用数据库缓存
from src.common.database.api.query import QueryBuilder from src.common.database.api.query import QueryBuilder
logger.debug("[数据库] 开始加载所有表情包记录 ...") logger.debug("[数据库] 开始加载所有表情包记录 ...")
emoji_instances = await QueryBuilder(Emoji).all() emoji_instances = await QueryBuilder(Emoji).all()
emoji_objects, load_errors = _to_emoji_objects(emoji_instances) emoji_objects, load_errors = _to_emoji_objects(emoji_instances)
@@ -802,7 +802,7 @@ class EmojiManager:
# 如果内存中没有,从数据库查找(使用 QueryBuilder 启用数据库缓存) # 如果内存中没有,从数据库查找(使用 QueryBuilder 启用数据库缓存)
try: try:
from src.common.database.api.query import QueryBuilder from src.common.database.api.query import QueryBuilder
emoji_record = await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first() emoji_record = await QueryBuilder(Emoji).filter(emoji_hash=emoji_hash).first()
if emoji_record and emoji_record.description: if emoji_record and emoji_record.description:
logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...") logger.info(f"[缓存命中] 从数据库获取表情包描述: {emoji_record.description[:50]}...")
@@ -966,7 +966,7 @@ class EmojiManager:
existing_description = None existing_description = None
try: try:
from src.common.database.api.query import QueryBuilder from src.common.database.api.query import QueryBuilder
existing_image = await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first() existing_image = await QueryBuilder(Images).filter(emoji_hash=image_hash, type="emoji").first()
if existing_image and existing_image.description: if existing_image and existing_image.description:
existing_description = existing_image.description existing_description = existing_image.description

View File

@@ -1,5 +1,4 @@
import os import os
import random
import time import time
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
@@ -135,20 +134,20 @@ class ExpressionLearner:
async def cleanup_expired_expressions(self, expiration_days: int | None = None) -> int: async def cleanup_expired_expressions(self, expiration_days: int | None = None) -> int:
""" """
清理过期的表达方式 清理过期的表达方式
Args: Args:
expiration_days: 过期天数,超过此天数未激活的表达方式将被删除(不指定则从配置读取) expiration_days: 过期天数,超过此天数未激活的表达方式将被删除(不指定则从配置读取)
Returns: Returns:
int: 删除的表达方式数量 int: 删除的表达方式数量
""" """
# 从配置读取过期天数 # 从配置读取过期天数
if expiration_days is None: if expiration_days is None:
expiration_days = global_config.expression.expiration_days expiration_days = global_config.expression.expiration_days
current_time = time.time() current_time = time.time()
expiration_threshold = current_time - (expiration_days * 24 * 3600) expiration_threshold = current_time - (expiration_days * 24 * 3600)
try: try:
deleted_count = 0 deleted_count = 0
async with get_db_session() as session: async with get_db_session() as session:
@@ -160,15 +159,15 @@ class ExpressionLearner:
) )
) )
expired_expressions = list(query.scalars()) expired_expressions = list(query.scalars())
if expired_expressions: if expired_expressions:
for expr in expired_expressions: for expr in expired_expressions:
await session.delete(expr) await session.delete(expr)
deleted_count += 1 deleted_count += 1
await session.commit() await session.commit()
logger.info(f"清理了 {deleted_count} 个过期表达方式(超过 {expiration_days} 天未使用)") logger.info(f"清理了 {deleted_count} 个过期表达方式(超过 {expiration_days} 天未使用)")
# 清除缓存 # 清除缓存
from src.common.database.optimization.cache_manager import get_cache from src.common.database.optimization.cache_manager import get_cache
from src.common.database.utils.decorators import generate_cache_key from src.common.database.utils.decorators import generate_cache_key
@@ -176,7 +175,7 @@ class ExpressionLearner:
await cache.delete(generate_cache_key("chat_expressions", self.chat_id)) await cache.delete(generate_cache_key("chat_expressions", self.chat_id))
else: else:
logger.debug(f"没有发现过期的表达方式(阈值:{expiration_days} 天)") logger.debug(f"没有发现过期的表达方式(阈值:{expiration_days} 天)")
return deleted_count return deleted_count
except Exception as e: except Exception as e:
logger.error(f"清理过期表达方式失败: {e}") logger.error(f"清理过期表达方式失败: {e}")
@@ -460,7 +459,7 @@ class ExpressionLearner:
) )
) )
same_situation_expr = query_same_situation.scalar() same_situation_expr = query_same_situation.scalar()
# 情况2相同 chat_id + type + style相同表达不同情景 # 情况2相同 chat_id + type + style相同表达不同情景
query_same_style = await session.execute( query_same_style = await session.execute(
select(Expression).where( select(Expression).where(
@@ -470,7 +469,7 @@ class ExpressionLearner:
) )
) )
same_style_expr = query_same_style.scalar() same_style_expr = query_same_style.scalar()
# 情况3完全相同相同情景+相同表达) # 情况3完全相同相同情景+相同表达)
query_exact_match = await session.execute( query_exact_match = await session.execute(
select(Expression).where( select(Expression).where(
@@ -481,7 +480,7 @@ class ExpressionLearner:
) )
) )
exact_match_expr = query_exact_match.scalar() exact_match_expr = query_exact_match.scalar()
# 优先处理完全匹配的情况 # 优先处理完全匹配的情况
if exact_match_expr: if exact_match_expr:
# 完全相同增加count更新时间 # 完全相同增加count更新时间

View File

@@ -72,21 +72,21 @@ class ExpressorModel:
是否删除成功 是否删除成功
""" """
removed = False removed = False
if cid in self._candidates: if cid in self._candidates:
del self._candidates[cid] del self._candidates[cid]
removed = True removed = True
if cid in self._situations: if cid in self._situations:
del self._situations[cid] del self._situations[cid]
# 从nb模型中删除 # 从nb模型中删除
if cid in self.nb.cls_counts: if cid in self.nb.cls_counts:
del self.nb.cls_counts[cid] del self.nb.cls_counts[cid]
if cid in self.nb.token_counts: if cid in self.nb.token_counts:
del self.nb.token_counts[cid] del self.nb.token_counts[cid]
return removed return removed
def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]: def predict(self, text: str, k: int | None = None) -> tuple[str | None, dict[str, float]]:

View File

@@ -72,7 +72,7 @@ class StyleLearner:
# 检查是否需要清理 # 检查是否需要清理
current_count = len(self.style_to_id) current_count = len(self.style_to_id)
cleanup_trigger = int(self.max_styles * self.cleanup_threshold) cleanup_trigger = int(self.max_styles * self.cleanup_threshold)
if current_count >= cleanup_trigger: if current_count >= cleanup_trigger:
if current_count >= self.max_styles: if current_count >= self.max_styles:
# 已经达到最大限制,必须清理 # 已经达到最大限制,必须清理
@@ -109,7 +109,7 @@ class StyleLearner:
def _cleanup_styles(self): def _cleanup_styles(self):
""" """
清理低价值的风格,为新风格腾出空间 清理低价值的风格,为新风格腾出空间
清理策略: 清理策略:
1. 综合考虑使用次数和最后使用时间 1. 综合考虑使用次数和最后使用时间
2. 删除得分最低的风格 2. 删除得分最低的风格
@@ -118,34 +118,34 @@ class StyleLearner:
try: try:
current_time = time.time() current_time = time.time()
cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio)) cleanup_count = max(1, int(len(self.style_to_id) * self.cleanup_ratio))
# 计算每个风格的价值分数 # 计算每个风格的价值分数
style_scores = [] style_scores = []
for style_id in self.style_to_id.values(): for style_id in self.style_to_id.values():
# 使用次数 # 使用次数
usage_count = self.learning_stats["style_counts"].get(style_id, 0) usage_count = self.learning_stats["style_counts"].get(style_id, 0)
# 最后使用时间(越近越好) # 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0) last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float('inf') time_since_used = current_time - last_used if last_used > 0 else float("inf")
# 综合分数:使用次数越多越好,距离上次使用时间越短越好 # 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响 # 使用对数来平滑使用次数的影响
import math import math
usage_score = math.log1p(usage_count) # log(1 + count) usage_score = math.log1p(usage_count) # log(1 + count)
# 时间分数:转换为天数,使用指数衰减 # 时间分数:转换为天数,使用指数衰减
days_unused = time_since_used / 86400 # 转换为天 days_unused = time_since_used / 86400 # 转换为天
time_score = math.exp(-days_unused / 30) # 30天衰减因子 time_score = math.exp(-days_unused / 30) # 30天衰减因子
# 综合分数80%使用频率 + 20%时间新鲜度 # 综合分数80%使用频率 + 20%时间新鲜度
total_score = 0.8 * usage_score + 0.2 * time_score total_score = 0.8 * usage_score + 0.2 * time_score
style_scores.append((style_id, total_score, usage_count, days_unused)) style_scores.append((style_id, total_score, usage_count, days_unused))
# 按分数排序,分数低的先删除 # 按分数排序,分数低的先删除
style_scores.sort(key=lambda x: x[1]) style_scores.sort(key=lambda x: x[1])
# 删除分数最低的风格 # 删除分数最低的风格
deleted_styles = [] deleted_styles = []
for style_id, score, usage, days in style_scores[:cleanup_count]: for style_id, score, usage, days in style_scores[:cleanup_count]:
@@ -156,27 +156,27 @@ class StyleLearner:
del self.id_to_style[style_id] del self.id_to_style[style_id]
if style_id in self.id_to_situation: if style_id in self.id_to_situation:
del self.id_to_situation[style_id] del self.id_to_situation[style_id]
# 从统计中删除 # 从统计中删除
if style_id in self.learning_stats["style_counts"]: if style_id in self.learning_stats["style_counts"]:
del self.learning_stats["style_counts"][style_id] del self.learning_stats["style_counts"][style_id]
if style_id in self.learning_stats["style_last_used"]: if style_id in self.learning_stats["style_last_used"]:
del self.learning_stats["style_last_used"][style_id] del self.learning_stats["style_last_used"][style_id]
# 从expressor模型中删除 # 从expressor模型中删除
self.expressor.remove_candidate(style_id) self.expressor.remove_candidate(style_id)
deleted_styles.append((style_text[:30], usage, f"{days:.1f}")) deleted_styles.append((style_text[:30], usage, f"{days:.1f}"))
logger.info( logger.info(
f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格," f"风格清理完成: 删除了 {len(deleted_styles)}/{len(style_scores)} 个风格,"
f"剩余 {len(self.style_to_id)} 个风格" f"剩余 {len(self.style_to_id)} 个风格"
) )
# 记录前5个被删除的风格用于调试 # 记录前5个被删除的风格用于调试
if deleted_styles: if deleted_styles:
logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}") logger.debug(f"被删除的风格样例(前5): {deleted_styles[:5]}")
except Exception as e: except Exception as e:
logger.error(f"清理风格失败: {e}", exc_info=True) logger.error(f"清理风格失败: {e}", exc_info=True)
@@ -303,10 +303,10 @@ class StyleLearner:
def cleanup_old_styles(self, ratio: float | None = None) -> int: def cleanup_old_styles(self, ratio: float | None = None) -> int:
""" """
手动清理旧风格 手动清理旧风格
Args: Args:
ratio: 清理比例如果为None则使用默认的cleanup_ratio ratio: 清理比例如果为None则使用默认的cleanup_ratio
Returns: Returns:
清理的风格数量 清理的风格数量
""" """
@@ -318,7 +318,7 @@ class StyleLearner:
self.cleanup_ratio = old_cleanup_ratio self.cleanup_ratio = old_cleanup_ratio
else: else:
self._cleanup_styles() self._cleanup_styles()
new_count = len(self.style_to_id) new_count = len(self.style_to_id)
cleaned = old_count - new_count cleaned = old_count - new_count
logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格") logger.info(f"手动清理完成: chat_id={self.chat_id}, 清理了 {cleaned} 个风格")
@@ -357,11 +357,11 @@ class StyleLearner:
import pickle import pickle
meta_path = os.path.join(save_dir, "meta.pkl") meta_path = os.path.join(save_dir, "meta.pkl")
# 确保 learning_stats 包含所有必要字段 # 确保 learning_stats 包含所有必要字段
if "style_last_used" not in self.learning_stats: if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {} self.learning_stats["style_last_used"] = {}
meta_data = { meta_data = {
"style_to_id": self.style_to_id, "style_to_id": self.style_to_id,
"id_to_style": self.id_to_style, "id_to_style": self.id_to_style,
@@ -416,7 +416,7 @@ class StyleLearner:
self.id_to_situation = meta_data["id_to_situation"] self.id_to_situation = meta_data["id_to_situation"]
self.next_style_id = meta_data["next_style_id"] self.next_style_id = meta_data["next_style_id"]
self.learning_stats = meta_data["learning_stats"] self.learning_stats = meta_data["learning_stats"]
# 确保旧数据兼容:如果没有 style_last_used 字段,添加它 # 确保旧数据兼容:如果没有 style_last_used 字段,添加它
if "style_last_used" not in self.learning_stats: if "style_last_used" not in self.learning_stats:
self.learning_stats["style_last_used"] = {} self.learning_stats["style_last_used"] = {}
@@ -526,10 +526,10 @@ class StyleLearnerManager:
def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]: def cleanup_all_old_styles(self, ratio: float | None = None) -> dict[str, int]:
""" """
对所有学习器清理旧风格 对所有学习器清理旧风格
Args: Args:
ratio: 清理比例 ratio: 清理比例
Returns: Returns:
{chat_id: 清理数量} {chat_id: 清理数量}
""" """
@@ -538,7 +538,7 @@ class StyleLearnerManager:
cleaned = learner.cleanup_old_styles(ratio) cleaned = learner.cleanup_old_styles(ratio)
if cleaned > 0: if cleaned > 0:
cleanup_results[chat_id] = cleaned cleanup_results[chat_id] = cleaned
total_cleaned = sum(cleanup_results.values()) total_cleaned = sum(cleanup_results.values())
logger.info(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格") logger.info(f"清理所有StyleLearner完成: 总共清理了 {total_cleaned} 个风格")
return cleanup_results return cleanup_results

View File

@@ -8,7 +8,6 @@ from datetime import datetime
from typing import Any from typing import Any
import numpy as np import numpy as np
import orjson
from sqlalchemy import select from sqlalchemy import select
from src.common.config_helpers import resolve_embedding_dimension from src.common.config_helpers import resolve_embedding_dimension
@@ -124,7 +123,7 @@ class BotInterestManager:
tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()] tags_info = [f" - '{tag.tag_name}' (权重: {tag.weight:.2f})" for tag in loaded_interests.get_active_tags()]
tags_str = "\n".join(tags_info) tags_str = "\n".join(tags_info)
logger.info(f"当前兴趣标签:\n{tags_str}") logger.info(f"当前兴趣标签:\n{tags_str}")
# 为加载的标签生成embedding数据库不存储embedding启动时动态生成 # 为加载的标签生成embedding数据库不存储embedding启动时动态生成
logger.info("🧠 为加载的标签生成embedding向量...") logger.info("🧠 为加载的标签生成embedding向量...")
await self._generate_embeddings_for_tags(loaded_interests) await self._generate_embeddings_for_tags(loaded_interests)
@@ -326,13 +325,13 @@ class BotInterestManager:
raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding") raise RuntimeError("❌ Embedding客户端未初始化无法生成embedding")
total_tags = len(interests.interest_tags) total_tags = len(interests.interest_tags)
# 尝试从文件加载缓存 # 尝试从文件加载缓存
file_cache = await self._load_embedding_cache_from_file(interests.personality_id) file_cache = await self._load_embedding_cache_from_file(interests.personality_id)
if file_cache: if file_cache:
logger.info(f"📂 从文件加载 {len(file_cache)} 个embedding缓存") logger.info(f"📂 从文件加载 {len(file_cache)} 个embedding缓存")
self.embedding_cache.update(file_cache) self.embedding_cache.update(file_cache)
logger.info(f"🧠 开始为 {total_tags} 个兴趣标签生成embedding向量...") logger.info(f"🧠 开始为 {total_tags} 个兴趣标签生成embedding向量...")
memory_cached_count = 0 memory_cached_count = 0
@@ -477,14 +476,14 @@ class BotInterestManager:
self, message_text: str, keywords: list[str] | None = None self, message_text: str, keywords: list[str] | None = None
) -> InterestMatchResult: ) -> InterestMatchResult:
"""计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略) """计算消息与机器人兴趣的匹配度(优化版 - 标签扩展策略)
核心优化:将短标签扩展为完整的描述性句子,解决语义粒度不匹配问题 核心优化:将短标签扩展为完整的描述性句子,解决语义粒度不匹配问题
原问题: 原问题:
- 消息: "今天天气不错" (完整句子) - 消息: "今天天气不错" (完整句子)
- 标签: "蹭人治愈" (2-4字短语) - 标签: "蹭人治愈" (2-4字短语)
- 结果: 误匹配,因为短标签的 embedding 过于抽象 - 结果: 误匹配,因为短标签的 embedding 过于抽象
解决方案: 解决方案:
- 标签扩展: "蹭人治愈" -> "表达亲近、寻求安慰、撒娇的内容" - 标签扩展: "蹭人治愈" -> "表达亲近、寻求安慰、撒娇的内容"
- 现在是: 句子 vs 句子,匹配更准确 - 现在是: 句子 vs 句子,匹配更准确
@@ -527,18 +526,18 @@ class BotInterestManager:
if tag.embedding: if tag.embedding:
# 🔧 优化:获取扩展标签的 embedding带缓存 # 🔧 优化:获取扩展标签的 embedding带缓存
expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name) expanded_embedding = await self._get_expanded_tag_embedding(tag.tag_name)
if expanded_embedding: if expanded_embedding:
# 使用扩展标签的 embedding 进行匹配 # 使用扩展标签的 embedding 进行匹配
similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding) similarity = self._calculate_cosine_similarity(message_embedding, expanded_embedding)
# 同时计算原始标签的相似度作为参考 # 同时计算原始标签的相似度作为参考
original_similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding) original_similarity = self._calculate_cosine_similarity(message_embedding, tag.embedding)
# 混合策略扩展标签权重更高70%原始标签作为补充30% # 混合策略扩展标签权重更高70%原始标签作为补充30%
# 这样可以兼顾准确性(扩展)和灵活性(原始) # 这样可以兼顾准确性(扩展)和灵活性(原始)
final_similarity = similarity * 0.7 + original_similarity * 0.3 final_similarity = similarity * 0.7 + original_similarity * 0.3
logger.debug(f"标签'{tag.tag_name}': 原始={original_similarity:.3f}, 扩展={similarity:.3f}, 最终={final_similarity:.3f}") logger.debug(f"标签'{tag.tag_name}': 原始={original_similarity:.3f}, 扩展={similarity:.3f}, 最终={final_similarity:.3f}")
else: else:
# 如果扩展 embedding 获取失败,使用原始 embedding # 如果扩展 embedding 获取失败,使用原始 embedding
@@ -603,27 +602,27 @@ class BotInterestManager:
logger.debug( logger.debug(
f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}" f"最终结果: 总分={result.overall_score:.3f}, 置信度={result.confidence:.3f}, 匹配标签数={len(result.matched_tags)}"
) )
# 如果有新生成的扩展embedding保存到缓存文件 # 如果有新生成的扩展embedding保存到缓存文件
if hasattr(self, '_new_expanded_embeddings_generated') and self._new_expanded_embeddings_generated: if hasattr(self, "_new_expanded_embeddings_generated") and self._new_expanded_embeddings_generated:
await self._save_embedding_cache_to_file(self.current_interests.personality_id) await self._save_embedding_cache_to_file(self.current_interests.personality_id)
self._new_expanded_embeddings_generated = False self._new_expanded_embeddings_generated = False
logger.debug("💾 已保存新生成的扩展embedding到缓存文件") logger.debug("💾 已保存新生成的扩展embedding到缓存文件")
return result return result
async def _get_expanded_tag_embedding(self, tag_name: str) -> list[float] | None: async def _get_expanded_tag_embedding(self, tag_name: str) -> list[float] | None:
"""获取扩展标签的 embedding带缓存 """获取扩展标签的 embedding带缓存
优先使用缓存,如果没有则生成并缓存 优先使用缓存,如果没有则生成并缓存
""" """
# 检查缓存 # 检查缓存
if tag_name in self.expanded_embedding_cache: if tag_name in self.expanded_embedding_cache:
return self.expanded_embedding_cache[tag_name] return self.expanded_embedding_cache[tag_name]
# 扩展标签 # 扩展标签
expanded_tag = self._expand_tag_for_matching(tag_name) expanded_tag = self._expand_tag_for_matching(tag_name)
# 生成 embedding # 生成 embedding
try: try:
embedding = await self._get_embedding(expanded_tag) embedding = await self._get_embedding(expanded_tag)
@@ -636,19 +635,19 @@ class BotInterestManager:
return embedding return embedding
except Exception as e: except Exception as e:
logger.warning(f"为标签'{tag_name}'生成扩展embedding失败: {e}") logger.warning(f"为标签'{tag_name}'生成扩展embedding失败: {e}")
return None return None
def _expand_tag_for_matching(self, tag_name: str) -> str: def _expand_tag_for_matching(self, tag_name: str) -> str:
"""将短标签扩展为完整的描述性句子 """将短标签扩展为完整的描述性句子
这是解决"标签太短导致误匹配"的核心方法 这是解决"标签太短导致误匹配"的核心方法
策略: 策略:
1. 优先使用 LLM 生成的 expanded 字段(最准确) 1. 优先使用 LLM 生成的 expanded 字段(最准确)
2. 如果没有,使用基于规则的回退方案 2. 如果没有,使用基于规则的回退方案
3. 最后使用通用模板 3. 最后使用通用模板
示例: 示例:
- "Python" + expanded -> "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题" - "Python" + expanded -> "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
- "蹭人治愈" + expanded -> "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话" - "蹭人治愈" + expanded -> "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
@@ -656,7 +655,7 @@ class BotInterestManager:
# 使用缓存 # 使用缓存
if tag_name in self.expanded_tag_cache: if tag_name in self.expanded_tag_cache:
return self.expanded_tag_cache[tag_name] return self.expanded_tag_cache[tag_name]
# 🎯 优先策略:使用 LLM 生成的 expanded 字段 # 🎯 优先策略:使用 LLM 生成的 expanded 字段
if self.current_interests: if self.current_interests:
for tag in self.current_interests.interest_tags: for tag in self.current_interests.interest_tags:
@@ -664,66 +663,66 @@ class BotInterestManager:
logger.debug(f"✅ 使用LLM生成的扩展描述: {tag_name} -> {tag.expanded[:50]}...") logger.debug(f"✅ 使用LLM生成的扩展描述: {tag_name} -> {tag.expanded[:50]}...")
self.expanded_tag_cache[tag_name] = tag.expanded self.expanded_tag_cache[tag_name] = tag.expanded
return tag.expanded return tag.expanded
# 🔧 回退策略基于规则的扩展用于兼容旧数据或LLM未生成扩展的情况 # 🔧 回退策略基于规则的扩展用于兼容旧数据或LLM未生成扩展的情况
logger.debug(f"⚠️ 标签'{tag_name}'没有LLM扩展描述使用规则回退方案") logger.debug(f"⚠️ 标签'{tag_name}'没有LLM扩展描述使用规则回退方案")
tag_lower = tag_name.lower() tag_lower = tag_name.lower()
# 技术编程类标签(具体化描述) # 技术编程类标签(具体化描述)
if any(word in tag_lower for word in ['python', 'java', 'code', '代码', '编程', '脚本', '算法', '开发']): if any(word in tag_lower for word in ["python", "java", "code", "代码", "编程", "脚本", "算法", "开发"]):
if 'python' in tag_lower: if "python" in tag_lower:
return f"讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题" return "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
elif '算法' in tag_lower: elif "算法" in tag_lower:
return f"讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化" return "讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化"
elif '代码' in tag_lower or '被窝' in tag_lower: elif "代码" in tag_lower or "被窝" in tag_lower:
return f"讨论写代码、编程开发、代码实现、技术方案、编程技巧" return "讨论写代码、编程开发、代码实现、技术方案、编程技巧"
else: else:
return f"讨论编程开发、软件技术、代码编写、技术实现" return "讨论编程开发、软件技术、代码编写、技术实现"
# 情感表达类标签(具体化为真实对话场景) # 情感表达类标签(具体化为真实对话场景)
elif any(word in tag_lower for word in ['治愈', '撒娇', '安慰', '呼噜', '', '卖萌']): elif any(word in tag_lower for word in ["治愈", "撒娇", "安慰", "呼噜", "", "卖萌"]):
return f"想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话" return "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
# 游戏娱乐类标签(具体游戏场景) # 游戏娱乐类标签(具体游戏场景)
elif any(word in tag_lower for word in ['游戏', '网游', 'mmo', '', '']): elif any(word in tag_lower for word in ["游戏", "网游", "mmo", "", ""]):
return f"讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得" return "讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得"
# 动漫影视类标签(具体观看行为) # 动漫影视类标签(具体观看行为)
elif any(word in tag_lower for word in ['', '动漫', '视频', 'b站', '弹幕', '追番', '云新番']): elif any(word in tag_lower for word in ["", "动漫", "视频", "b站", "弹幕", "追番", "云新番"]):
# 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西" # 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西"
if '' in tag_lower or '新番' in tag_lower: if "" in tag_lower or "新番" in tag_lower:
return f"讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色" return "讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色"
else: else:
return f"讨论动漫番剧内容、B站视频、弹幕文化、追番体验" return "讨论动漫番剧内容、B站视频、弹幕文化、追番体验"
# 社交平台类标签(具体平台行为) # 社交平台类标签(具体平台行为)
elif any(word in tag_lower for word in ['小红书', '贴吧', '论坛', '社区', '吃瓜', '八卦']): elif any(word in tag_lower for word in ["小红书", "贴吧", "论坛", "社区", "吃瓜", "八卦"]):
if '吃瓜' in tag_lower: if "吃瓜" in tag_lower:
return f"聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题" return "聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题"
else: else:
return f"讨论社交平台内容、网络社区话题、论坛讨论、分享生活" return "讨论社交平台内容、网络社区话题、论坛讨论、分享生活"
# 生活日常类标签(具体萌宠场景) # 生活日常类标签(具体萌宠场景)
elif any(word in tag_lower for word in ['', '宠物', '尾巴', '耳朵', '毛绒']): elif any(word in tag_lower for word in ["", "宠物", "尾巴", "耳朵", "毛绒"]):
return f"讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得" return "讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得"
# 状态心情类标签(具体情绪状态) # 状态心情类标签(具体情绪状态)
elif any(word in tag_lower for word in ['社恐', '隐身', '流浪', '深夜', '被窝']): elif any(word in tag_lower for word in ["社恐", "隐身", "流浪", "深夜", "被窝"]):
if '社恐' in tag_lower: if "社恐" in tag_lower:
return f"表达社交焦虑、不想见人、想躲起来、害怕社交的心情" return "表达社交焦虑、不想见人、想躲起来、害怕社交的心情"
elif '深夜' in tag_lower: elif "深夜" in tag_lower:
return f"深夜睡不着、熬夜、夜猫子、深夜思考人生的对话" return "深夜睡不着、熬夜、夜猫子、深夜思考人生的对话"
else: else:
return f"表达当前心情状态、个人感受、生活状态" return "表达当前心情状态、个人感受、生活状态"
# 物品装备类标签(具体使用场景) # 物品装备类标签(具体使用场景)
elif any(word in tag_lower for word in ['键盘', '耳机', '装备', '设备']): elif any(word in tag_lower for word in ["键盘", "耳机", "装备", "设备"]):
return f"讨论键盘耳机装备、数码产品、使用体验、装备推荐评测" return "讨论键盘耳机装备、数码产品、使用体验、装备推荐评测"
# 互动关系类标签 # 互动关系类标签
elif any(word in tag_lower for word in ['拾风', '互怼', '互动']): elif any(word in tag_lower for word in ["拾风", "互怼", "互动"]):
return f"聊天互动、开玩笑、友好互怼、日常对话交流" return "聊天互动、开玩笑、友好互怼、日常对话交流"
# 默认:尽量具体化 # 默认:尽量具体化
else: else:
return f"明确讨论{tag_name}这个特定主题的具体内容和相关话题" return f"明确讨论{tag_name}这个特定主题的具体内容和相关话题"
@@ -1011,56 +1010,58 @@ class BotInterestManager:
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None: async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None:
"""从文件加载embedding缓存""" """从文件加载embedding缓存"""
try: try:
import orjson
from pathlib import Path from pathlib import Path
import orjson
cache_dir = Path("data/embedding") cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json" cache_file = cache_dir / f"{personality_id}_embeddings.json"
if not cache_file.exists(): if not cache_file.exists():
logger.debug(f"📂 Embedding缓存文件不存在: {cache_file}") logger.debug(f"📂 Embedding缓存文件不存在: {cache_file}")
return None return None
# 读取缓存文件 # 读取缓存文件
with open(cache_file, "rb") as f: with open(cache_file, "rb") as f:
cache_data = orjson.loads(f.read()) cache_data = orjson.loads(f.read())
# 验证缓存版本和embedding模型 # 验证缓存版本和embedding模型
cache_version = cache_data.get("version", 1) cache_version = cache_data.get("version", 1)
cache_embedding_model = cache_data.get("embedding_model", "") cache_embedding_model = cache_data.get("embedding_model", "")
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") else "" current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") else ""
if cache_embedding_model != current_embedding_model: if cache_embedding_model != current_embedding_model:
logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model}{current_embedding_model}),忽略旧缓存") logger.warning(f"⚠️ Embedding模型已变更 ({cache_embedding_model}{current_embedding_model}),忽略旧缓存")
return None return None
embeddings = cache_data.get("embeddings", {}) embeddings = cache_data.get("embeddings", {})
# 同时加载扩展标签的embedding缓存 # 同时加载扩展标签的embedding缓存
expanded_embeddings = cache_data.get("expanded_embeddings", {}) expanded_embeddings = cache_data.get("expanded_embeddings", {})
if expanded_embeddings: if expanded_embeddings:
self.expanded_embedding_cache.update(expanded_embeddings) self.expanded_embedding_cache.update(expanded_embeddings)
logger.info(f"📂 加载 {len(expanded_embeddings)} 个扩展标签embedding缓存") logger.info(f"📂 加载 {len(expanded_embeddings)} 个扩展标签embedding缓存")
logger.info(f"✅ 成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})") logger.info(f"✅ 成功从文件加载 {len(embeddings)} 个标签embedding缓存 (版本: {cache_version}, 模型: {cache_embedding_model})")
return embeddings return embeddings
except Exception as e: except Exception as e:
logger.warning(f"⚠️ 加载embedding缓存文件失败: {e}") logger.warning(f"⚠️ 加载embedding缓存文件失败: {e}")
return None return None
async def _save_embedding_cache_to_file(self, personality_id: str): async def _save_embedding_cache_to_file(self, personality_id: str):
"""保存embedding缓存到文件包括扩展标签的embedding""" """保存embedding缓存到文件包括扩展标签的embedding"""
try: try:
import orjson
from pathlib import Path
from datetime import datetime from datetime import datetime
from pathlib import Path
import orjson
cache_dir = Path("data/embedding") cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json" cache_file = cache_dir / f"{personality_id}_embeddings.json"
# 准备缓存数据 # 准备缓存数据
current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list else "" current_embedding_model = self.embedding_config.model_list[0] if hasattr(self.embedding_config, "model_list") and self.embedding_config.model_list else ""
cache_data = { cache_data = {
@@ -1071,13 +1072,13 @@ class BotInterestManager:
"embeddings": self.embedding_cache, "embeddings": self.embedding_cache,
"expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding "expanded_embeddings": self.expanded_embedding_cache, # 同时保存扩展标签的embedding
} }
# 写入文件 # 写入文件
with open(cache_file, "wb") as f: with open(cache_file, "wb") as f:
f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2)) f.write(orjson.dumps(cache_data, option=orjson.OPT_INDENT_2))
logger.debug(f"💾 已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}") logger.debug(f"💾 已保存 {len(self.embedding_cache)} 个标签embedding和 {len(self.expanded_embedding_cache)} 个扩展embedding到缓存文件: {cache_file}")
except Exception as e: except Exception as e:
logger.warning(f"⚠️ 保存embedding缓存文件失败: {e}") logger.warning(f"⚠️ 保存embedding缓存文件失败: {e}")

View File

@@ -9,8 +9,8 @@ from .scheduler_dispatcher import SchedulerDispatcher, scheduler_dispatcher
__all__ = [ __all__ = [
"MessageManager", "MessageManager",
"SingleStreamContextManager",
"SchedulerDispatcher", "SchedulerDispatcher",
"SingleStreamContextManager",
"message_manager", "message_manager",
"scheduler_dispatcher", "scheduler_dispatcher",
] ]

View File

@@ -73,7 +73,7 @@ class SingleStreamContextManager:
cache_enabled = global_config.chat.enable_message_cache cache_enabled = global_config.chat.enable_message_cache
use_cache_system = message_manager.is_running and cache_enabled use_cache_system = message_manager.is_running and cache_enabled
if not cache_enabled: if not cache_enabled:
logger.debug(f"消息缓存系统已在配置中禁用") logger.debug("消息缓存系统已在配置中禁用")
except Exception as e: except Exception as e:
logger.debug(f"MessageManager不可用使用直接添加: {e}") logger.debug(f"MessageManager不可用使用直接添加: {e}")
use_cache_system = False use_cache_system = False
@@ -129,13 +129,13 @@ class SingleStreamContextManager:
await self._calculate_message_interest(message) await self._calculate_message_interest(message)
self.total_messages += 1 self.total_messages += 1
self.last_access_time = time.time() self.last_access_time = time.time()
logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}") logger.debug(f"添加消息{message.processed_plain_text}到单流上下文: {self.stream_id}")
return True return True
# 不应该到达这里,但为了类型检查添加返回值 # 不应该到达这里,但为了类型检查添加返回值
return True return True
except Exception as e: except Exception as e:
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True) logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
return False return False

View File

@@ -4,13 +4,11 @@
""" """
import asyncio import asyncio
import random
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
@@ -77,7 +75,7 @@ class MessageManager:
# 启动基于 scheduler 的消息分发器 # 启动基于 scheduler 的消息分发器
await scheduler_dispatcher.start() await scheduler_dispatcher.start()
scheduler_dispatcher.set_chatter_manager(self.chatter_manager) scheduler_dispatcher.set_chatter_manager(self.chatter_manager)
# 保留旧的流循环管理器(暂时)以便平滑过渡 # 保留旧的流循环管理器(暂时)以便平滑过渡
# TODO: 在确认新机制稳定后移除 # TODO: 在确认新机制稳定后移除
# await stream_loop_manager.start() # await stream_loop_manager.start()
@@ -108,7 +106,7 @@ class MessageManager:
# 停止基于 scheduler 的消息分发器 # 停止基于 scheduler 的消息分发器
await scheduler_dispatcher.stop() await scheduler_dispatcher.stop()
# 停止旧的流循环管理器(如果启用) # 停止旧的流循环管理器(如果启用)
# await stream_loop_manager.stop() # await stream_loop_manager.stop()
@@ -116,7 +114,7 @@ class MessageManager:
async def add_message(self, stream_id: str, message: DatabaseMessages): async def add_message(self, stream_id: str, message: DatabaseMessages):
"""添加消息到指定聊天流 """添加消息到指定聊天流
新的流程: 新的流程:
1. 检查 notice 消息 1. 检查 notice 消息
2. 将消息添加到上下文(缓存) 2. 将消息添加到上下文(缓存)
@@ -149,10 +147,10 @@ class MessageManager:
if not chat_stream: if not chat_stream:
logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在") logger.warning(f"MessageManager.add_message: 聊天流 {stream_id} 不存在")
return return
# 将消息添加到上下文 # 将消息添加到上下文
await chat_stream.context_manager.add_message(message) await chat_stream.context_manager.add_message(message)
# 通知 scheduler_dispatcher 处理消息接收事件 # 通知 scheduler_dispatcher 处理消息接收事件
# dispatcher 会检查是否需要打断、创建或更新 schedule # dispatcher 会检查是否需要打断、创建或更新 schedule
await scheduler_dispatcher.on_message_received(stream_id) await scheduler_dispatcher.on_message_received(stream_id)

View File

@@ -20,7 +20,7 @@ logger = get_logger("scheduler_dispatcher")
class SchedulerDispatcher: class SchedulerDispatcher:
"""基于 scheduler 的消息分发器 """基于 scheduler 的消息分发器
工作流程: 工作流程:
1. 接收消息时,将消息添加到聊天流上下文 1. 接收消息时,将消息添加到聊天流上下文
2. 检查是否有活跃的 schedule如果没有则创建 2. 检查是否有活跃的 schedule如果没有则创建
@@ -32,13 +32,13 @@ class SchedulerDispatcher:
def __init__(self): def __init__(self):
# 追踪每个流的 schedule_id # 追踪每个流的 schedule_id
self.stream_schedules: dict[str, str] = {} # stream_id -> schedule_id self.stream_schedules: dict[str, str] = {} # stream_id -> schedule_id
# 用于保护 schedule 创建/删除的锁,避免竞态条件 # 用于保护 schedule 创建/删除的锁,避免竞态条件
self.schedule_locks: dict[str, asyncio.Lock] = {} # stream_id -> Lock self.schedule_locks: dict[str, asyncio.Lock] = {} # stream_id -> Lock
# Chatter 管理器 # Chatter 管理器
self.chatter_manager: ChatterManager | None = None self.chatter_manager: ChatterManager | None = None
# 统计信息 # 统计信息
self.stats = { self.stats = {
"total_schedules_created": 0, "total_schedules_created": 0,
@@ -48,9 +48,9 @@ class SchedulerDispatcher:
"total_failures": 0, "total_failures": 0,
"start_time": time.time(), "start_time": time.time(),
} }
self.is_running = False self.is_running = False
logger.info("基于 Scheduler 的消息分发器初始化完成") logger.info("基于 Scheduler 的消息分发器初始化完成")
async def start(self) -> None: async def start(self) -> None:
@@ -58,7 +58,7 @@ class SchedulerDispatcher:
if self.is_running: if self.is_running:
logger.warning("分发器已在运行") logger.warning("分发器已在运行")
return return
self.is_running = True self.is_running = True
logger.info("基于 Scheduler 的消息分发器已启动") logger.info("基于 Scheduler 的消息分发器已启动")
@@ -66,9 +66,9 @@ class SchedulerDispatcher:
"""停止分发器""" """停止分发器"""
if not self.is_running: if not self.is_running:
return return
self.is_running = False self.is_running = False
# 取消所有活跃的 schedule # 取消所有活跃的 schedule
schedule_ids = list(self.stream_schedules.values()) schedule_ids = list(self.stream_schedules.values())
for schedule_id in schedule_ids: for schedule_id in schedule_ids:
@@ -76,7 +76,7 @@ class SchedulerDispatcher:
await unified_scheduler.remove_schedule(schedule_id) await unified_scheduler.remove_schedule(schedule_id)
except Exception as e: except Exception as e:
logger.error(f"移除 schedule {schedule_id} 失败: {e}") logger.error(f"移除 schedule {schedule_id} 失败: {e}")
self.stream_schedules.clear() self.stream_schedules.clear()
logger.info("基于 Scheduler 的消息分发器已停止") logger.info("基于 Scheduler 的消息分发器已停止")
@@ -84,7 +84,7 @@ class SchedulerDispatcher:
"""设置 Chatter 管理器""" """设置 Chatter 管理器"""
self.chatter_manager = chatter_manager self.chatter_manager = chatter_manager
logger.debug(f"设置 Chatter 管理器: {chatter_manager.__class__.__name__}") logger.debug(f"设置 Chatter 管理器: {chatter_manager.__class__.__name__}")
def _get_schedule_lock(self, stream_id: str) -> asyncio.Lock: def _get_schedule_lock(self, stream_id: str) -> asyncio.Lock:
"""获取流的 schedule 锁""" """获取流的 schedule 锁"""
if stream_id not in self.schedule_locks: if stream_id not in self.schedule_locks:
@@ -93,40 +93,40 @@ class SchedulerDispatcher:
async def on_message_received(self, stream_id: str) -> None: async def on_message_received(self, stream_id: str) -> None:
"""消息接收时的处理逻辑 """消息接收时的处理逻辑
Args: Args:
stream_id: 聊天流ID stream_id: 聊天流ID
""" """
if not self.is_running: if not self.is_running:
logger.warning("分发器未运行,忽略消息") logger.warning("分发器未运行,忽略消息")
return return
try: try:
# 1. 获取流上下文 # 1. 获取流上下文
context = await self._get_stream_context(stream_id) context = await self._get_stream_context(stream_id)
if not context: if not context:
logger.warning(f"无法获取流上下文: {stream_id}") logger.warning(f"无法获取流上下文: {stream_id}")
return return
# 2. 检查是否有活跃的 schedule # 2. 检查是否有活跃的 schedule
has_active_schedule = stream_id in self.stream_schedules has_active_schedule = stream_id in self.stream_schedules
if not has_active_schedule: if not has_active_schedule:
# 4. 创建新的 schedule在锁内避免重复创建 # 4. 创建新的 schedule在锁内避免重复创建
await self._create_schedule(stream_id, context) await self._create_schedule(stream_id, context)
return return
# 3. 检查打断判定 # 3. 检查打断判定
if has_active_schedule: if has_active_schedule:
should_interrupt = await self._check_interruption(stream_id, context) should_interrupt = await self._check_interruption(stream_id, context)
if should_interrupt: if should_interrupt:
# 移除旧 schedule 并创建新的(内部有锁保护) # 移除旧 schedule 并创建新的(内部有锁保护)
await self._cancel_and_recreate_schedule(stream_id, context) await self._cancel_and_recreate_schedule(stream_id, context)
logger.debug(f"⚡ 打断成功: 流={stream_id[:8]}..., 已重新创建 schedule") logger.debug(f"⚡ 打断成功: 流={stream_id[:8]}..., 已重新创建 schedule")
else: else:
logger.debug(f"打断判定失败,保持原有 schedule: 流={stream_id[:8]}...") logger.debug(f"打断判定失败,保持原有 schedule: 流={stream_id[:8]}...")
except Exception as e: except Exception as e:
logger.error(f"处理消息接收事件失败 {stream_id}: {e}", exc_info=True) logger.error(f"处理消息接收事件失败 {stream_id}: {e}", exc_info=True)
@@ -144,18 +144,18 @@ class SchedulerDispatcher:
async def _check_interruption(self, stream_id: str, context: StreamContext) -> bool: async def _check_interruption(self, stream_id: str, context: StreamContext) -> bool:
"""检查是否应该打断当前处理 """检查是否应该打断当前处理
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
Returns: Returns:
bool: 是否应该打断 bool: 是否应该打断
""" """
# 检查是否启用打断 # 检查是否启用打断
if not global_config.chat.interruption_enabled: if not global_config.chat.interruption_enabled:
return False return False
# 检查是否正在回复,以及是否允许在回复时打断 # 检查是否正在回复,以及是否允许在回复时打断
if context.is_replying: if context.is_replying:
if not global_config.chat.allow_reply_interruption: if not global_config.chat.allow_reply_interruption:
@@ -163,49 +163,49 @@ class SchedulerDispatcher:
return False return False
else: else:
logger.debug(f"聊天流 {stream_id} 正在回复中,但配置允许回复时打断") logger.debug(f"聊天流 {stream_id} 正在回复中,但配置允许回复时打断")
# 只有当 Chatter 真正在处理时才检查打断 # 只有当 Chatter 真正在处理时才检查打断
if not context.is_chatter_processing: if not context.is_chatter_processing:
logger.debug(f"聊天流 {stream_id} Chatter 未在处理,无需打断") logger.debug(f"聊天流 {stream_id} Chatter 未在处理,无需打断")
return False return False
# 检查最后一条消息 # 检查最后一条消息
last_message = context.get_last_message() last_message = context.get_last_message()
if not last_message: if not last_message:
return False return False
# 检查是否为表情包消息 # 检查是否为表情包消息
if last_message.is_picid or last_message.is_emoji: if last_message.is_picid or last_message.is_emoji:
logger.info(f"消息 {last_message.message_id} 是表情包或Emoji跳过打断检查") logger.info(f"消息 {last_message.message_id} 是表情包或Emoji跳过打断检查")
return False return False
# 检查触发用户ID # 检查触发用户ID
triggering_user_id = context.triggering_user_id triggering_user_id = context.triggering_user_id
if triggering_user_id and last_message.user_info.user_id != triggering_user_id: if triggering_user_id and last_message.user_info.user_id != triggering_user_id:
logger.info(f"消息来自非触发用户 {last_message.user_info.user_id},实际触发用户为 {triggering_user_id},跳过打断检查") logger.info(f"消息来自非触发用户 {last_message.user_info.user_id},实际触发用户为 {triggering_user_id},跳过打断检查")
return False return False
# 检查是否已达到最大打断次数 # 检查是否已达到最大打断次数
if context.interruption_count >= global_config.chat.interruption_max_limit: if context.interruption_count >= global_config.chat.interruption_max_limit:
logger.debug( logger.debug(
f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit}" f"聊天流 {stream_id} 已达到最大打断次数 {context.interruption_count}/{global_config.chat.interruption_max_limit}"
) )
return False return False
# 计算打断概率 # 计算打断概率
interruption_probability = context.calculate_interruption_probability( interruption_probability = context.calculate_interruption_probability(
global_config.chat.interruption_max_limit global_config.chat.interruption_max_limit
) )
# 根据概率决定是否打断 # 根据概率决定是否打断
import random import random
if random.random() < interruption_probability: if random.random() < interruption_probability:
logger.debug(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}") logger.debug(f"聊天流 {stream_id} 触发消息打断,打断概率: {interruption_probability:.2f}")
# 增加打断计数 # 增加打断计数
await context.increment_interruption_count() await context.increment_interruption_count()
self.stats["total_interruptions"] += 1 self.stats["total_interruptions"] += 1
# 检查是否已达到最大次数 # 检查是否已达到最大次数
if context.interruption_count >= global_config.chat.interruption_max_limit: if context.interruption_count >= global_config.chat.interruption_max_limit:
logger.warning( logger.warning(
@@ -215,7 +215,7 @@ class SchedulerDispatcher:
logger.info( logger.info(
f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}" f"聊天流 {stream_id} 已打断,当前打断次数: {context.interruption_count}/{global_config.chat.interruption_max_limit}"
) )
return True return True
else: else:
logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}") logger.debug(f"聊天流 {stream_id} 未触发打断,打断概率: {interruption_probability:.2f}")
@@ -223,7 +223,7 @@ class SchedulerDispatcher:
async def _cancel_and_recreate_schedule(self, stream_id: str, context: StreamContext) -> None: async def _cancel_and_recreate_schedule(self, stream_id: str, context: StreamContext) -> None:
"""取消旧的 schedule 并创建新的(打断模式,使用极短延迟) """取消旧的 schedule 并创建新的(打断模式,使用极短延迟)
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
@@ -244,13 +244,13 @@ class SchedulerDispatcher:
) )
# 移除失败,不创建新 schedule避免重复 # 移除失败,不创建新 schedule避免重复
return return
# 创建新的 schedule使用即时处理模式极短延迟 # 创建新的 schedule使用即时处理模式极短延迟
await self._create_schedule(stream_id, context, immediate_mode=True) await self._create_schedule(stream_id, context, immediate_mode=True)
async def _create_schedule(self, stream_id: str, context: StreamContext, immediate_mode: bool = False) -> None: async def _create_schedule(self, stream_id: str, context: StreamContext, immediate_mode: bool = False) -> None:
"""为聊天流创建新的 schedule """为聊天流创建新的 schedule
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
@@ -266,7 +266,7 @@ class SchedulerDispatcher:
) )
await unified_scheduler.remove_schedule(old_schedule_id) await unified_scheduler.remove_schedule(old_schedule_id)
del self.stream_schedules[stream_id] del self.stream_schedules[stream_id]
# 如果是即时处理模式打断时使用固定的1秒延迟立即重新处理 # 如果是即时处理模式打断时使用固定的1秒延迟立即重新处理
if immediate_mode: if immediate_mode:
delay = 1.0 # 硬编码1秒延迟确保打断后能快速重新处理 delay = 1.0 # 硬编码1秒延迟确保打断后能快速重新处理
@@ -277,10 +277,10 @@ class SchedulerDispatcher:
else: else:
# 常规模式:计算初始延迟 # 常规模式:计算初始延迟
delay = await self._calculate_initial_delay(stream_id, context) delay = await self._calculate_initial_delay(stream_id, context)
# 获取未读消息数量用于日志 # 获取未读消息数量用于日志
unread_count = len(context.unread_messages) if context.unread_messages else 0 unread_count = len(context.unread_messages) if context.unread_messages else 0
# 创建 schedule # 创建 schedule
schedule_id = await unified_scheduler.create_schedule( schedule_id = await unified_scheduler.create_schedule(
callback=self._on_schedule_triggered, callback=self._on_schedule_triggered,
@@ -290,41 +290,41 @@ class SchedulerDispatcher:
task_name=f"dispatch_{stream_id[:8]}", task_name=f"dispatch_{stream_id[:8]}",
callback_args=(stream_id,), callback_args=(stream_id,),
) )
# 追踪 schedule # 追踪 schedule
self.stream_schedules[stream_id] = schedule_id self.stream_schedules[stream_id] = schedule_id
self.stats["total_schedules_created"] += 1 self.stats["total_schedules_created"] += 1
mode_indicator = "⚡打断" if immediate_mode else "📅常规" mode_indicator = "⚡打断" if immediate_mode else "📅常规"
logger.info( logger.info(
f"{mode_indicator} 创建 schedule: 流={stream_id[:8]}..., " f"{mode_indicator} 创建 schedule: 流={stream_id[:8]}..., "
f"延迟={delay:.3f}s, 未读={unread_count}, " f"延迟={delay:.3f}s, 未读={unread_count}, "
f"ID={schedule_id[:8]}..." f"ID={schedule_id[:8]}..."
) )
except Exception as e: except Exception as e:
logger.error(f"创建 schedule 失败 {stream_id}: {e}", exc_info=True) logger.error(f"创建 schedule 失败 {stream_id}: {e}", exc_info=True)
async def _calculate_initial_delay(self, stream_id: str, context: StreamContext) -> float: async def _calculate_initial_delay(self, stream_id: str, context: StreamContext) -> float:
"""计算初始延迟时间 """计算初始延迟时间
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
Returns: Returns:
float: 延迟时间(秒) float: 延迟时间(秒)
""" """
# 基础间隔 # 基础间隔
base_interval = getattr(global_config.chat, "distribution_interval", 5.0) base_interval = getattr(global_config.chat, "distribution_interval", 5.0)
# 检查是否有未读消息 # 检查是否有未读消息
unread_count = len(context.unread_messages) if context.unread_messages else 0 unread_count = len(context.unread_messages) if context.unread_messages else 0
# 强制分发阈值 # 强制分发阈值
force_dispatch_threshold = getattr(global_config.chat, "force_dispatch_unread_threshold", 20) force_dispatch_threshold = getattr(global_config.chat, "force_dispatch_unread_threshold", 20)
# 如果未读消息过多,使用最小间隔 # 如果未读消息过多,使用最小间隔
if force_dispatch_threshold and unread_count > force_dispatch_threshold: if force_dispatch_threshold and unread_count > force_dispatch_threshold:
min_interval = getattr(global_config.chat, "force_dispatch_min_interval", 0.1) min_interval = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
@@ -334,24 +334,24 @@ class SchedulerDispatcher:
f"使用最小间隔={min_interval}s" f"使用最小间隔={min_interval}s"
) )
return min_interval return min_interval
# 尝试使用能量管理器计算间隔 # 尝试使用能量管理器计算间隔
try: try:
# 更新能量值 # 更新能量值
await self._update_stream_energy(stream_id, context) await self._update_stream_energy(stream_id, context)
# 获取当前 focus_energy # 获取当前 focus_energy
focus_energy = energy_manager.energy_cache.get(stream_id, (0.5, 0))[0] focus_energy = energy_manager.energy_cache.get(stream_id, (0.5, 0))[0]
# 使用能量管理器计算间隔 # 使用能量管理器计算间隔
interval = energy_manager.get_distribution_interval(focus_energy) interval = energy_manager.get_distribution_interval(focus_energy)
logger.info( logger.info(
f"📊 动态间隔计算: 流={stream_id[:8]}..., " f"📊 动态间隔计算: 流={stream_id[:8]}..., "
f"能量={focus_energy:.3f}, 间隔={interval:.2f}s" f"能量={focus_energy:.3f}, 间隔={interval:.2f}s"
) )
return interval return interval
except Exception as e: except Exception as e:
logger.info( logger.info(
f"📊 使用默认间隔: 流={stream_id[:8]}..., " f"📊 使用默认间隔: 流={stream_id[:8]}..., "
@@ -361,96 +361,96 @@ class SchedulerDispatcher:
async def _update_stream_energy(self, stream_id: str, context: StreamContext) -> None: async def _update_stream_energy(self, stream_id: str, context: StreamContext) -> None:
"""更新流的能量值 """更新流的能量值
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
""" """
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
# 获取聊天流 # 获取聊天流
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
chat_stream = await chat_manager.get_stream(stream_id) chat_stream = await chat_manager.get_stream(stream_id)
if not chat_stream: if not chat_stream:
logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新") logger.debug(f"无法找到聊天流 {stream_id},跳过能量更新")
return return
# 合并未读消息和历史消息 # 合并未读消息和历史消息
all_messages = [] all_messages = []
# 添加历史消息 # 添加历史消息
history_messages = context.get_history_messages(limit=global_config.chat.max_context_size) history_messages = context.get_history_messages(limit=global_config.chat.max_context_size)
all_messages.extend(history_messages) all_messages.extend(history_messages)
# 添加未读消息 # 添加未读消息
unread_messages = context.get_unread_messages() unread_messages = context.get_unread_messages()
all_messages.extend(unread_messages) all_messages.extend(unread_messages)
# 按时间排序并限制数量 # 按时间排序并限制数量
all_messages.sort(key=lambda m: m.time) all_messages.sort(key=lambda m: m.time)
messages = all_messages[-global_config.chat.max_context_size:] messages = all_messages[-global_config.chat.max_context_size:]
# 获取用户ID # 获取用户ID
user_id = context.triggering_user_id user_id = context.triggering_user_id
# 使用能量管理器计算并缓存能量值 # 使用能量管理器计算并缓存能量值
energy = await energy_manager.calculate_focus_energy( energy = await energy_manager.calculate_focus_energy(
stream_id=stream_id, stream_id=stream_id,
messages=messages, messages=messages,
user_id=user_id user_id=user_id
) )
# 同步更新到 ChatStream # 同步更新到 ChatStream
chat_stream._focus_energy = energy chat_stream._focus_energy = energy
logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}") logger.debug(f"已更新流 {stream_id} 的能量值: {energy:.3f}")
except Exception as e: except Exception as e:
logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False) logger.warning(f"更新流能量失败 {stream_id}: {e}", exc_info=False)
async def _on_schedule_triggered(self, stream_id: str) -> None: async def _on_schedule_triggered(self, stream_id: str) -> None:
"""schedule 触发时的回调 """schedule 触发时的回调
Args: Args:
stream_id: 流ID stream_id: 流ID
""" """
try: try:
old_schedule_id = self.stream_schedules.get(stream_id) old_schedule_id = self.stream_schedules.get(stream_id)
logger.info( logger.info(
f"⏰ Schedule 触发: 流={stream_id[:8]}..., " f"⏰ Schedule 触发: 流={stream_id[:8]}..., "
f"ID={old_schedule_id[:8] if old_schedule_id else 'None'}..., " f"ID={old_schedule_id[:8] if old_schedule_id else 'None'}..., "
f"开始处理消息" f"开始处理消息"
) )
# 获取流上下文 # 获取流上下文
context = await self._get_stream_context(stream_id) context = await self._get_stream_context(stream_id)
if not context: if not context:
logger.warning(f"Schedule 触发时无法获取流上下文: {stream_id}") logger.warning(f"Schedule 触发时无法获取流上下文: {stream_id}")
return return
# 检查是否有未读消息 # 检查是否有未读消息
if not context.unread_messages: if not context.unread_messages:
logger.debug(f"{stream_id} 没有未读消息,跳过处理") logger.debug(f"{stream_id} 没有未读消息,跳过处理")
return return
# 激活 chatter 处理(不需要锁,允许并发处理) # 激活 chatter 处理(不需要锁,允许并发处理)
success = await self._process_stream(stream_id, context) success = await self._process_stream(stream_id, context)
# 更新统计 # 更新统计
self.stats["total_process_cycles"] += 1 self.stats["total_process_cycles"] += 1
if not success: if not success:
self.stats["total_failures"] += 1 self.stats["total_failures"] += 1
self.stream_schedules.pop(stream_id, None) self.stream_schedules.pop(stream_id, None)
# 检查缓存中是否有待处理的消息 # 检查缓存中是否有待处理的消息
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
has_cached = message_manager.has_cached_messages(stream_id) has_cached = message_manager.has_cached_messages(stream_id)
if has_cached: if has_cached:
# 有缓存消息,立即创建新 schedule 继续处理 # 有缓存消息,立即创建新 schedule 继续处理
logger.info( logger.info(
@@ -464,60 +464,60 @@ class SchedulerDispatcher:
f"✅ 处理完成且无缓存消息: 流={stream_id[:8]}..., " f"✅ 处理完成且无缓存消息: 流={stream_id[:8]}..., "
f"等待新消息到达" f"等待新消息到达"
) )
except Exception as e: except Exception as e:
logger.error(f"Schedule 回调执行失败 {stream_id}: {e}", exc_info=True) logger.error(f"Schedule 回调执行失败 {stream_id}: {e}", exc_info=True)
async def _process_stream(self, stream_id: str, context: StreamContext) -> bool: async def _process_stream(self, stream_id: str, context: StreamContext) -> bool:
"""处理流消息 """处理流消息
Args: Args:
stream_id: 流ID stream_id: 流ID
context: 流上下文 context: 流上下文
Returns: Returns:
bool: 是否处理成功 bool: 是否处理成功
""" """
if not self.chatter_manager: if not self.chatter_manager:
logger.warning(f"Chatter 管理器未设置: {stream_id}") logger.warning(f"Chatter 管理器未设置: {stream_id}")
return False return False
# 设置处理状态 # 设置处理状态
self._set_stream_processing_status(stream_id, True) self._set_stream_processing_status(stream_id, True)
try: try:
start_time = time.time() start_time = time.time()
# 设置触发用户ID # 设置触发用户ID
last_message = context.get_last_message() last_message = context.get_last_message()
if last_message: if last_message:
context.triggering_user_id = last_message.user_info.user_id context.triggering_user_id = last_message.user_info.user_id
# 创建异步任务刷新能量(不阻塞主流程) # 创建异步任务刷新能量(不阻塞主流程)
energy_task = asyncio.create_task(self._refresh_focus_energy(stream_id)) energy_task = asyncio.create_task(self._refresh_focus_energy(stream_id))
# 设置 Chatter 正在处理的标志 # 设置 Chatter 正在处理的标志
context.is_chatter_processing = True context.is_chatter_processing = True
logger.debug(f"设置 Chatter 处理标志: {stream_id}") logger.debug(f"设置 Chatter 处理标志: {stream_id}")
try: try:
# 调用 chatter_manager 处理流上下文 # 调用 chatter_manager 处理流上下文
results = await self.chatter_manager.process_stream_context(stream_id, context) results = await self.chatter_manager.process_stream_context(stream_id, context)
success = results.get("success", False) success = results.get("success", False)
if success: if success:
process_time = time.time() - start_time process_time = time.time() - start_time
logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)") logger.debug(f"流处理成功: {stream_id} (耗时: {process_time:.2f}s)")
else: else:
logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}") logger.warning(f"流处理失败: {stream_id} - {results.get('error_message', '未知错误')}")
return success return success
finally: finally:
# 清除 Chatter 处理标志 # 清除 Chatter 处理标志
context.is_chatter_processing = False context.is_chatter_processing = False
logger.debug(f"清除 Chatter 处理标志: {stream_id}") logger.debug(f"清除 Chatter 处理标志: {stream_id}")
# 等待能量刷新任务完成 # 等待能量刷新任务完成
try: try:
await asyncio.wait_for(energy_task, timeout=5.0) await asyncio.wait_for(energy_task, timeout=5.0)
@@ -525,11 +525,11 @@ class SchedulerDispatcher:
logger.warning(f"等待能量刷新超时: {stream_id}") logger.warning(f"等待能量刷新超时: {stream_id}")
except Exception as e: except Exception as e:
logger.debug(f"能量刷新任务异常: {e}") logger.debug(f"能量刷新任务异常: {e}")
except Exception as e: except Exception as e:
logger.error(f"流处理异常: {stream_id} - {e}", exc_info=True) logger.error(f"流处理异常: {stream_id} - {e}", exc_info=True)
return False return False
finally: finally:
# 设置处理状态为未处理 # 设置处理状态为未处理
self._set_stream_processing_status(stream_id, False) self._set_stream_processing_status(stream_id, False)
@@ -538,11 +538,11 @@ class SchedulerDispatcher:
"""设置流的处理状态""" """设置流的处理状态"""
try: try:
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
if message_manager.is_running: if message_manager.is_running:
message_manager.set_stream_processing_status(stream_id, is_processing) message_manager.set_stream_processing_status(stream_id, is_processing)
logger.debug(f"设置流处理状态: stream={stream_id}, processing={is_processing}") logger.debug(f"设置流处理状态: stream={stream_id}, processing={is_processing}")
except ImportError: except ImportError:
logger.debug("MessageManager 不可用,跳过状态设置") logger.debug("MessageManager 不可用,跳过状态设置")
except Exception as e: except Exception as e:
@@ -556,7 +556,7 @@ class SchedulerDispatcher:
if not chat_stream: if not chat_stream:
logger.debug(f"刷新能量时未找到聊天流: {stream_id}") logger.debug(f"刷新能量时未找到聊天流: {stream_id}")
return return
await chat_stream.context_manager.refresh_focus_energy_from_history() await chat_stream.context_manager.refresh_focus_energy_from_history()
logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量") logger.debug(f"已刷新聊天流 {stream_id} 的聚焦能量")
except Exception as e: except Exception as e:

View File

@@ -367,7 +367,7 @@ class ChatBot:
message_segment = message_data.get("message_segment") message_segment = message_data.get("message_segment")
if message_segment and isinstance(message_segment, dict): if message_segment and isinstance(message_segment, dict):
if message_segment.get("type") == "adapter_response": if message_segment.get("type") == "adapter_response":
logger.info(f"[DEBUG bot.py message_process] 检测到adapter_response立即处理") logger.info("[DEBUG bot.py message_process] 检测到adapter_response立即处理")
await self._handle_adapter_response_from_dict(message_segment.get("data")) await self._handle_adapter_response_from_dict(message_segment.get("data"))
return return

View File

@@ -205,7 +205,7 @@ async def _process_single_segment(segment: Seg, state: dict, message_info: BaseM
return result return result
else: else:
logger.warning(f"[at处理] 无法解析格式: '{segment.data}'") logger.warning(f"[at处理] 无法解析格式: '{segment.data}'")
return f"@{segment.data}" return f"@{segment.data}"
logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}") logger.warning(f"[at处理] 数据类型异常: {type(segment.data)}")
return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户" return f"@{segment.data}" if isinstance(segment.data, str) else "@未知用户"

View File

@@ -542,7 +542,7 @@ class DefaultReplyer:
all_memories = [] all_memories = []
try: try:
from src.memory_graph.manager_singleton import get_memory_manager, is_initialized from src.memory_graph.manager_singleton import get_memory_manager, is_initialized
if is_initialized(): if is_initialized():
manager = get_memory_manager() manager = get_memory_manager()
if manager: if manager:
@@ -552,12 +552,12 @@ class DefaultReplyer:
sender_name = "" sender_name = ""
if user_info_obj: if user_info_obj:
sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "") sender_name = getattr(user_info_obj, "user_nickname", "") or getattr(user_info_obj, "user_cardname", "")
# 获取参与者信息 # 获取参与者信息
participants = [] participants = []
try: try:
# 尝试从聊天流中获取参与者信息 # 尝试从聊天流中获取参与者信息
if hasattr(stream, 'chat_history_manager'): if hasattr(stream, "chat_history_manager"):
history_manager = stream.chat_history_manager history_manager = stream.chat_history_manager
# 获取最近的参与者列表 # 获取最近的参与者列表
recent_records = history_manager.get_memory_chat_history( recent_records = history_manager.get_memory_chat_history(
@@ -586,16 +586,16 @@ class DefaultReplyer:
formatted_history = "" formatted_history = ""
if chat_history: if chat_history:
# 移除过长的历史记录,只保留最近部分 # 移除过长的历史记录,只保留最近部分
lines = chat_history.strip().split('\n') lines = chat_history.strip().split("\n")
recent_lines = lines[-10:] if len(lines) > 10 else lines recent_lines = lines[-10:] if len(lines) > 10 else lines
formatted_history = '\n'.join(recent_lines) formatted_history = "\n".join(recent_lines)
query_context = { query_context = {
"chat_history": formatted_history, "chat_history": formatted_history,
"sender": sender_name, "sender": sender_name,
"participants": participants, "participants": participants,
} }
# 使用记忆管理器的智能检索(多查询策略) # 使用记忆管理器的智能检索(多查询策略)
memories = await manager.search_memories( memories = await manager.search_memories(
query=target, query=target,
@@ -605,23 +605,23 @@ class DefaultReplyer:
use_multi_query=True, use_multi_query=True,
context=query_context, context=query_context,
) )
if memories: if memories:
logger.info(f"[记忆图] 检索到 {len(memories)} 条相关记忆") logger.info(f"[记忆图] 检索到 {len(memories)} 条相关记忆")
# 使用新的格式化工具构建完整的记忆描述 # 使用新的格式化工具构建完整的记忆描述
from src.memory_graph.utils.memory_formatter import ( from src.memory_graph.utils.memory_formatter import (
format_memory_for_prompt, format_memory_for_prompt,
get_memory_type_label, get_memory_type_label,
) )
for memory in memories: for memory in memories:
# 使用格式化工具生成完整的主谓宾描述 # 使用格式化工具生成完整的主谓宾描述
content = format_memory_for_prompt(memory, include_metadata=False) content = format_memory_for_prompt(memory, include_metadata=False)
# 获取记忆类型 # 获取记忆类型
mem_type = memory.memory_type.value if memory.memory_type else "未知" mem_type = memory.memory_type.value if memory.memory_type else "未知"
if content: if content:
all_memories.append({ all_memories.append({
"content": content, "content": content,
@@ -636,7 +636,7 @@ class DefaultReplyer:
except Exception as e: except Exception as e:
logger.debug(f"[记忆图] 检索失败: {e}") logger.debug(f"[记忆图] 检索失败: {e}")
all_memories = [] all_memories = []
# 构建记忆字符串,使用方括号格式 # 构建记忆字符串,使用方括号格式
memory_str = "" memory_str = ""
has_any_memory = False has_any_memory = False
@@ -725,7 +725,7 @@ class DefaultReplyer:
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result") tool_result.get("type", "tool_result")
# 不进行截断,让工具自己处理结果长度 # 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}") current_results_parts.append(f"- **{tool_name}**: {content}")
@@ -744,7 +744,7 @@ class DefaultReplyer:
logger.error(f"工具信息获取失败: {e}") logger.error(f"工具信息获取失败: {e}")
return "" return ""
def _parse_reply_target(self, target_message: str) -> tuple[str, str]: def _parse_reply_target(self, target_message: str) -> tuple[str, str]:
"""解析回复目标消息 - 使用共享工具""" """解析回复目标消息 - 使用共享工具"""
from src.chat.utils.prompt import Prompt from src.chat.utils.prompt import Prompt
@@ -1897,7 +1897,7 @@ class DefaultReplyer:
async def _store_chat_memory_async(self, reply_to: str, reply_message: DatabaseMessages | dict[str, Any] | None = None): async def _store_chat_memory_async(self, reply_to: str, reply_message: DatabaseMessages | dict[str, Any] | None = None):
""" """
[已废弃] 异步存储聊天记忆从build_memory_block迁移而来 [已废弃] 异步存储聊天记忆从build_memory_block迁移而来
此函数已被记忆图系统的工具调用方式替代。 此函数已被记忆图系统的工具调用方式替代。
记忆现在由LLM在对话过程中通过CreateMemoryTool主动创建。 记忆现在由LLM在对话过程中通过CreateMemoryTool主动创建。
@@ -1906,14 +1906,13 @@ class DefaultReplyer:
reply_message: 回复的原始消息 reply_message: 回复的原始消息
""" """
return # 已禁用,保留函数签名以防其他地方有引用 return # 已禁用,保留函数签名以防其他地方有引用
# 以下代码已废弃,不再执行 # 以下代码已废弃,不再执行
try: try:
if not global_config.memory.enable_memory: if not global_config.memory.enable_memory:
return return
# 使用统一记忆系统存储记忆 # 使用统一记忆系统存储记忆
from src.chat.memory_system import get_memory_system
stream = self.chat_stream stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None) user_info_obj = getattr(stream, "user_info", None)
@@ -2036,7 +2035,7 @@ class DefaultReplyer:
timestamp=time.time(), timestamp=time.time(),
limit=int(global_config.chat.max_context_size), limit=int(global_config.chat.max_context_size),
) )
chat_history = await build_readable_messages( await build_readable_messages(
message_list_before_short, message_list_before_short,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,

View File

@@ -400,7 +400,7 @@ class Prompt:
# 初始化预构建参数字典 # 初始化预构建参数字典
pre_built_params = {} pre_built_params = {}
try: try:
# --- 步骤 1: 准备构建任务 --- # --- 步骤 1: 准备构建任务 ---
tasks = [] tasks = []

View File

@@ -87,20 +87,18 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
) )
processed_text = message.processed_plain_text or "" processed_text = message.processed_plain_text or ""
# 1. 判断是否为私聊(强提及) # 1. 判断是否为私聊(强提及)
group_info = getattr(message, "group_info", None) group_info = getattr(message, "group_info", None)
if not group_info or not getattr(group_info, "group_id", None): if not group_info or not getattr(group_info, "group_id", None):
is_private = True
mention_type = 2 mention_type = 2
logger.debug("检测到私聊消息 - 强提及") logger.debug("检测到私聊消息 - 强提及")
# 2. 判断是否被@(强提及) # 2. 判断是否被@(强提及)
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text): if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text):
is_at = True
mention_type = 2 mention_type = 2
logger.debug("检测到@提及 - 强提及") logger.debug("检测到@提及 - 强提及")
# 3. 判断是否被回复(强提及) # 3. 判断是否被回复(强提及)
if re.match( if re.match(
rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\)(.+?)\],说:", processed_text rf"\[回复 (.+?)\({global_config.bot.qq_account!s}\)(.+?)\],说:", processed_text
@@ -108,10 +106,9 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:", rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:",
processed_text, processed_text,
): ):
is_replied = True
mention_type = 2 mention_type = 2
logger.debug("检测到回复消息 - 强提及") logger.debug("检测到回复消息 - 强提及")
# 4. 判断文本中是否提及bot名字或别名弱提及 # 4. 判断文本中是否提及bot名字或别名弱提及
if mention_type == 0: # 只有在没有强提及时才检查弱提及 if mention_type == 0: # 只有在没有强提及时才检查弱提及
# 移除@和回复标记后再检查 # 移除@和回复标记后再检查
@@ -119,21 +116,19 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content) message_content = re.sub(r"@<(.+?)(?=:(\d+))\:(\d+)>", "", message_content)
message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\)(.+?)\],说:", "", message_content) message_content = re.sub(r"\[回复 (.+?)\(((\d+)|未知id)\)(.+?)\],说:", "", message_content)
message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", message_content) message_content = re.sub(r"\[回复<(.+?)(?=:(\d+))\:(\d+)>(.+?)\],说:", "", message_content)
# 检查bot主名字 # 检查bot主名字
if global_config.bot.nickname in message_content: if global_config.bot.nickname in message_content:
is_text_mentioned = True
mention_type = 1 mention_type = 1
logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及") logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及")
# 如果主名字没匹配,再检查别名 # 如果主名字没匹配,再检查别名
elif nicknames: elif nicknames:
for alias_name in nicknames: for alias_name in nicknames:
if alias_name in message_content: if alias_name in message_content:
is_text_mentioned = True
mention_type = 1 mention_type = 1
logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及") logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及")
break break
# 返回结果 # 返回结果
is_mentioned = mention_type > 0 is_mentioned = mention_type > 0
return is_mentioned, float(mention_type) return is_mentioned, float(mention_type)

View File

@@ -368,13 +368,13 @@ class CacheManager:
if expired_keys: if expired_keys:
logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目") logger.info(f"清理了 {len(expired_keys)} 个过期的L1缓存条目")
def get_health_stats(self) -> dict[str, Any]: def get_health_stats(self) -> dict[str, Any]:
"""获取缓存健康统计信息""" """获取缓存健康统计信息"""
# 简化的健康统计,不包含内存监控(因为相关属性未定义) # 简化的健康统计,不包含内存监控(因为相关属性未定义)
return { return {
"l1_count": len(self.l1_kv_cache), "l1_count": len(self.l1_kv_cache),
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0, "l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0,
"tool_stats": { "tool_stats": {
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0), "total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})), "tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
@@ -397,7 +397,7 @@ class CacheManager:
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}") warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
# 检查向量索引大小 # 检查向量索引大小
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0 vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0
if isinstance(vector_count, int) and vector_count > 500: if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}") warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")

View File

@@ -66,7 +66,7 @@ class BatchStats:
last_batch_duration: float = 0.0 last_batch_duration: float = 0.0
last_batch_size: int = 0 last_batch_size: int = 0
congestion_score: float = 0.0 # 拥塞评分 (0-1) congestion_score: float = 0.0 # 拥塞评分 (0-1)
# 🔧 新增:缓存统计 # 🔧 新增:缓存统计
cache_size: int = 0 # 缓存条目数 cache_size: int = 0 # 缓存条目数
cache_memory_mb: float = 0.0 # 缓存内存占用MB cache_memory_mb: float = 0.0 # 缓存内存占用MB
@@ -539,8 +539,7 @@ class AdaptiveBatchScheduler:
def _set_cache(self, cache_key: str, result: Any) -> None: def _set_cache(self, cache_key: str, result: Any) -> None:
"""设置缓存(改进版,带大小限制和内存统计)""" """设置缓存(改进版,带大小限制和内存统计)"""
import sys
# 🔧 检查缓存大小限制 # 🔧 检查缓存大小限制
if len(self._result_cache) >= self._cache_max_size: if len(self._result_cache) >= self._cache_max_size:
# 首先清理过期条目 # 首先清理过期条目
@@ -549,18 +548,18 @@ class AdaptiveBatchScheduler:
k for k, (_, ts) in self._result_cache.items() k for k, (_, ts) in self._result_cache.items()
if current_time - ts >= self.cache_ttl if current_time - ts >= self.cache_ttl
] ]
for k in expired_keys: for k in expired_keys:
# 更新内存统计 # 更新内存统计
if k in self._cache_size_map: if k in self._cache_size_map:
self._cache_memory_estimate -= self._cache_size_map[k] self._cache_memory_estimate -= self._cache_size_map[k]
del self._cache_size_map[k] del self._cache_size_map[k]
del self._result_cache[k] del self._result_cache[k]
# 如果还是太大清理最老的条目LRU # 如果还是太大清理最老的条目LRU
if len(self._result_cache) >= self._cache_max_size: if len(self._result_cache) >= self._cache_max_size:
oldest_key = min( oldest_key = min(
self._result_cache.keys(), self._result_cache.keys(),
key=lambda k: self._result_cache[k][1] key=lambda k: self._result_cache[k][1]
) )
# 更新内存统计 # 更新内存统计
@@ -569,7 +568,7 @@ class AdaptiveBatchScheduler:
del self._cache_size_map[oldest_key] del self._cache_size_map[oldest_key]
del self._result_cache[oldest_key] del self._result_cache[oldest_key]
logger.debug(f"缓存已满,淘汰最老条目: {oldest_key}") logger.debug(f"缓存已满,淘汰最老条目: {oldest_key}")
# 🔧 使用准确的内存估算方法 # 🔧 使用准确的内存估算方法
try: try:
total_size = estimate_size_smart(cache_key) + estimate_size_smart(result) total_size = estimate_size_smart(cache_key) + estimate_size_smart(result)
@@ -580,7 +579,7 @@ class AdaptiveBatchScheduler:
# 使用默认值 # 使用默认值
self._cache_size_map[cache_key] = 1024 self._cache_size_map[cache_key] = 1024
self._cache_memory_estimate += 1024 self._cache_memory_estimate += 1024
self._result_cache[cache_key] = (result, time.time()) self._result_cache[cache_key] = (result, time.time())
async def get_stats(self) -> BatchStats: async def get_stats(self) -> BatchStats:

View File

@@ -171,7 +171,7 @@ class LRUCache(Generic[T]):
) )
else: else:
adjusted_created_at = now adjusted_created_at = now
entry = CacheEntry( entry = CacheEntry(
value=value, value=value,
created_at=adjusted_created_at, created_at=adjusted_created_at,
@@ -345,7 +345,7 @@ class MultiLevelCache:
# 估算数据大小(如果未提供) # 估算数据大小(如果未提供)
if size is None: if size is None:
size = estimate_size_smart(value) size = estimate_size_smart(value)
# 检查单个条目大小是否超过限制 # 检查单个条目大小是否超过限制
if size > self.max_item_size_bytes: if size > self.max_item_size_bytes:
logger.warning( logger.warning(
@@ -354,7 +354,7 @@ class MultiLevelCache:
f"limit={self.max_item_size_bytes / (1024 * 1024):.2f}MB" f"limit={self.max_item_size_bytes / (1024 * 1024):.2f}MB"
) )
return return
# 根据TTL决定写入哪个缓存层 # 根据TTL决定写入哪个缓存层
if ttl is not None: if ttl is not None:
# 有自定义TTL根据TTL大小决定写入层级 # 有自定义TTL根据TTL大小决定写入层级
@@ -394,37 +394,37 @@ class MultiLevelCache:
"""获取所有缓存层的统计信息(修正版,避免重复计数)""" """获取所有缓存层的统计信息(修正版,避免重复计数)"""
l1_stats = await self.l1_cache.get_stats() l1_stats = await self.l1_cache.get_stats()
l2_stats = await self.l2_cache.get_stats() l2_stats = await self.l2_cache.get_stats()
# 🔧 修复计算实际独占的内存避免L1和L2共享数据的重复计数 # 🔧 修复计算实际独占的内存避免L1和L2共享数据的重复计数
l1_keys = set(self.l1_cache._cache.keys()) l1_keys = set(self.l1_cache._cache.keys())
l2_keys = set(self.l2_cache._cache.keys()) l2_keys = set(self.l2_cache._cache.keys())
shared_keys = l1_keys & l2_keys shared_keys = l1_keys & l2_keys
l1_only_keys = l1_keys - l2_keys l1_only_keys = l1_keys - l2_keys
l2_only_keys = l2_keys - l1_keys l2_only_keys = l2_keys - l1_keys
# 计算实际总内存(避免重复计数) # 计算实际总内存(避免重复计数)
# L1独占内存 # L1独占内存
l1_only_size = sum( l1_only_size = sum(
self.l1_cache._cache[k].size self.l1_cache._cache[k].size
for k in l1_only_keys for k in l1_only_keys
if k in self.l1_cache._cache if k in self.l1_cache._cache
) )
# L2独占内存 # L2独占内存
l2_only_size = sum( l2_only_size = sum(
self.l2_cache._cache[k].size self.l2_cache._cache[k].size
for k in l2_only_keys for k in l2_only_keys
if k in self.l2_cache._cache if k in self.l2_cache._cache
) )
# 共享内存只计算一次使用L1的数据 # 共享内存只计算一次使用L1的数据
shared_size = sum( shared_size = sum(
self.l1_cache._cache[k].size self.l1_cache._cache[k].size
for k in shared_keys for k in shared_keys
if k in self.l1_cache._cache if k in self.l1_cache._cache
) )
actual_total_size = l1_only_size + l2_only_size + shared_size actual_total_size = l1_only_size + l2_only_size + shared_size
return { return {
"l1": l1_stats, "l1": l1_stats,
"l2": l2_stats, "l2": l2_stats,
@@ -442,7 +442,7 @@ class MultiLevelCache:
"""检查并强制清理超出内存限制的缓存""" """检查并强制清理超出内存限制的缓存"""
stats = await self.get_stats() stats = await self.get_stats()
total_size = stats["l1"].total_size + stats["l2"].total_size total_size = stats["l1"].total_size + stats["l2"].total_size
if total_size > self.max_memory_bytes: if total_size > self.max_memory_bytes:
memory_mb = total_size / (1024 * 1024) memory_mb = total_size / (1024 * 1024)
max_mb = self.max_memory_bytes / (1024 * 1024) max_mb = self.max_memory_bytes / (1024 * 1024)
@@ -452,14 +452,14 @@ class MultiLevelCache:
) )
# 优先清理L2缓存温数据 # 优先清理L2缓存温数据
await self.l2_cache.clear() await self.l2_cache.clear()
# 如果清理L2后仍超限清理L1 # 如果清理L2后仍超限清理L1
stats_after_l2 = await self.get_stats() stats_after_l2 = await self.get_stats()
total_after_l2 = stats_after_l2["l1"].total_size + stats_after_l2["l2"].total_size total_after_l2 = stats_after_l2["l1"].total_size + stats_after_l2["l2"].total_size
if total_after_l2 > self.max_memory_bytes: if total_after_l2 > self.max_memory_bytes:
logger.warning("清理L2后仍超限继续清理L1缓存") logger.warning("清理L2后仍超限继续清理L1缓存")
await self.l1_cache.clear() await self.l1_cache.clear()
logger.info("缓存强制清理完成") logger.info("缓存强制清理完成")
async def start_cleanup_task(self, interval: float = 60) -> None: async def start_cleanup_task(self, interval: float = 60) -> None:
@@ -476,10 +476,10 @@ class MultiLevelCache:
while not self._is_closing: while not self._is_closing:
try: try:
await asyncio.sleep(interval) await asyncio.sleep(interval)
if self._is_closing: if self._is_closing:
break break
stats = await self.get_stats() stats = await self.get_stats()
l1_stats = stats["l1"] l1_stats = stats["l1"]
l2_stats = stats["l2"] l2_stats = stats["l2"]
@@ -493,13 +493,13 @@ class MultiLevelCache:
f"共享: {stats['shared_keys_count']}键/{stats['shared_mb']:.2f}MB " f"共享: {stats['shared_keys_count']}键/{stats['shared_mb']:.2f}MB "
f"(去重节省{stats['dedup_savings_mb']:.2f}MB)" f"(去重节省{stats['dedup_savings_mb']:.2f}MB)"
) )
# 🔧 清理过期条目 # 🔧 清理过期条目
await self._clean_expired_entries() await self._clean_expired_entries()
# 检查内存限制 # 检查内存限制
await self.check_memory_limit() await self.check_memory_limit()
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
@@ -511,7 +511,7 @@ class MultiLevelCache:
async def stop_cleanup_task(self) -> None: async def stop_cleanup_task(self) -> None:
"""停止清理任务""" """停止清理任务"""
self._is_closing = True self._is_closing = True
if self._cleanup_task is not None: if self._cleanup_task is not None:
self._cleanup_task.cancel() self._cleanup_task.cancel()
try: try:
@@ -520,43 +520,43 @@ class MultiLevelCache:
pass pass
self._cleanup_task = None self._cleanup_task = None
logger.info("缓存清理任务已停止") logger.info("缓存清理任务已停止")
async def _clean_expired_entries(self) -> None: async def _clean_expired_entries(self) -> None:
"""清理过期的缓存条目""" """清理过期的缓存条目"""
try: try:
current_time = time.time() current_time = time.time()
# 清理 L1 过期条目 # 清理 L1 过期条目
async with self.l1_cache._lock: async with self.l1_cache._lock:
expired_keys = [ expired_keys = [
key for key, entry in self.l1_cache._cache.items() key for key, entry in self.l1_cache._cache.items()
if current_time - entry.created_at > self.l1_cache.ttl if current_time - entry.created_at > self.l1_cache.ttl
] ]
for key in expired_keys: for key in expired_keys:
entry = self.l1_cache._cache.pop(key, None) entry = self.l1_cache._cache.pop(key, None)
if entry: if entry:
self.l1_cache._stats.evictions += 1 self.l1_cache._stats.evictions += 1
self.l1_cache._stats.item_count -= 1 self.l1_cache._stats.item_count -= 1
self.l1_cache._stats.total_size -= entry.size self.l1_cache._stats.total_size -= entry.size
# 清理 L2 过期条目 # 清理 L2 过期条目
async with self.l2_cache._lock: async with self.l2_cache._lock:
expired_keys = [ expired_keys = [
key for key, entry in self.l2_cache._cache.items() key for key, entry in self.l2_cache._cache.items()
if current_time - entry.created_at > self.l2_cache.ttl if current_time - entry.created_at > self.l2_cache.ttl
] ]
for key in expired_keys: for key in expired_keys:
entry = self.l2_cache._cache.pop(key, None) entry = self.l2_cache._cache.pop(key, None)
if entry: if entry:
self.l2_cache._stats.evictions += 1 self.l2_cache._stats.evictions += 1
self.l2_cache._stats.item_count -= 1 self.l2_cache._stats.item_count -= 1
self.l2_cache._stats.total_size -= entry.size self.l2_cache._stats.total_size -= entry.size
if expired_keys: if expired_keys:
logger.debug(f"清理了 {len(expired_keys)} 个过期缓存条目") logger.debug(f"清理了 {len(expired_keys)} 个过期缓存条目")
except Exception as e: except Exception as e:
logger.error(f"清理过期条目失败: {e}", exc_info=True) logger.error(f"清理过期条目失败: {e}", exc_info=True)
@@ -568,7 +568,7 @@ _cache_lock = asyncio.Lock()
async def get_cache() -> MultiLevelCache: async def get_cache() -> MultiLevelCache:
"""获取全局缓存实例(单例) """获取全局缓存实例(单例)
从配置文件读取缓存参数,如果配置未加载则使用默认值 从配置文件读取缓存参数,如果配置未加载则使用默认值
如果配置中禁用了缓存返回一个最小化的缓存实例容量为1 如果配置中禁用了缓存返回一个最小化的缓存实例容量为1
""" """
@@ -580,9 +580,9 @@ async def get_cache() -> MultiLevelCache:
# 尝试从配置读取参数 # 尝试从配置读取参数
try: try:
from src.config.config import global_config from src.config.config import global_config
db_config = global_config.database db_config = global_config.database
# 检查是否启用缓存 # 检查是否启用缓存
if not db_config.enable_database_cache: if not db_config.enable_database_cache:
logger.info("数据库缓存已禁用,使用最小化缓存实例") logger.info("数据库缓存已禁用,使用最小化缓存实例")
@@ -594,7 +594,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb=1, max_memory_mb=1,
) )
return _global_cache return _global_cache
l1_max_size = db_config.cache_l1_max_size l1_max_size = db_config.cache_l1_max_size
l1_ttl = db_config.cache_l1_ttl l1_ttl = db_config.cache_l1_ttl
l2_max_size = db_config.cache_l2_max_size l2_max_size = db_config.cache_l2_max_size
@@ -602,7 +602,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb = db_config.cache_max_memory_mb max_memory_mb = db_config.cache_max_memory_mb
max_item_size_mb = db_config.cache_max_item_size_mb max_item_size_mb = db_config.cache_max_item_size_mb
cleanup_interval = db_config.cache_cleanup_interval cleanup_interval = db_config.cache_cleanup_interval
logger.info( logger.info(
f"从配置加载缓存参数: L1({l1_max_size}/{l1_ttl}s), " f"从配置加载缓存参数: L1({l1_max_size}/{l1_ttl}s), "
f"L2({l2_max_size}/{l2_ttl}s), 内存限制({max_memory_mb}MB), " f"L2({l2_max_size}/{l2_ttl}s), 内存限制({max_memory_mb}MB), "
@@ -618,7 +618,7 @@ async def get_cache() -> MultiLevelCache:
max_memory_mb = 100 max_memory_mb = 100
max_item_size_mb = 1 max_item_size_mb = 1
cleanup_interval = 60 cleanup_interval = 60
_global_cache = MultiLevelCache( _global_cache = MultiLevelCache(
l1_max_size=l1_max_size, l1_max_size=l1_max_size,
l1_ttl=l1_ttl, l1_ttl=l1_ttl,

View File

@@ -4,73 +4,74 @@
提供比 sys.getsizeof() 更准确的内存占用估算方法 提供比 sys.getsizeof() 更准确的内存占用估算方法
""" """
import sys
import pickle import pickle
import sys
from typing import Any from typing import Any
import numpy as np import numpy as np
def get_accurate_size(obj: Any, seen: set | None = None) -> int: def get_accurate_size(obj: Any, seen: set | None = None) -> int:
""" """
准确估算对象的内存大小(递归计算所有引用对象) 准确估算对象的内存大小(递归计算所有引用对象)
比 sys.getsizeof() 准确得多,特别是对于复杂嵌套对象。 比 sys.getsizeof() 准确得多,特别是对于复杂嵌套对象。
Args: Args:
obj: 要估算大小的对象 obj: 要估算大小的对象
seen: 已访问对象的集合(用于避免循环引用) seen: 已访问对象的集合(用于避免循环引用)
Returns: Returns:
估算的字节数 估算的字节数
""" """
if seen is None: if seen is None:
seen = set() seen = set()
obj_id = id(obj) obj_id = id(obj)
if obj_id in seen: if obj_id in seen:
return 0 return 0
seen.add(obj_id) seen.add(obj_id)
size = sys.getsizeof(obj) size = sys.getsizeof(obj)
# NumPy 数组特殊处理 # NumPy 数组特殊处理
if isinstance(obj, np.ndarray): if isinstance(obj, np.ndarray):
size += obj.nbytes size += obj.nbytes
return size return size
# 字典:递归计算所有键值对 # 字典:递归计算所有键值对
if isinstance(obj, dict): if isinstance(obj, dict):
size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen) size += sum(get_accurate_size(k, seen) + get_accurate_size(v, seen)
for k, v in obj.items()) for k, v in obj.items())
# 列表、元组、集合:递归计算所有元素 # 列表、元组、集合:递归计算所有元素
elif isinstance(obj, (list, tuple, set, frozenset)): elif isinstance(obj, list | tuple | set | frozenset):
size += sum(get_accurate_size(item, seen) for item in obj) size += sum(get_accurate_size(item, seen) for item in obj)
# 有 __dict__ 的对象:递归计算属性 # 有 __dict__ 的对象:递归计算属性
elif hasattr(obj, '__dict__'): elif hasattr(obj, "__dict__"):
size += get_accurate_size(obj.__dict__, seen) size += get_accurate_size(obj.__dict__, seen)
# 其他可迭代对象 # 其他可迭代对象
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): elif hasattr(obj, "__iter__") and not isinstance(obj, str | bytes | bytearray):
try: try:
size += sum(get_accurate_size(item, seen) for item in obj) size += sum(get_accurate_size(item, seen) for item in obj)
except: except:
pass pass
return size return size
def get_pickle_size(obj: Any) -> int: def get_pickle_size(obj: Any) -> int:
""" """
使用 pickle 序列化大小作为参考 使用 pickle 序列化大小作为参考
通常比 sys.getsizeof() 更接近实际内存占用, 通常比 sys.getsizeof() 更接近实际内存占用,
但可能略小于真实内存占用(不包括 Python 对象开销) 但可能略小于真实内存占用(不包括 Python 对象开销)
Args: Args:
obj: 要估算大小的对象 obj: 要估算大小的对象
Returns: Returns:
pickle 序列化后的字节数,失败返回 0 pickle 序列化后的字节数,失败返回 0
""" """
@@ -83,17 +84,17 @@ def get_pickle_size(obj: Any) -> int:
def estimate_size_smart(obj: Any, max_depth: int = 5, sample_large: bool = True) -> int: def estimate_size_smart(obj: Any, max_depth: int = 5, sample_large: bool = True) -> int:
""" """
智能估算对象大小(平衡准确性和性能) 智能估算对象大小(平衡准确性和性能)
使用深度受限的递归估算+采样策略,平衡准确性和性能: 使用深度受限的递归估算+采样策略,平衡准确性和性能:
- 深度5层足以覆盖99%的缓存数据结构 - 深度5层足以覆盖99%的缓存数据结构
- 对大型容器(>100项进行采样估算 - 对大型容器(>100项进行采样估算
- 性能开销约60倍于sys.getsizeof但准确度提升1000+倍 - 性能开销约60倍于sys.getsizeof但准确度提升1000+倍
Args: Args:
obj: 要估算大小的对象 obj: 要估算大小的对象
max_depth: 最大递归深度默认5层可覆盖大多数嵌套结构 max_depth: 最大递归深度默认5层可覆盖大多数嵌套结构
sample_large: 对大型容器是否采样默认True提升性能 sample_large: 对大型容器是否采样默认True提升性能
Returns: Returns:
估算的字节数 估算的字节数
""" """
@@ -105,24 +106,24 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
# 检查深度限制 # 检查深度限制
if depth <= 0: if depth <= 0:
return sys.getsizeof(obj) return sys.getsizeof(obj)
# 检查循环引用 # 检查循环引用
obj_id = id(obj) obj_id = id(obj)
if obj_id in seen: if obj_id in seen:
return 0 return 0
seen.add(obj_id) seen.add(obj_id)
# 基本大小 # 基本大小
size = sys.getsizeof(obj) size = sys.getsizeof(obj)
# 简单类型直接返回 # 简单类型直接返回
if isinstance(obj, (int, float, bool, type(None), str, bytes, bytearray)): if isinstance(obj, int | float | bool | type(None) | str | bytes | bytearray):
return size return size
# NumPy 数组特殊处理 # NumPy 数组特殊处理
if isinstance(obj, np.ndarray): if isinstance(obj, np.ndarray):
return size + obj.nbytes return size + obj.nbytes
# 字典递归 # 字典递归
if isinstance(obj, dict): if isinstance(obj, dict):
items = list(obj.items()) items = list(obj.items())
@@ -130,7 +131,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
# 大字典采样前50 + 中间50 + 最后50 # 大字典采样前50 + 中间50 + 最后50
sample_items = items[:50] + items[len(items)//2-25:len(items)//2+25] + items[-50:] sample_items = items[:50] + items[len(items)//2-25:len(items)//2+25] + items[-50:]
sampled_size = sum( sampled_size = sum(
_estimate_recursive(k, depth - 1, seen, sample_large) + _estimate_recursive(k, depth - 1, seen, sample_large) +
_estimate_recursive(v, depth - 1, seen, sample_large) _estimate_recursive(v, depth - 1, seen, sample_large)
for k, v in sample_items for k, v in sample_items
) )
@@ -142,9 +143,9 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
size += _estimate_recursive(k, depth - 1, seen, sample_large) size += _estimate_recursive(k, depth - 1, seen, sample_large)
size += _estimate_recursive(v, depth - 1, seen, sample_large) size += _estimate_recursive(v, depth - 1, seen, sample_large)
return size return size
# 列表、元组、集合递归 # 列表、元组、集合递归
if isinstance(obj, (list, tuple, set, frozenset)): if isinstance(obj, list | tuple | set | frozenset):
items = list(obj) items = list(obj)
if sample_large and len(items) > 100: if sample_large and len(items) > 100:
# 大容器采样前50 + 中间50 + 最后50 # 大容器采样前50 + 中间50 + 最后50
@@ -160,21 +161,21 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
for item in items: for item in items:
size += _estimate_recursive(item, depth - 1, seen, sample_large) size += _estimate_recursive(item, depth - 1, seen, sample_large)
return size return size
# 有 __dict__ 的对象 # 有 __dict__ 的对象
if hasattr(obj, '__dict__'): if hasattr(obj, "__dict__"):
size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large) size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large)
return size return size
def format_size(size_bytes: int) -> str: def format_size(size_bytes: int) -> str:
""" """
格式化字节数为人类可读的格式 格式化字节数为人类可读的格式
Args: Args:
size_bytes: 字节数 size_bytes: 字节数
Returns: Returns:
格式化后的字符串,如 "1.23 MB" 格式化后的字符串,如 "1.23 MB"
""" """

View File

@@ -2,7 +2,6 @@ import os
import shutil import shutil
import sys import sys
from datetime import datetime from datetime import datetime
from typing import Optional
import tomlkit import tomlkit
from pydantic import Field from pydantic import Field
@@ -381,7 +380,7 @@ class Config(ValidatedConfigBase):
notice: NoticeConfig = Field(..., description="Notice消息配置") notice: NoticeConfig = Field(..., description="Notice消息配置")
emoji: EmojiConfig = Field(..., description="表情配置") emoji: EmojiConfig = Field(..., description="表情配置")
expression: ExpressionConfig = Field(..., description="表达配置") expression: ExpressionConfig = Field(..., description="表达配置")
memory: Optional[MemoryConfig] = Field(default=None, description="记忆配置") memory: MemoryConfig | None = Field(default=None, description="记忆配置")
mood: MoodConfig = Field(..., description="情绪配置") mood: MoodConfig = Field(..., description="情绪配置")
reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置") reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置")
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置") chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")

View File

@@ -401,16 +401,16 @@ class MemoryConfig(ValidatedConfigBase):
memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡") memory_system_load_balancing: bool = Field(default=True, description="启用记忆系统负载均衡")
memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流") memory_build_throttling: bool = Field(default=True, description="启用记忆构建节流")
memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列") memory_priority_queue_enabled: bool = Field(default=True, description="启用记忆优先级队列")
# === 记忆图系统配置 (Memory Graph System) === # === 记忆图系统配置 (Memory Graph System) ===
# 新一代记忆系统的配置项 # 新一代记忆系统的配置项
enable: bool = Field(default=True, description="启用记忆图系统") enable: bool = Field(default=True, description="启用记忆图系统")
data_dir: str = Field(default="data/memory_graph", description="记忆数据存储目录") data_dir: str = Field(default="data/memory_graph", description="记忆数据存储目录")
# 向量存储配置 # 向量存储配置
vector_collection_name: str = Field(default="memory_nodes", description="向量集合名称") vector_collection_name: str = Field(default="memory_nodes", description="向量集合名称")
vector_db_path: str = Field(default="data/memory_graph/chroma_db", description="向量数据库路径") vector_db_path: str = Field(default="data/memory_graph/chroma_db", description="向量数据库路径")
# 检索配置 # 检索配置
search_top_k: int = Field(default=10, description="默认检索返回数量") search_top_k: int = Field(default=10, description="默认检索返回数量")
search_min_importance: float = Field(default=0.3, description="最小重要性阈值") search_min_importance: float = Field(default=0.3, description="最小重要性阈值")
@@ -418,13 +418,13 @@ class MemoryConfig(ValidatedConfigBase):
search_max_expand_depth: int = Field(default=2, description="检索时图扩展深度0-3") search_max_expand_depth: int = Field(default=2, description="检索时图扩展深度0-3")
search_expand_semantic_threshold: float = Field(default=0.3, description="图扩展时语义相似度阈值建议0.3-0.5,过低可能引入无关记忆,过高无法扩展)") search_expand_semantic_threshold: float = Field(default=0.3, description="图扩展时语义相似度阈值建议0.3-0.5,过低可能引入无关记忆,过高无法扩展)")
enable_query_optimization: bool = Field(default=True, description="启用查询优化") enable_query_optimization: bool = Field(default=True, description="启用查询优化")
# 检索权重配置 (记忆图系统) # 检索权重配置 (记忆图系统)
search_vector_weight: float = Field(default=0.4, description="向量相似度权重") search_vector_weight: float = Field(default=0.4, description="向量相似度权重")
search_graph_distance_weight: float = Field(default=0.2, description="图距离权重") search_graph_distance_weight: float = Field(default=0.2, description="图距离权重")
search_importance_weight: float = Field(default=0.2, description="重要性权重") search_importance_weight: float = Field(default=0.2, description="重要性权重")
search_recency_weight: float = Field(default=0.2, description="时效性权重") search_recency_weight: float = Field(default=0.2, description="时效性权重")
# 记忆整合配置 # 记忆整合配置
consolidation_enabled: bool = Field(default=False, description="是否启用记忆整合") consolidation_enabled: bool = Field(default=False, description="是否启用记忆整合")
consolidation_interval_hours: float = Field(default=2.0, description="整合任务执行间隔(小时)") consolidation_interval_hours: float = Field(default=2.0, description="整合任务执行间隔(小时)")
@@ -442,21 +442,21 @@ class MemoryConfig(ValidatedConfigBase):
consolidation_linking_min_confidence: float = Field(default=0.7, description="LLM分析最低置信度阈值") consolidation_linking_min_confidence: float = Field(default=0.7, description="LLM分析最低置信度阈值")
consolidation_linking_llm_temperature: float = Field(default=0.2, description="LLM分析温度参数") consolidation_linking_llm_temperature: float = Field(default=0.2, description="LLM分析温度参数")
consolidation_linking_llm_max_tokens: int = Field(default=1500, description="LLM分析最大输出长度") consolidation_linking_llm_max_tokens: int = Field(default=1500, description="LLM分析最大输出长度")
# 遗忘配置 (记忆图系统) # 遗忘配置 (记忆图系统)
forgetting_enabled: bool = Field(default=True, description="是否启用自动遗忘") forgetting_enabled: bool = Field(default=True, description="是否启用自动遗忘")
forgetting_activation_threshold: float = Field(default=0.1, description="激活度阈值") forgetting_activation_threshold: float = Field(default=0.1, description="激活度阈值")
forgetting_min_importance: float = Field(default=0.8, description="最小保护重要性") forgetting_min_importance: float = Field(default=0.8, description="最小保护重要性")
# 激活配置 # 激活配置
activation_decay_rate: float = Field(default=0.9, description="激活度衰减率") activation_decay_rate: float = Field(default=0.9, description="激活度衰减率")
activation_propagation_strength: float = Field(default=0.5, description="激活传播强度") activation_propagation_strength: float = Field(default=0.5, description="激活传播强度")
activation_propagation_depth: int = Field(default=2, description="激活传播深度") activation_propagation_depth: int = Field(default=2, description="激活传播深度")
# 性能配置 # 性能配置
max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数") max_memory_nodes_per_memory: int = Field(default=10, description="每个记忆最多包含的节点数")
max_related_memories: int = Field(default=5, description="相关记忆最大数量") max_related_memories: int = Field(default=5, description="相关记忆最大数量")
# 节点去重合并配置 # 节点去重合并配置
node_merger_similarity_threshold: float = Field(default=0.85, description="节点去重相似度阈值") node_merger_similarity_threshold: float = Field(default=0.85, description="节点去重相似度阈值")
node_merger_context_match_required: bool = Field(default=True, description="节点合并是否要求上下文匹配") node_merger_context_match_required: bool = Field(default=True, description="节点合并是否要求上下文匹配")

View File

@@ -534,7 +534,7 @@ class _RequestExecutor:
model_name = model_info.name model_name = model_info.name
retry_interval = api_provider.retry_interval retry_interval = api_provider.retry_interval
if isinstance(e, (NetworkConnectionError, ReqAbortException)): if isinstance(e, NetworkConnectionError | ReqAbortException):
return await self._check_retry(remain_try, retry_interval, "连接异常", model_name) return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
elif isinstance(e, RespNotOkException): elif isinstance(e, RespNotOkException):
return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info) return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)

View File

@@ -100,10 +100,10 @@ class VectorStore:
# 处理额外的元数据,将 list 转换为 JSON 字符串 # 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items(): for key, value in node.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, list | dict):
import orjson import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value metadata[key] = value
else: else:
metadata[key] = str(value) metadata[key] = str(value)
@@ -149,9 +149,9 @@ class VectorStore:
"created_at": n.created_at.isoformat(), "created_at": n.created_at.isoformat(),
} }
for key, value in n.metadata.items(): for key, value in n.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, list | dict):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value # type: ignore metadata[key] = value # type: ignore
else: else:
metadata[key] = str(value) metadata[key] = str(value)

View File

@@ -4,8 +4,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import numpy as np import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -72,7 +70,7 @@ class EmbeddingGenerator:
logger.warning(f"⚠️ Embedding API 初始化失败: {e}") logger.warning(f"⚠️ Embedding API 初始化失败: {e}")
self._api_available = False self._api_available = False
async def generate(self, text: str) -> np.ndarray | None: async def generate(self, text: str) -> np.ndarray | None:
""" """
生成单个文本的嵌入向量 生成单个文本的嵌入向量
@@ -130,7 +128,7 @@ class EmbeddingGenerator:
logger.debug(f"API 嵌入生成失败: {e}") logger.debug(f"API 嵌入生成失败: {e}")
return None return None
def _get_dimension(self) -> int: def _get_dimension(self) -> int:
"""获取嵌入维度""" """获取嵌入维度"""
# 优先使用 API 维度 # 优先使用 API 维度

View File

@@ -7,11 +7,12 @@
""" """
import atexit import atexit
import orjson
import os import os
import threading import threading
from typing import Any, ClassVar from typing import Any, ClassVar
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
# 获取日志记录器 # 获取日志记录器
@@ -125,7 +126,7 @@ class PluginStorage:
try: try:
with open(self.file_path, "w", encoding="utf-8") as f: with open(self.file_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')) f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8"))
self._dirty = False # 保存后重置标志 self._dirty = False # 保存后重置标志
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。") logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
except Exception as e: except Exception as e:

View File

@@ -5,12 +5,12 @@ MCP Client Manager
""" """
import asyncio import asyncio
import orjson
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import mcp.types import mcp.types
import orjson
from fastmcp.client import Client, StdioTransport, StreamableHttpTransport from fastmcp.client import Client, StdioTransport, StreamableHttpTransport
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -4,11 +4,13 @@
""" """
import time import time
from typing import Any, Optional from dataclasses import dataclass, field
from dataclasses import dataclass, asdict, field from typing import Any
import orjson import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
logger = get_logger("stream_tool_history") logger = get_logger("stream_tool_history")
@@ -18,10 +20,10 @@ class ToolCallRecord:
"""工具调用记录""" """工具调用记录"""
tool_name: str tool_name: str
args: dict[str, Any] args: dict[str, Any]
result: Optional[dict[str, Any]] = None result: dict[str, Any] | None = None
status: str = "success" # success, error, pending status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time) timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒) execution_time: float | None = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存 cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览 result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息 error_message: str = "" # 错误信息
@@ -32,9 +34,9 @@ class ToolCallRecord:
content = self.result.get("content", "") content = self.result.get("content", "")
if isinstance(content, str): if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "") self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)): elif isinstance(content, list | dict):
try: try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..." self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..."
except Exception: except Exception:
self.result_preview = str(content)[:500] + "..." self.result_preview = str(content)[:500] + "..."
else: else:
@@ -105,7 +107,7 @@ class StreamToolHistoryManager:
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}") logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""从缓存或历史记录中获取结果 """从缓存或历史记录中获取结果
Args: Args:
@@ -160,9 +162,9 @@ class StreamToolHistoryManager:
return None return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any], async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None, execution_time: float | None = None,
tool_file_path: Optional[str] = None, tool_file_path: str | None = None,
ttl: Optional[int] = None) -> None: ttl: int | None = None) -> None:
"""缓存工具调用结果 """缓存工具调用结果
Args: Args:
@@ -207,7 +209,7 @@ class StreamToolHistoryManager:
except Exception as e: except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}") logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]: async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
"""获取最近的历史记录 """获取最近的历史记录
Args: Args:
@@ -295,7 +297,7 @@ class StreamToolHistoryManager:
self._history.clear() self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除") logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""在内存历史记录中搜索缓存 """在内存历史记录中搜索缓存
Args: Args:
@@ -333,7 +335,7 @@ class StreamToolHistoryManager:
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py") return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]: def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> str | None:
"""提取语义查询参数 """提取语义查询参数
Args: Args:
@@ -370,7 +372,7 @@ class StreamToolHistoryManager:
return "" return ""
try: try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8') args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
if len(args_str) > max_length: if len(args_str) > max_length:
args_str = args_str[:max_length] + "..." args_str = args_str[:max_length] + "..."
return args_str return args_str
@@ -411,4 +413,4 @@ def cleanup_stream_manager(chat_id: str) -> None:
""" """
if chat_id in _stream_managers: if chat_id in _stream_managers:
del _stream_managers[chat_id] del _stream_managers[chat_id]
logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器") logger.info(f"已清理聊天 {chat_id} 的工具历史记录管理器")

View File

@@ -1,5 +1,6 @@
import inspect import inspect
import time import time
from dataclasses import asdict
from typing import Any from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
@@ -10,8 +11,7 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_stream_tool_history_manager
from dataclasses import asdict
logger = get_logger("tool_use") logger = get_logger("tool_use")
@@ -140,7 +140,7 @@ class ToolExecutor:
# 构建工具调用历史文本 # 构建工具调用历史文本
tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True) tool_history = self.history_manager.format_for_prompt(max_records=5, include_results=True)
# 获取人设信息 # 获取人设信息
personality_core = global_config.personality.personality_core personality_core = global_config.personality.personality_core
personality_side = global_config.personality.personality_side personality_side = global_config.personality.personality_side
@@ -197,7 +197,7 @@ class ToolExecutor:
return tool_definitions return tool_definitions
async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]: async def execute_tool_calls(self, tool_calls: list[ToolCall] | None) -> tuple[list[dict[str, Any]], list[str]]:
"""执行工具调用 """执行工具调用
@@ -338,9 +338,8 @@ class ToolExecutor:
if tool_instance and result and tool_instance.enable_cache: if tool_instance and result and tool_instance.enable_cache:
try: try:
tool_file_path = inspect.getfile(tool_instance.__class__) tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key: if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key) function_args.get(tool_instance.semantic_cache_query_key)
await self.history_manager.cache_result( await self.history_manager.cache_result(
tool_name=tool_call.func_name, tool_name=tool_call.func_name,

View File

@@ -122,7 +122,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
+ relationship_score * self.score_weights["relationship"] + relationship_score * self.score_weights["relationship"]
+ mentioned_score * self.score_weights["mentioned"] + mentioned_score * self.score_weights["mentioned"]
) )
# 限制总分上限为1.0,确保分数在合理范围内 # 限制总分上限为1.0,确保分数在合理范围内
total_score = min(raw_total_score, 1.0) total_score = min(raw_total_score, 1.0)
@@ -131,7 +131,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
f"{relationship_score:.3f}*{self.score_weights['relationship']} + " f"{relationship_score:.3f}*{self.score_weights['relationship']} + "
f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}" f"{mentioned_score:.3f}*{self.score_weights['mentioned']} = {raw_total_score:.3f}"
) )
if raw_total_score > 1.0: if raw_total_score > 1.0:
logger.debug(f"[Affinity兴趣计算] 原始分数 {raw_total_score:.3f} 超过1.0,已限制为 {total_score:.3f}") logger.debug(f"[Affinity兴趣计算] 原始分数 {raw_total_score:.3f} 超过1.0,已限制为 {total_score:.3f}")
@@ -217,7 +217,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
return 0.0 return 0.0
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数") logger.warning("⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数")
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分 return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
except Exception as e: except Exception as e:
logger.warning(f"智能兴趣匹配失败: {e}") logger.warning(f"智能兴趣匹配失败: {e}")
@@ -251,19 +251,19 @@ class AffinityInterestCalculator(BaseInterestCalculator):
def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float: def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float:
"""计算提及分 - 区分强提及和弱提及 """计算提及分 - 区分强提及和弱提及
强提及(被@、被回复、私聊): 使用 strong_mention_interest_score 强提及(被@、被回复、私聊): 使用 strong_mention_interest_score
弱提及(文本匹配名字/别名): 使用 weak_mention_interest_score 弱提及(文本匹配名字/别名): 使用 weak_mention_interest_score
""" """
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
# 使用统一的提及检测函数 # 使用统一的提及检测函数
is_mentioned, mention_type = is_mentioned_bot_in_message(message) is_mentioned, mention_type = is_mentioned_bot_in_message(message)
if not is_mentioned: if not is_mentioned:
logger.debug("[提及分计算] 未提及机器人返回0.0") logger.debug("[提及分计算] 未提及机器人返回0.0")
return 0.0 return 0.0
# mention_type: 0=未提及, 1=弱提及, 2=强提及 # mention_type: 0=未提及, 1=弱提及, 2=强提及
if mention_type >= 2: if mention_type >= 2:
# 强提及:被@、被回复、私聊 # 强提及:被@、被回复、私聊
@@ -281,22 +281,22 @@ class AffinityInterestCalculator(BaseInterestCalculator):
def _apply_no_reply_threshold_adjustment(self) -> tuple[float, float]: def _apply_no_reply_threshold_adjustment(self) -> tuple[float, float]:
"""应用阈值调整(包括连续不回复和回复后降低机制) """应用阈值调整(包括连续不回复和回复后降低机制)
Returns: Returns:
tuple[float, float]: (调整后的回复阈值, 调整后的动作阈值) tuple[float, float]: (调整后的回复阈值, 调整后的动作阈值)
""" """
# 基础阈值 # 基础阈值
base_reply_threshold = self.reply_threshold base_reply_threshold = self.reply_threshold
base_action_threshold = global_config.affinity_flow.non_reply_action_interest_threshold base_action_threshold = global_config.affinity_flow.non_reply_action_interest_threshold
total_reduction = 0.0 total_reduction = 0.0
# 1. 连续不回复的阈值降低 # 1. 连续不回复的阈值降低
if self.no_reply_count > 0 and self.no_reply_count < self.max_no_reply_count: if self.no_reply_count > 0 and self.no_reply_count < self.max_no_reply_count:
no_reply_reduction = self.no_reply_count * self.probability_boost_per_no_reply no_reply_reduction = self.no_reply_count * self.probability_boost_per_no_reply
total_reduction += no_reply_reduction total_reduction += no_reply_reduction
logger.debug(f"[阈值调整] 连续不回复降低: {no_reply_reduction:.3f} (计数: {self.no_reply_count})") logger.debug(f"[阈值调整] 连续不回复降低: {no_reply_reduction:.3f} (计数: {self.no_reply_count})")
# 2. 回复后的阈值降低使bot更容易连续对话 # 2. 回复后的阈值降低使bot更容易连续对话
if self.enable_post_reply_boost and self.post_reply_boost_remaining > 0: if self.enable_post_reply_boost and self.post_reply_boost_remaining > 0:
# 计算衰减后的降低值 # 计算衰减后的降低值
@@ -309,16 +309,16 @@ class AffinityInterestCalculator(BaseInterestCalculator):
f"[阈值调整] 回复后降低: {post_reply_reduction:.3f} " f"[阈值调整] 回复后降低: {post_reply_reduction:.3f} "
f"(剩余次数: {self.post_reply_boost_remaining}, 衰减: {decay_factor:.2f})" f"(剩余次数: {self.post_reply_boost_remaining}, 衰减: {decay_factor:.2f})"
) )
# 应用总降低量 # 应用总降低量
adjusted_reply_threshold = max(0.0, base_reply_threshold - total_reduction) adjusted_reply_threshold = max(0.0, base_reply_threshold - total_reduction)
adjusted_action_threshold = max(0.0, base_action_threshold - total_reduction) adjusted_action_threshold = max(0.0, base_action_threshold - total_reduction)
return adjusted_reply_threshold, adjusted_action_threshold return adjusted_reply_threshold, adjusted_action_threshold
def _apply_no_reply_boost(self, base_score: float) -> float: def _apply_no_reply_boost(self, base_score: float) -> float:
"""【已弃用】应用连续不回复的概率提升 """【已弃用】应用连续不回复的概率提升
注意:此方法已被 _apply_no_reply_threshold_adjustment 替代 注意:此方法已被 _apply_no_reply_threshold_adjustment 替代
保留用于向后兼容 保留用于向后兼容
""" """
@@ -388,7 +388,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
self.no_reply_count = 0 self.no_reply_count = 0
else: else:
self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count) self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count)
def on_reply_sent(self): def on_reply_sent(self):
"""当机器人发送回复后调用,激活回复后阈值降低机制""" """当机器人发送回复后调用,激活回复后阈值降低机制"""
if self.enable_post_reply_boost: if self.enable_post_reply_boost:
@@ -399,16 +399,16 @@ class AffinityInterestCalculator(BaseInterestCalculator):
) )
# 同时重置不回复计数 # 同时重置不回复计数
self.no_reply_count = 0 self.no_reply_count = 0
def on_message_processed(self, replied: bool): def on_message_processed(self, replied: bool):
"""消息处理完成后调用,更新各种计数器 """消息处理完成后调用,更新各种计数器
Args: Args:
replied: 是否回复了此消息 replied: 是否回复了此消息
""" """
# 更新不回复计数 # 更新不回复计数
self.update_no_reply_count(replied) self.update_no_reply_count(replied)
# 如果已回复,激活回复后降低机制 # 如果已回复,激活回复后降低机制
if replied: if replied:
self.on_reply_sent() self.on_reply_sent()

View File

@@ -4,10 +4,10 @@ AffinityFlow Chatter 规划器模块
包含计划生成、过滤、执行等规划相关功能 包含计划生成、过滤、执行等规划相关功能
""" """
from . import planner_prompts
from .plan_executor import ChatterPlanExecutor from .plan_executor import ChatterPlanExecutor
from .plan_filter import ChatterPlanFilter from .plan_filter import ChatterPlanFilter
from .plan_generator import ChatterPlanGenerator from .plan_generator import ChatterPlanGenerator
from .planner import ChatterActionPlanner from .planner import ChatterActionPlanner
from . import planner_prompts
__all__ = ["ChatterActionPlanner", "planner_prompts", "ChatterPlanGenerator", "ChatterPlanFilter", "ChatterPlanExecutor"] __all__ = ["ChatterActionPlanner", "ChatterPlanExecutor", "ChatterPlanFilter", "ChatterPlanGenerator", "planner_prompts"]

View File

@@ -14,9 +14,7 @@ from json_repair import repair_json
# 旧的Hippocampus系统已被移除现在使用增强记忆系统 # 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_actions,
build_readable_messages_with_id, build_readable_messages_with_id,
get_actions_by_timestamp_with_chat,
) )
from src.chat.utils.prompt import global_prompt_manager from src.chat.utils.prompt import global_prompt_manager
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan from src.common.data_models.info_data_model import ActionPlannerInfo, Plan
@@ -646,7 +644,7 @@ class ChatterPlanFilter:
memory_manager = get_memory_manager() memory_manager = get_memory_manager()
if not memory_manager: if not memory_manager:
return "记忆系统未初始化。" return "记忆系统未初始化。"
# 将关键词转换为查询字符串 # 将关键词转换为查询字符串
query = " ".join(keywords) query = " ".join(keywords)
enhanced_memories = await memory_manager.search_memories( enhanced_memories = await memory_manager.search_memories(

View File

@@ -21,7 +21,6 @@ if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
# 导入提示词模块以确保其被初始化 # 导入提示词模块以确保其被初始化
from src.plugins.built_in.affinity_flow_chatter.planner import planner_prompts
logger = get_logger("planner") logger = get_logger("planner")
@@ -159,10 +158,10 @@ class ChatterActionPlanner:
action_data={}, action_data={},
action_message=None, action_message=None,
) )
# 更新连续不回复计数 # 更新连续不回复计数
await self._update_interest_calculator_state(replied=False) await self._update_interest_calculator_state(replied=False)
initial_plan = await self.generator.generate(chat_mode) initial_plan = await self.generator.generate(chat_mode)
filtered_plan = initial_plan filtered_plan = initial_plan
filtered_plan.decided_actions = [no_action] filtered_plan.decided_actions = [no_action]
@@ -270,7 +269,7 @@ class ChatterActionPlanner:
try: try:
# Normal模式开始时刷新缓存消息到未读列表 # Normal模式开始时刷新缓存消息到未读列表
await self._flush_cached_messages_to_unread(context) await self._flush_cached_messages_to_unread(context)
unread_messages = context.get_unread_messages() if context else [] unread_messages = context.get_unread_messages() if context else []
if not unread_messages: if not unread_messages:
@@ -347,7 +346,7 @@ class ChatterActionPlanner:
self._update_stats_from_execution_result(execution_result) self._update_stats_from_execution_result(execution_result)
logger.info("Normal模式: 执行reply动作完成") logger.info("Normal模式: 执行reply动作完成")
# 更新兴趣计算器状态(回复成功,重置不回复计数) # 更新兴趣计算器状态(回复成功,重置不回复计数)
await self._update_interest_calculator_state(replied=True) await self._update_interest_calculator_state(replied=True)
@@ -465,7 +464,7 @@ class ChatterActionPlanner:
async def _update_interest_calculator_state(self, replied: bool) -> None: async def _update_interest_calculator_state(self, replied: bool) -> None:
"""更新兴趣计算器状态(连续不回复计数和回复后降低机制) """更新兴趣计算器状态(连续不回复计数和回复后降低机制)
Args: Args:
replied: 是否回复了消息 replied: 是否回复了消息
""" """
@@ -504,36 +503,36 @@ class ChatterActionPlanner:
async def _flush_cached_messages_to_unread(self, context: "StreamContext | None") -> list: async def _flush_cached_messages_to_unread(self, context: "StreamContext | None") -> list:
"""在planner开始时将缓存消息刷新到未读消息列表 """在planner开始时将缓存消息刷新到未读消息列表
此方法在动作修改器执行后、生成初始计划前调用,确保计划阶段能看到所有积累的消息。 此方法在动作修改器执行后、生成初始计划前调用,确保计划阶段能看到所有积累的消息。
Args: Args:
context: 流上下文 context: 流上下文
Returns: Returns:
list: 刷新的消息列表 list: 刷新的消息列表
""" """
if not context: if not context:
return [] return []
try: try:
from src.chat.message_manager.message_manager import message_manager from src.chat.message_manager.message_manager import message_manager
stream_id = context.stream_id stream_id = context.stream_id
if message_manager.is_running and message_manager.has_cached_messages(stream_id): if message_manager.is_running and message_manager.has_cached_messages(stream_id):
# 获取缓存消息 # 获取缓存消息
cached_messages = message_manager.flush_cached_messages(stream_id) cached_messages = message_manager.flush_cached_messages(stream_id)
if cached_messages: if cached_messages:
# 直接添加到上下文的未读消息列表 # 直接添加到上下文的未读消息列表
for message in cached_messages: for message in cached_messages:
context.unread_messages.append(message) context.unread_messages.append(message)
logger.info(f"Planner开始前刷新缓存消息到未读列表: stream={stream_id}, 数量={len(cached_messages)}") logger.info(f"Planner开始前刷新缓存消息到未读列表: stream={stream_id}, 数量={len(cached_messages)}")
return cached_messages return cached_messages
return [] return []
except ImportError: except ImportError:
logger.debug("MessageManager不可用跳过缓存刷新") logger.debug("MessageManager不可用跳过缓存刷新")
return [] return []

View File

@@ -9,9 +9,9 @@ from .proactive_thinking_executor import execute_proactive_thinking
from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler
__all__ = [ __all__ = [
"ProactiveThinkingReplyHandler",
"ProactiveThinkingMessageHandler", "ProactiveThinkingMessageHandler",
"execute_proactive_thinking", "ProactiveThinkingReplyHandler",
"ProactiveThinkingScheduler", "ProactiveThinkingScheduler",
"execute_proactive_thinking",
"proactive_thinking_scheduler", "proactive_thinking_scheduler",
] ]

View File

@@ -3,7 +3,6 @@
当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复 当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复
""" """
import orjson
from datetime import datetime from datetime import datetime
from typing import Any, Literal from typing import Any, Literal

View File

@@ -14,7 +14,6 @@ from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import config_api, generator_api, llm_api from src.plugin_system.apis import config_api, generator_api, llm_api
@@ -320,7 +319,7 @@ class ContentService:
- 禁止在说说中直接、完整地提及当前的年月日,除非日期有特殊含义,但也尽量用节日名/节气名字代替。 - 禁止在说说中直接、完整地提及当前的年月日,除非日期有特殊含义,但也尽量用节日名/节气名字代替。
2. **严禁重复**:下方会提供你最近发过的说说历史,你必须创作一条全新的、与历史记录内容和主题都不同的说说。 2. **严禁重复**:下方会提供你最近发过的说说历史,你必须创作一条全新的、与历史记录内容和主题都不同的说说。
**其他的禁止的内容以及说明** **其他的禁止的内容以及说明**
- 绝对禁止提及当下具体几点几分的时间戳。 - 绝对禁止提及当下具体几点几分的时间戳。
- 绝对禁止攻击性内容和过度的负面情绪。 - 绝对禁止攻击性内容和过度的负面情绪。

View File

@@ -136,10 +136,10 @@ class QZoneService:
logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}") logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}")
api_client = await self._get_api_client(qq_account, stream_id) api_client = await self._get_api_client(qq_account, stream_id)
if not api_client: if not api_client:
logger.error(f"[DEBUG] API客户端获取失败返回错误") logger.error("[DEBUG] API客户端获取失败返回错误")
return {"success": False, "message": "获取QZone API客户端失败"} return {"success": False, "message": "获取QZone API客户端失败"}
logger.info(f"[DEBUG] API客户端获取成功准备读取说说") logger.info("[DEBUG] API客户端获取成功准备读取说说")
num_to_read = self.get_config("read.read_number", 5) num_to_read = self.get_config("read.read_number", 5)
# 尝试执行如果Cookie失效则自动重试一次 # 尝试执行如果Cookie失效则自动重试一次
@@ -186,7 +186,7 @@ class QZoneService:
# 检查是否是Cookie失效-3000错误 # 检查是否是Cookie失效-3000错误
if "错误码: -3000" in error_msg and retry_count == 0: if "错误码: -3000" in error_msg and retry_count == 0:
logger.warning(f"检测到Cookie失效-3000错误准备删除缓存并重试...") logger.warning("检测到Cookie失效-3000错误准备删除缓存并重试...")
# 删除Cookie缓存文件 # 删除Cookie缓存文件
cookie_file = self.cookie_service._get_cookie_file_path(qq_account) cookie_file = self.cookie_service._get_cookie_file_path(qq_account)
@@ -623,7 +623,7 @@ class QZoneService:
logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}") logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
return None return None
logger.info(f"[DEBUG] p_skey获取成功") logger.info("[DEBUG] p_skey获取成功")
gtk = self._generate_gtk(p_skey) gtk = self._generate_gtk(p_skey)
uin = cookies.get("uin", "").lstrip("o") uin = cookies.get("uin", "").lstrip("o")
@@ -1230,7 +1230,7 @@ class QZoneService:
logger.error(f"监控好友动态失败: {e}", exc_info=True) logger.error(f"监控好友动态失败: {e}", exc_info=True)
return [] return []
logger.info(f"[DEBUG] API客户端构造完成返回包含6个方法的字典") logger.info("[DEBUG] API客户端构造完成返回包含6个方法的字典")
return { return {
"publish": _publish, "publish": _publish,
"list_feeds": _list_feeds, "list_feeds": _list_feeds,

View File

@@ -3,11 +3,12 @@
负责记录和管理已回复过的评论ID避免重复回复 负责记录和管理已回复过的评论ID避免重复回复
""" """
import orjson
import time import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("MaiZone.ReplyTrackerService") logger = get_logger("MaiZone.ReplyTrackerService")
@@ -117,8 +118,8 @@ class ReplyTrackerService:
temp_file = self.reply_record_file.with_suffix(".tmp") temp_file = self.reply_record_file.with_suffix(".tmp")
# 先写入临时文件 # 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f: with open(temp_file, "w", encoding="utf-8"):
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8') orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8")
# 如果写入成功,重命名为正式文件 # 如果写入成功,重命名为正式文件
if temp_file.stat().st_size > 0: # 确保写入成功 if temp_file.stat().st_size > 0: # 确保写入成功

View File

@@ -1,7 +1,6 @@
import orjson import orjson
import random import random
import time import time
import random
import websockets as Server import websockets as Server
import uuid import uuid
from maim_message import ( from maim_message import (
@@ -205,7 +204,7 @@ class SendHandler:
# 发送响应回MoFox-Bot # 发送响应回MoFox-Bot
logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}") logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}")
await self.send_adapter_command_response(raw_message_base, response, request_id) await self.send_adapter_command_response(raw_message_base, response, request_id)
logger.debug(f"[DEBUG handle_adapter_command] send_adapter_command_response调用完成") logger.debug("[DEBUG handle_adapter_command] send_adapter_command_response调用完成")
if response.get("status") == "ok": if response.get("status") == "ok":
logger.info(f"适配器命令 {action} 执行成功") logger.info(f"适配器命令 {action} 执行成功")

View File

@@ -1,10 +1,10 @@
""" """
Metaso Search Engine (Chat Completions Mode) Metaso Search Engine (Chat Completions Mode)
""" """
import orjson
from typing import Any from typing import Any
import httpx import httpx
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api

View File

@@ -3,9 +3,10 @@ Serper search engine implementation
Google Search via Serper.dev API Google Search via Serper.dev API
""" """
import aiohttp
from typing import Any from typing import Any
import aiohttp
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
""" """
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool from .tools.url_parser import URLParserTool

View File

@@ -113,7 +113,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
search_tasks.append(engine.answer_search(custom_args)) search_tasks.append(engine.answer_search(custom_args))
else: else:
search_tasks.append(engine.search(custom_args)) search_tasks.append(engine.search(custom_args))
@@ -162,7 +162,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索fallback策略") logger.info("使用Exa答案模式进行搜索fallback策略")
results = await engine.answer_search(custom_args) results = await engine.answer_search(custom_args)
else: else:
@@ -195,7 +195,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索") logger.info("使用Exa答案模式进行搜索")
results = await engine.answer_search(custom_args) results = await engine.answer_search(custom_args)
else: else:

View File

@@ -266,13 +266,13 @@ class UnifiedScheduler:
name=f"execute_{task.task_name}" name=f"execute_{task.task_name}"
) )
execution_tasks.append(execution_task) execution_tasks.append(execution_task)
# 追踪正在执行的任务,以便在 remove_schedule 时可以取消 # 追踪正在执行的任务,以便在 remove_schedule 时可以取消
self._executing_tasks[task.schedule_id] = execution_task self._executing_tasks[task.schedule_id] = execution_task
# 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务) # 等待所有任务完成(使用 return_exceptions=True 避免单个任务失败影响其他任务)
results = await asyncio.gather(*execution_tasks, return_exceptions=True) results = await asyncio.gather(*execution_tasks, return_exceptions=True)
# 清理执行追踪 # 清理执行追踪
for task in tasks_to_trigger: for task in tasks_to_trigger:
self._executing_tasks.pop(task.schedule_id, None) self._executing_tasks.pop(task.schedule_id, None)
@@ -515,7 +515,7 @@ class UnifiedScheduler:
async def remove_schedule(self, schedule_id: str) -> bool: async def remove_schedule(self, schedule_id: str) -> bool:
"""移除调度任务 """移除调度任务
如果任务正在执行,会取消执行中的任务 如果任务正在执行,会取消执行中的任务
""" """
async with self._lock: async with self._lock:
@@ -524,7 +524,7 @@ class UnifiedScheduler:
return False return False
task = self._tasks[schedule_id] task = self._tasks[schedule_id]
# 检查是否有正在执行的任务 # 检查是否有正在执行的任务
executing_task = self._executing_tasks.get(schedule_id) executing_task = self._executing_tasks.get(schedule_id)
if executing_task and not executing_task.done(): if executing_task and not executing_task.done():

View File

@@ -19,42 +19,42 @@ logger = get_logger(__name__)
def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str, Any] | list | None: def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str, Any] | list | None:
""" """
从 LLM 响应中提取并解析 JSON 从 LLM 响应中提取并解析 JSON
处理策略: 处理策略:
1. 清理 Markdown 代码块标记(```json 和 ``` 1. 清理 Markdown 代码块标记(```json 和 ```
2. 提取 JSON 对象或数组 2. 提取 JSON 对象或数组
3. 使用 json_repair 修复格式问题 3. 使用 json_repair 修复格式问题
4. 解析为 Python 对象 4. 解析为 Python 对象
Args: Args:
response: LLM 响应字符串 response: LLM 响应字符串
strict: 严格模式,如果为 True 则解析失败时返回 None否则尝试容错处理 strict: 严格模式,如果为 True 则解析失败时返回 None否则尝试容错处理
Returns: Returns:
解析后的 dict 或 list失败时返回 None 解析后的 dict 或 list失败时返回 None
Examples: Examples:
>>> extract_and_parse_json('```json\\n{"key": "value"}\\n```') >>> extract_and_parse_json('```json\\n{"key": "value"}\\n```')
{'key': 'value'} {'key': 'value'}
>>> extract_and_parse_json('Some text {"key": "value"} more text') >>> extract_and_parse_json('Some text {"key": "value"} more text')
{'key': 'value'} {'key': 'value'}
>>> extract_and_parse_json('[{"a": 1}, {"b": 2}]') >>> extract_and_parse_json('[{"a": 1}, {"b": 2}]')
[{'a': 1}, {'b': 2}] [{'a': 1}, {'b': 2}]
""" """
if not response: if not response:
logger.debug("空响应,无法解析 JSON") logger.debug("空响应,无法解析 JSON")
return None return None
try: try:
# 步骤 1: 清理响应 # 步骤 1: 清理响应
cleaned = _clean_llm_response(response) cleaned = _clean_llm_response(response)
if not cleaned: if not cleaned:
logger.warning("清理后的响应为空") logger.warning("清理后的响应为空")
return None return None
# 步骤 2: 尝试直接解析 # 步骤 2: 尝试直接解析
try: try:
result = orjson.loads(cleaned) result = orjson.loads(cleaned)
@@ -62,11 +62,11 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
return result return result
except Exception as direct_error: except Exception as direct_error:
logger.debug(f"直接解析失败: {type(direct_error).__name__}: {direct_error}") logger.debug(f"直接解析失败: {type(direct_error).__name__}: {direct_error}")
# 步骤 3: 使用 json_repair 修复并解析 # 步骤 3: 使用 json_repair 修复并解析
try: try:
repaired = repair_json(cleaned) repaired = repair_json(cleaned)
# repair_json 可能返回字符串或已解析的对象 # repair_json 可能返回字符串或已解析的对象
if isinstance(repaired, str): if isinstance(repaired, str):
result = orjson.loads(repaired) result = orjson.loads(repaired)
@@ -74,16 +74,16 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
else: else:
result = repaired result = repaired
logger.debug(f"✅ JSON 修复后解析成功(对象模式),类型: {type(result).__name__}") logger.debug(f"✅ JSON 修复后解析成功(对象模式),类型: {type(result).__name__}")
return result return result
except Exception as repair_error: except Exception as repair_error:
logger.warning(f"JSON 修复失败: {type(repair_error).__name__}: {repair_error}") logger.warning(f"JSON 修复失败: {type(repair_error).__name__}: {repair_error}")
if strict: if strict:
logger.error(f"严格模式下解析失败,响应片段: {cleaned[:200]}") logger.error(f"严格模式下解析失败,响应片段: {cleaned[:200]}")
return None return None
# 最后的容错尝试:返回空字典或空列表 # 最后的容错尝试:返回空字典或空列表
if cleaned.strip().startswith("["): if cleaned.strip().startswith("["):
logger.warning("返回空列表作为容错") logger.warning("返回空列表作为容错")
@@ -91,7 +91,7 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
else: else:
logger.warning("返回空字典作为容错") logger.warning("返回空字典作为容错")
return {} return {}
except Exception as e: except Exception as e:
logger.error(f"❌ JSON 解析过程出现异常: {type(e).__name__}: {e}") logger.error(f"❌ JSON 解析过程出现异常: {type(e).__name__}: {e}")
if strict: if strict:
@@ -102,37 +102,37 @@ def extract_and_parse_json(response: str, *, strict: bool = False) -> dict[str,
def _clean_llm_response(response: str) -> str: def _clean_llm_response(response: str) -> str:
""" """
清理 LLM 响应,提取 JSON 部分 清理 LLM 响应,提取 JSON 部分
处理步骤: 处理步骤:
1. 移除 Markdown 代码块标记(```json 和 ``` 1. 移除 Markdown 代码块标记(```json 和 ```
2. 提取第一个完整的 JSON 对象 {...} 或数组 [...] 2. 提取第一个完整的 JSON 对象 {...} 或数组 [...]
3. 清理多余的空格和换行 3. 清理多余的空格和换行
Args: Args:
response: 原始 LLM 响应 response: 原始 LLM 响应
Returns: Returns:
清理后的 JSON 字符串 清理后的 JSON 字符串
""" """
if not response: if not response:
return "" return ""
cleaned = response.strip() cleaned = response.strip()
# 移除 Markdown 代码块标记 # 移除 Markdown 代码块标记
# 匹配 ```json ... ``` 或 ``` ... ``` # 匹配 ```json ... ``` 或 ``` ... ```
code_block_patterns = [ code_block_patterns = [
r"```json\s*(.*?)```", # ```json ... ``` r"```json\s*(.*?)```", # ```json ... ```
r"```\s*(.*?)```", # ``` ... ``` r"```\s*(.*?)```", # ``` ... ```
] ]
for pattern in code_block_patterns: for pattern in code_block_patterns:
match = re.search(pattern, cleaned, re.IGNORECASE | re.DOTALL) match = re.search(pattern, cleaned, re.IGNORECASE | re.DOTALL)
if match: if match:
cleaned = match.group(1).strip() cleaned = match.group(1).strip()
logger.debug(f"从 Markdown 代码块中提取内容,长度: {len(cleaned)}") logger.debug(f"从 Markdown 代码块中提取内容,长度: {len(cleaned)}")
break break
# 提取 JSON 对象或数组 # 提取 JSON 对象或数组
# 优先查找对象 {...},其次查找数组 [...] # 优先查找对象 {...},其次查找数组 [...]
for start_char, end_char in [("{", "}"), ("[", "]")]: for start_char, end_char in [("{", "}"), ("[", "]")]:
@@ -143,7 +143,7 @@ def _clean_llm_response(response: str) -> str:
if extracted: if extracted:
logger.debug(f"提取到 {start_char}...{end_char} 结构,长度: {len(extracted)}") logger.debug(f"提取到 {start_char}...{end_char} 结构,长度: {len(extracted)}")
return extracted return extracted
# 如果没有找到明确的 JSON 结构,返回清理后的原始内容 # 如果没有找到明确的 JSON 结构,返回清理后的原始内容
logger.debug("未找到明确的 JSON 结构,返回清理后的原始内容") logger.debug("未找到明确的 JSON 结构,返回清理后的原始内容")
return cleaned return cleaned
@@ -152,39 +152,39 @@ def _clean_llm_response(response: str) -> str:
def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char: str) -> str | None: def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char: str) -> str | None:
""" """
从指定位置提取平衡的 JSON 结构 从指定位置提取平衡的 JSON 结构
使用栈匹配算法找到对应的结束符,处理嵌套和字符串中的特殊字符 使用栈匹配算法找到对应的结束符,处理嵌套和字符串中的特殊字符
Args: Args:
text: 源文本 text: 源文本
start_idx: 起始字符的索引 start_idx: 起始字符的索引
start_char: 起始字符({ 或 [ start_char: 起始字符({ 或 [
end_char: 结束字符(} 或 ] end_char: 结束字符(} 或 ]
Returns: Returns:
提取的 JSON 字符串,失败时返回 None 提取的 JSON 字符串,失败时返回 None
""" """
depth = 0 depth = 0
in_string = False in_string = False
escape_next = False escape_next = False
for i in range(start_idx, len(text)): for i in range(start_idx, len(text)):
char = text[i] char = text[i]
# 处理转义字符 # 处理转义字符
if escape_next: if escape_next:
escape_next = False escape_next = False
continue continue
if char == "\\": if char == "\\":
escape_next = True escape_next = True
continue continue
# 处理字符串 # 处理字符串
if char == '"': if char == '"':
in_string = not in_string in_string = not in_string
continue continue
# 只在非字符串内处理括号 # 只在非字符串内处理括号
if not in_string: if not in_string:
if char == start_char: if char == start_char:
@@ -194,7 +194,7 @@ def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char:
if depth == 0: if depth == 0:
# 找到匹配的结束符 # 找到匹配的结束符
return text[start_idx : i + 1].strip() return text[start_idx : i + 1].strip()
# 没有找到匹配的结束符 # 没有找到匹配的结束符
logger.debug(f"未找到匹配的 {end_char},深度: {depth}") logger.debug(f"未找到匹配的 {end_char},深度: {depth}")
return None return None
@@ -203,11 +203,11 @@ def _extract_balanced_json(text: str, start_idx: int, start_char: str, end_char:
def safe_parse_json(json_str: str, default: Any = None) -> Any: def safe_parse_json(json_str: str, default: Any = None) -> Any:
""" """
安全解析 JSON失败时返回默认值 安全解析 JSON失败时返回默认值
Args: Args:
json_str: JSON 字符串 json_str: JSON 字符串
default: 解析失败时返回的默认值 default: 解析失败时返回的默认值
Returns: Returns:
解析结果或默认值 解析结果或默认值
""" """
@@ -222,19 +222,19 @@ def safe_parse_json(json_str: str, default: Any = None) -> Any:
def extract_json_field(response: str, field_name: str, default: Any = None) -> Any: def extract_json_field(response: str, field_name: str, default: Any = None) -> Any:
""" """
从 LLM 响应中提取特定字段的值 从 LLM 响应中提取特定字段的值
Args: Args:
response: LLM 响应 response: LLM 响应
field_name: 字段名 field_name: 字段名
default: 字段不存在时的默认值 default: 字段不存在时的默认值
Returns: Returns:
字段值或默认值 字段值或默认值
""" """
parsed = extract_and_parse_json(response, strict=False) parsed = extract_and_parse_json(response, strict=False)
if isinstance(parsed, dict): if isinstance(parsed, dict):
return parsed.get(field_name, default) return parsed.get(field_name, default)
logger.warning(f"解析结果不是字典,无法提取字段 '{field_name}'") logger.warning(f"解析结果不是字典,无法提取字段 '{field_name}'")
return default return default

View File

@@ -14,7 +14,7 @@ sys.path.insert(0, str(project_root))
from tools.memory_visualizer.visualizer_server import run_server from tools.memory_visualizer.visualizer_server import run_server
if __name__ == '__main__': if __name__ == "__main__":
print("=" * 60) print("=" * 60)
print("🦊 MoFox Bot - 记忆图可视化工具") print("🦊 MoFox Bot - 记忆图可视化工具")
print("=" * 60) print("=" * 60)
@@ -24,10 +24,10 @@ if __name__ == '__main__':
print("⏹️ 按 Ctrl+C 停止服务器") print("⏹️ 按 Ctrl+C 停止服务器")
print() print()
print("=" * 60) print("=" * 60)
try: try:
run_server( run_server(
host='127.0.0.1', host="127.0.0.1",
port=5000, port=5000,
debug=True debug=True
) )

View File

@@ -15,7 +15,7 @@ from pathlib import Path
project_root = Path(__file__).parent project_root = Path(__file__).parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
if __name__ == '__main__': if __name__ == "__main__":
print("=" * 70) print("=" * 70)
print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)") print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)")
print("=" * 70) print("=" * 70)
@@ -26,10 +26,10 @@ if __name__ == '__main__':
print(" • 快速启动,无需完整初始化") print(" • 快速启动,无需完整初始化")
print() print()
print("=" * 70) print("=" * 70)
try: try:
from tools.memory_visualizer.visualizer_simple import run_server from tools.memory_visualizer.visualizer_simple import run_server
run_server(host='127.0.0.1', port=5001, debug=True) run_server(host="127.0.0.1", port=5001, debug=True)
except KeyboardInterrupt: except KeyboardInterrupt:
print("\n\n👋 服务器已停止") print("\n\n👋 服务器已停止")
except Exception as e: except Exception as e:

View File

@@ -11,7 +11,6 @@ import logging
import sys import sys
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Optional
from flask import Flask, jsonify, render_template, request from flask import Flask, jsonify, render_template, request
from flask_cors import CORS from flask_cors import CORS
@@ -28,7 +27,7 @@ app = Flask(__name__)
CORS(app) # 允许跨域请求 CORS(app) # 允许跨域请求
# 全局记忆管理器 # 全局记忆管理器
memory_manager: Optional[MemoryManager] = None memory_manager: MemoryManager | None = None
def init_memory_manager(): def init_memory_manager():
@@ -189,7 +188,7 @@ def search_memories():
init_memory_manager() init_memory_manager()
query = request.args.get("q", "") query = request.args.get("q", "")
memory_type = request.args.get("type", None) request.args.get("type", None)
limit = int(request.args.get("limit", 50)) limit = int(request.args.get("limit", 50))
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()

View File

@@ -4,20 +4,18 @@
直接从存储的数据文件生成可视化,无需启动完整的记忆管理器 直接从存储的数据文件生成可视化,无需启动完整的记忆管理器
""" """
import orjson
import sys import sys
from pathlib import Path
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Set
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Set from typing import Any
import orjson
# 添加项目根目录 # 添加项目根目录
project_root = Path(__file__).parent.parent.parent project_root = Path(__file__).parent.parent.parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
from flask import Flask, jsonify, render_template_string, request, send_from_directory from flask import Flask, jsonify, render_template_string, request
from flask_cors import CORS from flask_cors import CORS
app = Flask(__name__) app = Flask(__name__)
@@ -29,38 +27,38 @@ data_dir = project_root / "data" / "memory_graph"
current_data_file = None # 当前选择的数据文件 current_data_file = None # 当前选择的数据文件
def find_available_data_files() -> List[Path]: def find_available_data_files() -> list[Path]:
"""查找所有可用的记忆图数据文件""" """查找所有可用的记忆图数据文件"""
files = [] files = []
if not data_dir.exists(): if not data_dir.exists():
return files return files
# 查找多种可能的文件名 # 查找多种可能的文件名
possible_files = [ possible_files = [
"graph_store.json", "graph_store.json",
"memory_graph.json", "memory_graph.json",
"graph_data.json", "graph_data.json",
] ]
for filename in possible_files: for filename in possible_files:
file_path = data_dir / filename file_path = data_dir / filename
if file_path.exists(): if file_path.exists():
files.append(file_path) files.append(file_path)
# 查找所有备份文件 # 查找所有备份文件
for pattern in ["graph_store_*.json", "memory_graph_*.json", "graph_data_*.json"]: for pattern in ["graph_store_*.json", "memory_graph_*.json", "graph_data_*.json"]:
for backup_file in data_dir.glob(pattern): for backup_file in data_dir.glob(pattern):
if backup_file not in files: if backup_file not in files:
files.append(backup_file) files.append(backup_file)
# 查找backups子目录 # 查找backups子目录
backups_dir = data_dir / "backups" backups_dir = data_dir / "backups"
if backups_dir.exists(): if backups_dir.exists():
for backup_file in backups_dir.glob("**/*.json"): for backup_file in backups_dir.glob("**/*.json"):
if backup_file not in files: if backup_file not in files:
files.append(backup_file) files.append(backup_file)
# 查找data/backup目录 # 查找data/backup目录
backup_dir = data_dir.parent / "backup" backup_dir = data_dir.parent / "backup"
if backup_dir.exists(): if backup_dir.exists():
@@ -70,22 +68,22 @@ def find_available_data_files() -> List[Path]:
for backup_file in backup_dir.glob("**/memory_*.json"): for backup_file in backup_dir.glob("**/memory_*.json"):
if backup_file not in files: if backup_file not in files:
files.append(backup_file) files.append(backup_file)
return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True) return sorted(files, key=lambda f: f.stat().st_mtime, reverse=True)
def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]: def load_graph_data(file_path: Path | None = None) -> dict[str, Any]:
"""从磁盘加载图数据""" """从磁盘加载图数据"""
global graph_data_cache, current_data_file global graph_data_cache, current_data_file
# 如果指定了新文件,清除缓存 # 如果指定了新文件,清除缓存
if file_path is not None and file_path != current_data_file: if file_path is not None and file_path != current_data_file:
graph_data_cache = None graph_data_cache = None
current_data_file = file_path current_data_file = file_path
if graph_data_cache is not None: if graph_data_cache is not None:
return graph_data_cache return graph_data_cache
try: try:
# 确定要加载的文件 # 确定要加载的文件
if current_data_file is not None: if current_data_file is not None:
@@ -94,115 +92,115 @@ def load_graph_data(file_path: Optional[Path] = None) -> Dict[str, Any]:
# 尝试查找可用的数据文件 # 尝试查找可用的数据文件
available_files = find_available_data_files() available_files = find_available_data_files()
if not available_files: if not available_files:
print(f"⚠️ 未找到任何图数据文件") print("⚠️ 未找到任何图数据文件")
print(f"📂 搜索目录: {data_dir}") print(f"📂 搜索目录: {data_dir}")
return { return {
"nodes": [], "nodes": [],
"edges": [], "edges": [],
"memories": [], "memories": [],
"stats": {"total_nodes": 0, "total_edges": 0, "total_memories": 0}, "stats": {"total_nodes": 0, "total_edges": 0, "total_memories": 0},
"error": "未找到数据文件", "error": "未找到数据文件",
"available_files": [] "available_files": []
} }
# 使用最新的文件 # 使用最新的文件
graph_file = available_files[0] graph_file = available_files[0]
current_data_file = graph_file current_data_file = graph_file
print(f"📂 自动选择最新文件: {graph_file}") print(f"📂 自动选择最新文件: {graph_file}")
if not graph_file.exists(): if not graph_file.exists():
print(f"⚠️ 图数据文件不存在: {graph_file}") print(f"⚠️ 图数据文件不存在: {graph_file}")
return { return {
"nodes": [], "nodes": [],
"edges": [], "edges": [],
"memories": [], "memories": [],
"stats": {"total_nodes": 0, "total_edges": 0, "total_memories": 0}, "stats": {"total_nodes": 0, "total_edges": 0, "total_memories": 0},
"error": f"文件不存在: {graph_file}" "error": f"文件不存在: {graph_file}"
} }
print(f"📂 加载图数据: {graph_file}") print(f"📂 加载图数据: {graph_file}")
with open(graph_file, 'r', encoding='utf-8') as f: with open(graph_file, encoding="utf-8") as f:
data = orjson.loads(f.read()) data = orjson.loads(f.read())
# 解析数据 # 解析数据
nodes_dict = {} nodes_dict = {}
edges_list = [] edges_list = []
memory_info = [] memory_info = []
# 实际文件格式是 {nodes: [], edges: [], metadata: {}} # 实际文件格式是 {nodes: [], edges: [], metadata: {}}
# 不是 {memories: [{nodes: [], edges: []}]} # 不是 {memories: [{nodes: [], edges: []}]}
nodes = data.get("nodes", []) nodes = data.get("nodes", [])
edges = data.get("edges", []) edges = data.get("edges", [])
metadata = data.get("metadata", {}) metadata = data.get("metadata", {})
print(f"✅ 找到 {len(nodes)} 个节点, {len(edges)} 条边") print(f"✅ 找到 {len(nodes)} 个节点, {len(edges)} 条边")
# 处理节点 # 处理节点
for node in nodes: for node in nodes:
node_id = node.get('id', '') node_id = node.get("id", "")
if node_id and node_id not in nodes_dict: if node_id and node_id not in nodes_dict:
memory_ids = node.get('metadata', {}).get('memory_ids', []) memory_ids = node.get("metadata", {}).get("memory_ids", [])
nodes_dict[node_id] = { nodes_dict[node_id] = {
'id': node_id, "id": node_id,
'label': node.get('content', ''), "label": node.get("content", ""),
'type': node.get('node_type', ''), "type": node.get("node_type", ""),
'group': extract_group_from_type(node.get('node_type', '')), "group": extract_group_from_type(node.get("node_type", "")),
'title': f"{node.get('node_type', '')}: {node.get('content', '')}", "title": f"{node.get('node_type', '')}: {node.get('content', '')}",
'metadata': node.get('metadata', {}), "metadata": node.get("metadata", {}),
'created_at': node.get('created_at', ''), "created_at": node.get("created_at", ""),
'memory_ids': memory_ids, "memory_ids": memory_ids,
} }
# 处理边 - 使用集合去重避免重复的边ID # 处理边 - 使用集合去重避免重复的边ID
existing_edge_ids = set() existing_edge_ids = set()
for edge in edges: for edge in edges:
# 边的ID字段可能是 'id' 或 'edge_id' # 边的ID字段可能是 'id' 或 'edge_id'
edge_id = edge.get('edge_id') or edge.get('id', '') edge_id = edge.get("edge_id") or edge.get("id", "")
# 如果ID为空或已存在跳过这条边 # 如果ID为空或已存在跳过这条边
if not edge_id or edge_id in existing_edge_ids: if not edge_id or edge_id in existing_edge_ids:
continue continue
existing_edge_ids.add(edge_id) existing_edge_ids.add(edge_id)
memory_id = edge.get('metadata', {}).get('memory_id', '') memory_id = edge.get("metadata", {}).get("memory_id", "")
# 注意: GraphStore 保存的格式使用 'source'/'target', 不是 'source_id'/'target_id' # 注意: GraphStore 保存的格式使用 'source'/'target', 不是 'source_id'/'target_id'
edges_list.append({ edges_list.append({
'id': edge_id, "id": edge_id,
'from': edge.get('source', edge.get('source_id', '')), "from": edge.get("source", edge.get("source_id", "")),
'to': edge.get('target', edge.get('target_id', '')), "to": edge.get("target", edge.get("target_id", "")),
'label': edge.get('relation', ''), "label": edge.get("relation", ""),
'type': edge.get('edge_type', ''), "type": edge.get("edge_type", ""),
'importance': edge.get('importance', 0.5), "importance": edge.get("importance", 0.5),
'title': f"{edge.get('edge_type', '')}: {edge.get('relation', '')}", "title": f"{edge.get('edge_type', '')}: {edge.get('relation', '')}",
'arrows': 'to', "arrows": "to",
'memory_id': memory_id, "memory_id": memory_id,
}) })
# 从元数据中获取统计信息 # 从元数据中获取统计信息
stats = metadata.get('statistics', {}) stats = metadata.get("statistics", {})
total_memories = stats.get('total_memories', 0) total_memories = stats.get("total_memories", 0)
# TODO: 如果需要记忆详细信息,需要从其他地方加载 # TODO: 如果需要记忆详细信息,需要从其他地方加载
# 目前只有节点和边的数据 # 目前只有节点和边的数据
graph_data_cache = { graph_data_cache = {
'nodes': list(nodes_dict.values()), "nodes": list(nodes_dict.values()),
'edges': edges_list, "edges": edges_list,
'memories': memory_info, # 空列表,因为文件中没有记忆详情 "memories": memory_info, # 空列表,因为文件中没有记忆详情
'stats': { "stats": {
'total_nodes': len(nodes_dict), "total_nodes": len(nodes_dict),
'total_edges': len(edges_list), "total_edges": len(edges_list),
'total_memories': total_memories, "total_memories": total_memories,
}, },
'current_file': str(graph_file), "current_file": str(graph_file),
'file_size': graph_file.stat().st_size, "file_size": graph_file.stat().st_size,
'file_modified': datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(), "file_modified": datetime.fromtimestamp(graph_file.stat().st_mtime).isoformat(),
} }
print(f"📊 统计: {len(nodes_dict)} 个节点, {len(edges_list)} 条边, {total_memories} 条记忆") print(f"📊 统计: {len(nodes_dict)} 个节点, {len(edges_list)} 条边, {total_memories} 条记忆")
print(f"📄 数据文件: {graph_file} ({graph_file.stat().st_size / 1024:.2f} KB)") print(f"📄 数据文件: {graph_file} ({graph_file.stat().st_size / 1024:.2f} KB)")
return graph_data_cache return graph_data_cache
except Exception as e: except Exception as e:
print(f"❌ 加载失败: {e}") print(f"❌ 加载失败: {e}")
import traceback import traceback
@@ -214,246 +212,246 @@ def extract_group_from_type(node_type: str) -> str:
"""从节点类型提取分组名""" """从节点类型提取分组名"""
# 假设类型格式为 "主体" 或 "SUBJECT" # 假设类型格式为 "主体" 或 "SUBJECT"
type_mapping = { type_mapping = {
'主体': 'SUBJECT', "主体": "SUBJECT",
'主题': 'TOPIC', "主题": "TOPIC",
'客体': 'OBJECT', "客体": "OBJECT",
'属性': 'ATTRIBUTE', "属性": "ATTRIBUTE",
'': 'VALUE', "": "VALUE",
} }
return type_mapping.get(node_type, node_type) return type_mapping.get(node_type, node_type)
def generate_memory_text(memory: Dict[str, Any]) -> str: def generate_memory_text(memory: dict[str, Any]) -> str:
"""生成记忆的文本描述""" """生成记忆的文本描述"""
try: try:
nodes = {n['id']: n for n in memory.get('nodes', [])} nodes = {n["id"]: n for n in memory.get("nodes", [])}
edges = memory.get('edges', []) edges = memory.get("edges", [])
subject_id = memory.get('subject_id', '') subject_id = memory.get("subject_id", "")
if not subject_id or subject_id not in nodes: if not subject_id or subject_id not in nodes:
return f"[记忆 {memory.get('id', '')[:8]}]" return f"[记忆 {memory.get('id', '')[:8]}]"
parts = [nodes[subject_id]['content']] parts = [nodes[subject_id]["content"]]
# 找主题节点 # 找主题节点
for edge in edges: for edge in edges:
if edge.get('edge_type') == '记忆类型' and edge.get('source_id') == subject_id: if edge.get("edge_type") == "记忆类型" and edge.get("source_id") == subject_id:
topic_id = edge.get('target_id', '') topic_id = edge.get("target_id", "")
if topic_id in nodes: if topic_id in nodes:
parts.append(nodes[topic_id]['content']) parts.append(nodes[topic_id]["content"])
# 找客体 # 找客体
for e2 in edges: for e2 in edges:
if e2.get('edge_type') == '核心关系' and e2.get('source_id') == topic_id: if e2.get("edge_type") == "核心关系" and e2.get("source_id") == topic_id:
obj_id = e2.get('target_id', '') obj_id = e2.get("target_id", "")
if obj_id in nodes: if obj_id in nodes:
parts.append(f"{e2.get('relation', '')} {nodes[obj_id]['content']}") parts.append(f"{e2.get('relation', '')} {nodes[obj_id]['content']}")
break break
break break
return " ".join(parts) return " ".join(parts)
except Exception: except Exception:
return f"[记忆 {memory.get('id', '')[:8]}]" return f"[记忆 {memory.get('id', '')[:8]}]"
# 使用内嵌的HTML模板(与之前相同) # 使用内嵌的HTML模板(与之前相同)
HTML_TEMPLATE = open(project_root / "tools" / "memory_visualizer" / "templates" / "visualizer.html", 'r', encoding='utf-8').read() HTML_TEMPLATE = open(project_root / "tools" / "memory_visualizer" / "templates" / "visualizer.html", encoding="utf-8").read()
@app.route('/') @app.route("/")
def index(): def index():
"""主页面""" """主页面"""
return render_template_string(HTML_TEMPLATE) return render_template_string(HTML_TEMPLATE)
@app.route('/api/graph/full') @app.route("/api/graph/full")
def get_full_graph(): def get_full_graph():
"""获取完整记忆图数据""" """获取完整记忆图数据"""
try: try:
data = load_graph_data() data = load_graph_data()
return jsonify({ return jsonify({
'success': True, "success": True,
'data': data "data": data
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/memory/<memory_id>') @app.route("/api/memory/<memory_id>")
def get_memory_detail(memory_id: str): def get_memory_detail(memory_id: str):
"""获取记忆详情""" """获取记忆详情"""
try: try:
data = load_graph_data() data = load_graph_data()
memory = next((m for m in data['memories'] if m['id'] == memory_id), None) memory = next((m for m in data["memories"] if m["id"] == memory_id), None)
if memory is None: if memory is None:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': '记忆不存在' "error": "记忆不存在"
}), 404 }), 404
return jsonify({ return jsonify({
'success': True, "success": True,
'data': memory "data": memory
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/search') @app.route("/api/search")
def search_memories(): def search_memories():
"""搜索记忆""" """搜索记忆"""
try: try:
query = request.args.get('q', '').lower() query = request.args.get("q", "").lower()
limit = int(request.args.get('limit', 50)) limit = int(request.args.get("limit", 50))
data = load_graph_data() data = load_graph_data()
# 简单的文本匹配搜索 # 简单的文本匹配搜索
results = [] results = []
for memory in data['memories']: for memory in data["memories"]:
text = memory.get('text', '').lower() text = memory.get("text", "").lower()
if query in text: if query in text:
results.append(memory) results.append(memory)
return jsonify({ return jsonify({
'success': True, "success": True,
'data': { "data": {
'results': results[:limit], "results": results[:limit],
'count': len(results), "count": len(results),
} }
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/stats') @app.route("/api/stats")
def get_statistics(): def get_statistics():
"""获取统计信息""" """获取统计信息"""
try: try:
data = load_graph_data() data = load_graph_data()
# 扩展统计信息 # 扩展统计信息
node_types = {} node_types = {}
memory_types = {} memory_types = {}
for node in data['nodes']: for node in data["nodes"]:
node_type = node.get('type', 'Unknown') node_type = node.get("type", "Unknown")
node_types[node_type] = node_types.get(node_type, 0) + 1 node_types[node_type] = node_types.get(node_type, 0) + 1
for memory in data['memories']: for memory in data["memories"]:
mem_type = memory.get('type', 'Unknown') mem_type = memory.get("type", "Unknown")
memory_types[mem_type] = memory_types.get(mem_type, 0) + 1 memory_types[mem_type] = memory_types.get(mem_type, 0) + 1
stats = data.get('stats', {}) stats = data.get("stats", {})
stats['node_types'] = node_types stats["node_types"] = node_types
stats['memory_types'] = memory_types stats["memory_types"] = memory_types
return jsonify({ return jsonify({
'success': True, "success": True,
'data': stats "data": stats
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/reload') @app.route("/api/reload")
def reload_data(): def reload_data():
"""重新加载数据""" """重新加载数据"""
global graph_data_cache global graph_data_cache
graph_data_cache = None graph_data_cache = None
data = load_graph_data() data = load_graph_data()
return jsonify({ return jsonify({
'success': True, "success": True,
'message': '数据已重新加载', "message": "数据已重新加载",
'stats': data.get('stats', {}) "stats": data.get("stats", {})
}) })
@app.route('/api/files') @app.route("/api/files")
def list_files(): def list_files():
"""列出所有可用的数据文件""" """列出所有可用的数据文件"""
try: try:
files = find_available_data_files() files = find_available_data_files()
file_list = [] file_list = []
for f in files: for f in files:
stat = f.stat() stat = f.stat()
file_list.append({ file_list.append({
'path': str(f), "path": str(f),
'name': f.name, "name": f.name,
'size': stat.st_size, "size": stat.st_size,
'size_kb': round(stat.st_size / 1024, 2), "size_kb": round(stat.st_size / 1024, 2),
'modified': datetime.fromtimestamp(stat.st_mtime).isoformat(), "modified": datetime.fromtimestamp(stat.st_mtime).isoformat(),
'modified_readable': datetime.fromtimestamp(stat.st_mtime).strftime('%Y-%m-%d %H:%M:%S'), "modified_readable": datetime.fromtimestamp(stat.st_mtime).strftime("%Y-%m-%d %H:%M:%S"),
'is_current': str(f) == str(current_data_file) if current_data_file else False "is_current": str(f) == str(current_data_file) if current_data_file else False
}) })
return jsonify({ return jsonify({
'success': True, "success": True,
'files': file_list, "files": file_list,
'count': len(file_list), "count": len(file_list),
'current_file': str(current_data_file) if current_data_file else None "current_file": str(current_data_file) if current_data_file else None
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
@app.route('/api/select_file', methods=['POST']) @app.route("/api/select_file", methods=["POST"])
def select_file(): def select_file():
"""选择要加载的数据文件""" """选择要加载的数据文件"""
global graph_data_cache, current_data_file global graph_data_cache, current_data_file
try: try:
data = request.get_json() data = request.get_json()
file_path = data.get('file_path') file_path = data.get("file_path")
if not file_path: if not file_path:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': '未提供文件路径' "error": "未提供文件路径"
}), 400 }), 400
file_path = Path(file_path) file_path = Path(file_path)
if not file_path.exists(): if not file_path.exists():
return jsonify({ return jsonify({
'success': False, "success": False,
'error': f'文件不存在: {file_path}' "error": f"文件不存在: {file_path}"
}), 404 }), 404
# 清除缓存并加载新文件 # 清除缓存并加载新文件
graph_data_cache = None graph_data_cache = None
current_data_file = file_path current_data_file = file_path
graph_data = load_graph_data(file_path) graph_data = load_graph_data(file_path)
return jsonify({ return jsonify({
'success': True, "success": True,
'message': f'已切换到文件: {file_path.name}', "message": f"已切换到文件: {file_path.name}",
'stats': graph_data.get('stats', {}) "stats": graph_data.get("stats", {})
}) })
except Exception as e: except Exception as e:
return jsonify({ return jsonify({
'success': False, "success": False,
'error': str(e) "error": str(e)
}), 500 }), 500
def run_server(host: str = '127.0.0.1', port: int = 5001, debug: bool = False): def run_server(host: str = "127.0.0.1", port: int = 5001, debug: bool = False):
"""启动服务器""" """启动服务器"""
print("=" * 60) print("=" * 60)
print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)") print("🦊 MoFox Bot - 记忆图可视化工具 (独立版)")
@@ -463,14 +461,14 @@ def run_server(host: str = '127.0.0.1', port: int = 5001, debug: bool = False):
print("⏹️ 按 Ctrl+C 停止服务器") print("⏹️ 按 Ctrl+C 停止服务器")
print("=" * 60) print("=" * 60)
print() print()
# 预加载数据 # 预加载数据
load_graph_data() load_graph_data()
app.run(host=host, port=port, debug=debug) app.run(host=host, port=port, debug=debug)
if __name__ == '__main__': if __name__ == "__main__":
try: try:
run_server(debug=True) run_server(debug=True)
except KeyboardInterrupt: except KeyboardInterrupt: