Merge branch 'dev' of https://github.com/MoFox-Studio/MoFox_Bot into dev(我顺带再提一嘴把API放两个文件的简直是天才)

This commit is contained in:
minecraft1024a
2025-11-07 21:12:51 +08:00
45 changed files with 668 additions and 683 deletions

View File

@@ -17,19 +17,19 @@
import argparse import argparse
import json import json
from pathlib import Path
from typing import Dict, Any, List, Tuple
import logging import logging
from pathlib import Path
from typing import Any
import orjson import orjson
# 配置日志 # 配置日志
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s', format="%(asctime)s - %(levelname)s - %(message)s",
handlers=[ handlers=[
logging.StreamHandler(), logging.StreamHandler(),
logging.FileHandler('embedding_cleanup.log', encoding='utf-8') logging.FileHandler("embedding_cleanup.log", encoding="utf-8")
] ]
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -49,13 +49,13 @@ class EmbeddingCleaner:
self.cleaned_files = [] self.cleaned_files = []
self.errors = [] self.errors = []
self.stats = { self.stats = {
'files_processed': 0, "files_processed": 0,
'embedings_removed': 0, "embedings_removed": 0,
'bytes_saved': 0, "bytes_saved": 0,
'nodes_processed': 0 "nodes_processed": 0
} }
def find_json_files(self) -> List[Path]: def find_json_files(self) -> list[Path]:
"""查找可能包含向量数据的 JSON 文件""" """查找可能包含向量数据的 JSON 文件"""
json_files = [] json_files = []
@@ -65,7 +65,7 @@ class EmbeddingCleaner:
json_files.append(memory_graph_file) json_files.append(memory_graph_file)
# 测试数据文件 # 测试数据文件
test_dir = self.data_dir / "test_*" self.data_dir / "test_*"
for test_path in self.data_dir.glob("test_*/memory_graph.json"): for test_path in self.data_dir.glob("test_*/memory_graph.json"):
if test_path.exists(): if test_path.exists():
json_files.append(test_path) json_files.append(test_path)
@@ -82,7 +82,7 @@ class EmbeddingCleaner:
logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件") logger.info(f"找到 {len(json_files)} 个需要处理的 JSON 文件")
return json_files return json_files
def analyze_embedding_in_data(self, data: Dict[str, Any]) -> int: def analyze_embedding_in_data(self, data: dict[str, Any]) -> int:
""" """
分析数据中的 embedding 字段数量 分析数据中的 embedding 字段数量
@@ -97,7 +97,7 @@ class EmbeddingCleaner:
def count_embeddings(obj): def count_embeddings(obj):
nonlocal embedding_count nonlocal embedding_count
if isinstance(obj, dict): if isinstance(obj, dict):
if 'embedding' in obj: if "embedding" in obj:
embedding_count += 1 embedding_count += 1
for value in obj.values(): for value in obj.values():
count_embeddings(value) count_embeddings(value)
@@ -108,7 +108,7 @@ class EmbeddingCleaner:
count_embeddings(data) count_embeddings(data)
return embedding_count return embedding_count
def clean_embedding_from_data(self, data: Dict[str, Any]) -> Tuple[Dict[str, Any], int]: def clean_embedding_from_data(self, data: dict[str, Any]) -> tuple[dict[str, Any], int]:
""" """
从数据中移除 embedding 字段 从数据中移除 embedding 字段
@@ -123,8 +123,8 @@ class EmbeddingCleaner:
def remove_embeddings(obj): def remove_embeddings(obj):
nonlocal removed_count nonlocal removed_count
if isinstance(obj, dict): if isinstance(obj, dict):
if 'embedding' in obj: if "embedding" in obj:
del obj['embedding'] del obj["embedding"]
removed_count += 1 removed_count += 1
for value in obj.values(): for value in obj.values():
remove_embeddings(value) remove_embeddings(value)
@@ -162,14 +162,14 @@ class EmbeddingCleaner:
data = orjson.loads(original_content) data = orjson.loads(original_content)
except orjson.JSONDecodeError: except orjson.JSONDecodeError:
# 回退到标准 json # 回退到标准 json
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
# 分析 embedding 数据 # 分析 embedding 数据
embedding_count = self.analyze_embedding_in_data(data) embedding_count = self.analyze_embedding_in_data(data)
if embedding_count == 0: if embedding_count == 0:
logger.info(f" ✓ 文件中没有 embedding 数据,跳过") logger.info(" ✓ 文件中没有 embedding 数据,跳过")
return True return True
logger.info(f" 发现 {embedding_count} 个 embedding 字段") logger.info(f" 发现 {embedding_count} 个 embedding 字段")
@@ -193,30 +193,30 @@ class EmbeddingCleaner:
cleaned_data, cleaned_data,
indent=2, indent=2,
ensure_ascii=False ensure_ascii=False
).encode('utf-8') ).encode("utf-8")
cleaned_size = len(cleaned_content) cleaned_size = len(cleaned_content)
bytes_saved = original_size - cleaned_size bytes_saved = original_size - cleaned_size
# 原子写入 # 原子写入
temp_file = file_path.with_suffix('.tmp') temp_file = file_path.with_suffix(".tmp")
temp_file.write_bytes(cleaned_content) temp_file.write_bytes(cleaned_content)
temp_file.replace(file_path) temp_file.replace(file_path)
logger.info(f" ✓ 清理完成:") logger.info(" ✓ 清理完成:")
logger.info(f" - 移除 embedding 字段: {removed_count}") logger.info(f" - 移除 embedding 字段: {removed_count}")
logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)") logger.info(f" - 节省空间: {bytes_saved:,} 字节 ({bytes_saved/original_size*100:.1f}%)")
logger.info(f" - 新文件大小: {cleaned_size:,} 字节") logger.info(f" - 新文件大小: {cleaned_size:,} 字节")
# 更新统计 # 更新统计
self.stats['embedings_removed'] += removed_count self.stats["embedings_removed"] += removed_count
self.stats['bytes_saved'] += bytes_saved self.stats["bytes_saved"] += bytes_saved
else: else:
logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段") logger.info(f" [试运行] 将移除 {embedding_count} 个 embedding 字段")
self.stats['embedings_removed'] += embedding_count self.stats["embedings_removed"] += embedding_count
self.stats['files_processed'] += 1 self.stats["files_processed"] += 1
self.cleaned_files.append(file_path) self.cleaned_files.append(file_path)
return True return True
@@ -236,12 +236,12 @@ class EmbeddingCleaner:
节点数量 节点数量
""" """
try: try:
with open(file_path, 'r', encoding='utf-8') as f: with open(file_path, encoding="utf-8") as f:
data = json.load(f) data = json.load(f)
node_count = 0 node_count = 0
if 'nodes' in data and isinstance(data['nodes'], list): if "nodes" in data and isinstance(data["nodes"], list):
node_count = len(data['nodes']) node_count = len(data["nodes"])
return node_count return node_count
@@ -268,7 +268,7 @@ class EmbeddingCleaner:
# 统计总节点数 # 统计总节点数
total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files) total_nodes = sum(self.analyze_nodes_in_file(f) for f in json_files)
self.stats['nodes_processed'] = total_nodes self.stats["nodes_processed"] = total_nodes
logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点") logger.info(f"总计 {len(json_files)} 个文件,{total_nodes} 个节点")
@@ -295,8 +295,8 @@ class EmbeddingCleaner:
if not dry_run: if not dry_run:
logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节") logger.info(f"节省空间: {self.stats['bytes_saved']:,} 字节")
if self.stats['bytes_saved'] > 0: if self.stats["bytes_saved"] > 0:
mb_saved = self.stats['bytes_saved'] / 1024 / 1024 mb_saved = self.stats["bytes_saved"] / 1024 / 1024
logger.info(f"节省空间: {mb_saved:.2f} MB") logger.info(f"节省空间: {mb_saved:.2f} MB")
if self.errors: if self.errors:
@@ -342,7 +342,7 @@ def main():
print(" 请确保向量数据库正在正常工作。") print(" 请确保向量数据库正在正常工作。")
print() print()
response = input("确认继续?(yes/no): ") response = input("确认继续?(yes/no): ")
if response.lower() not in ['yes', 'y', '']: if response.lower() not in ["yes", "y", ""]:
print("操作已取消") print("操作已取消")
return return

View File

@@ -22,7 +22,6 @@ import argparse
import asyncio import asyncio
import gc import gc
import json import json
import os
import subprocess import subprocess
import sys import sys
import threading import threading
@@ -30,7 +29,6 @@ import time
from collections import defaultdict from collections import defaultdict
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional
import psutil import psutil
@@ -101,10 +99,10 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
if children: if children:
print(f" 子进程: {info['children_count']}") print(f" 子进程: {info['children_count']}")
print(f" 子进程内存: {info['children_mem_mb']:.2f} MB") print(f" 子进程内存: {info['children_mem_mb']:.2f} MB")
total_mem = info['rss_mb'] + info['children_mem_mb'] total_mem = info["rss_mb"] + info["children_mem_mb"]
print(f" 总内存: {total_mem:.2f} MB") print(f" 总内存: {total_mem:.2f} MB")
print(f"\n 📋 子进程详情:") print("\n 📋 子进程详情:")
for idx, child in enumerate(children, 1): for idx, child in enumerate(children, 1):
try: try:
child_mem = child.memory_info().rss / 1024 / 1024 child_mem = child.memory_info().rss / 1024 / 1024
@@ -119,13 +117,13 @@ async def monitor_bot_process(bot_process: subprocess.Popen, interval: int = 5):
if len(history) > 1: if len(history) > 1:
prev = history[-2] prev = history[-2]
rss_diff = info['rss_mb'] - prev['rss_mb'] rss_diff = info["rss_mb"] - prev["rss_mb"]
print(f"\n变化:") print("\n变化:")
print(f" RSS: {rss_diff:+.2f} MB") print(f" RSS: {rss_diff:+.2f} MB")
if rss_diff > 10: if rss_diff > 10:
print(f" ⚠️ 内存增长较快!") print(" ⚠️ 内存增长较快!")
if info['rss_mb'] > 1000: if info["rss_mb"] > 1000:
print(f" ⚠️ 内存使用超过 1GB") print(" ⚠️ 内存使用超过 1GB")
print(f"{'=' * 80}\n") print(f"{'=' * 80}\n")
await asyncio.sleep(interval) await asyncio.sleep(interval)
@@ -163,7 +161,7 @@ def save_process_history(history: list, pid: int):
f.write(f"RSS: {info['rss_mb']:.2f} MB\n") f.write(f"RSS: {info['rss_mb']:.2f} MB\n")
f.write(f"VMS: {info['vms_mb']:.2f} MB\n") f.write(f"VMS: {info['vms_mb']:.2f} MB\n")
f.write(f"占比: {info['percent']:.2f}%\n") f.write(f"占比: {info['percent']:.2f}%\n")
if info['children_count'] > 0: if info["children_count"] > 0:
f.write(f"子进程: {info['children_count']}\n") f.write(f"子进程: {info['children_count']}\n")
f.write(f"子进程内存: {info['children_mem_mb']:.2f} MB\n") f.write(f"子进程内存: {info['children_mem_mb']:.2f} MB\n")
f.write("\n") f.write("\n")
@@ -264,7 +262,7 @@ async def run_monitor_mode(interval: int):
class ObjectMemoryProfiler: class ObjectMemoryProfiler:
"""对象级内存分析器""" """对象级内存分析器"""
def __init__(self, interval: int = 10, output_file: Optional[str] = None, object_limit: int = 20): def __init__(self, interval: int = 10, output_file: str | None = None, object_limit: int = 20):
self.interval = interval self.interval = interval
self.output_file = output_file self.output_file = output_file
self.object_limit = object_limit self.object_limit = object_limit
@@ -274,7 +272,7 @@ class ObjectMemoryProfiler:
self.tracker = tracker.SummaryTracker() self.tracker = tracker.SummaryTracker()
self.iteration = 0 self.iteration = 0
def get_object_stats(self) -> Dict: def get_object_stats(self) -> dict:
"""获取当前进程的对象统计(所有线程)""" """获取当前进程的对象统计(所有线程)"""
if not PYMPLER_AVAILABLE: if not PYMPLER_AVAILABLE:
return {} return {}
@@ -317,7 +315,7 @@ class ObjectMemoryProfiler:
print(f"❌ 获取对象统计失败: {e}") print(f"❌ 获取对象统计失败: {e}")
return {} return {}
def _get_module_stats(self, all_objects: list) -> Dict: def _get_module_stats(self, all_objects: list) -> dict:
"""统计各模块的内存占用""" """统计各模块的内存占用"""
module_mem = defaultdict(lambda: {"count": 0, "size": 0}) module_mem = defaultdict(lambda: {"count": 0, "size": 0})
@@ -329,7 +327,7 @@ class ObjectMemoryProfiler:
if module_name: if module_name:
# 获取顶级模块名(例如 src.chat.xxx -> src # 获取顶级模块名(例如 src.chat.xxx -> src
top_module = module_name.split('.')[0] top_module = module_name.split(".")[0]
obj_size = sys.getsizeof(obj) obj_size = sys.getsizeof(obj)
module_mem[top_module]["count"] += 1 module_mem[top_module]["count"] += 1
@@ -351,7 +349,7 @@ class ObjectMemoryProfiler:
"total_modules": len(module_mem) "total_modules": len(module_mem)
} }
def print_stats(self, stats: Dict, iteration: int): def print_stats(self, stats: dict, iteration: int):
"""打印统计信息""" """打印统计信息"""
print("\n" + "=" * 80) print("\n" + "=" * 80)
print(f"🔍 对象级内存分析 #{iteration} - {time.strftime('%H:%M:%S')}") print(f"🔍 对象级内存分析 #{iteration} - {time.strftime('%H:%M:%S')}")
@@ -374,8 +372,8 @@ class ObjectMemoryProfiler:
print(f"{obj_type:<50} {obj_count:>12,} {size_str:>15}") print(f"{obj_type:<50} {obj_count:>12,} {size_str:>15}")
if "module_stats" in stats and stats["module_stats"]: if stats.get("module_stats"):
print(f"\n📚 模块内存占用 (前 20 个模块):\n") print("\n📚 模块内存占用 (前 20 个模块):\n")
print(f"{'模块名':<40} {'对象数':>12} {'总内存':>15}") print(f"{'模块名':<40} {'对象数':>12} {'总内存':>15}")
print("-" * 80) print("-" * 80)
@@ -402,7 +400,7 @@ class ObjectMemoryProfiler:
if "gc_stats" in stats: if "gc_stats" in stats:
gc_stats = stats["gc_stats"] gc_stats = stats["gc_stats"]
print(f"\n🗑️ 垃圾回收:") print("\n🗑️ 垃圾回收:")
print(f" 代 0: {gc_stats['collections'][0]:,}") print(f" 代 0: {gc_stats['collections'][0]:,}")
print(f" 代 1: {gc_stats['collections'][1]:,}") print(f" 代 1: {gc_stats['collections'][1]:,}")
print(f" 代 2: {gc_stats['collections'][2]:,}") print(f" 代 2: {gc_stats['collections'][2]:,}")
@@ -423,7 +421,7 @@ class ObjectMemoryProfiler:
self.tracker.print_diff() self.tracker.print_diff()
print("-" * 80) print("-" * 80)
def save_to_file(self, stats: Dict): def save_to_file(self, stats: dict):
"""保存统计信息到文件""" """保存统计信息到文件"""
if not self.output_file: if not self.output_file:
return return
@@ -441,7 +439,7 @@ class ObjectMemoryProfiler:
for obj_type, obj_count, obj_size in stats["summary"]: for obj_type, obj_count, obj_size in stats["summary"]:
f.write(f" {obj_type}: {obj_count:,} 个, {obj_size:,} 字节\n") f.write(f" {obj_type}: {obj_count:,} 个, {obj_size:,} 字节\n")
if "module_stats" in stats and stats["module_stats"]: if stats.get("module_stats"):
f.write("\n模块统计 (前 20 个):\n") f.write("\n模块统计 (前 20 个):\n")
for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]: for module_name, obj_count, obj_size in stats["module_stats"]["top_modules"]:
f.write(f" {module_name}: {obj_count:,} 个对象, {obj_size:,} 字节\n") f.write(f" {module_name}: {obj_count:,} 个对象, {obj_size:,} 字节\n")
@@ -452,7 +450,7 @@ class ObjectMemoryProfiler:
# 保存 JSONL # 保存 JSONL
jsonl_path = str(self.output_file) + ".jsonl" jsonl_path = str(self.output_file) + ".jsonl"
record = { record = {
"timestamp": time.strftime('%Y-%m-%d %H:%M:%S'), "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"iteration": self.iteration, "iteration": self.iteration,
"total_objects": stats.get("total_objects", 0), "total_objects": stats.get("total_objects", 0),
"threads": stats.get("threads", []), "threads": stats.get("threads", []),
@@ -479,7 +477,7 @@ class ObjectMemoryProfiler:
self.running = True self.running = True
def monitor_loop(): def monitor_loop():
print(f"🚀 对象分析器已启动") print("🚀 对象分析器已启动")
print(f" 监控间隔: {self.interval}") print(f" 监控间隔: {self.interval}")
print(f" 对象类型限制: {self.object_limit}") print(f" 对象类型限制: {self.object_limit}")
print(f" 输出文件: {self.output_file or ''}") print(f" 输出文件: {self.output_file or ''}")
@@ -506,14 +504,14 @@ class ObjectMemoryProfiler:
monitor_thread = threading.Thread(target=monitor_loop, daemon=True) monitor_thread = threading.Thread(target=monitor_loop, daemon=True)
monitor_thread.start() monitor_thread.start()
print(f"✓ 监控线程已启动\n") print("✓ 监控线程已启动\n")
def stop(self): def stop(self):
"""停止监控""" """停止监控"""
self.running = False self.running = False
def run_objects_mode(interval: int, output: Optional[str], object_limit: int): def run_objects_mode(interval: int, output: str | None, object_limit: int):
"""对象分析模式主函数""" """对象分析模式主函数"""
if not PYMPLER_AVAILABLE: if not PYMPLER_AVAILABLE:
print("❌ pympler 未安装,无法使用对象分析模式") print("❌ pympler 未安装,无法使用对象分析模式")
@@ -549,9 +547,9 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
try: try:
import bot import bot
if hasattr(bot, 'main_async'): if hasattr(bot, "main_async"):
asyncio.run(bot.main_async()) asyncio.run(bot.main_async())
elif hasattr(bot, 'main'): elif hasattr(bot, "main"):
bot.main() bot.main()
else: else:
print("⚠️ bot.py 未找到 main_async() 或 main() 函数") print("⚠️ bot.py 未找到 main_async() 或 main() 函数")
@@ -577,10 +575,10 @@ def run_objects_mode(interval: int, output: Optional[str], object_limit: int):
# 可视化模式 # 可视化模式
# ============================================================================ # ============================================================================
def load_jsonl(path: Path) -> List[Dict]: def load_jsonl(path: Path) -> list[dict]:
"""加载 JSONL 文件""" """加载 JSONL 文件"""
snapshots = [] snapshots = []
with open(path, "r", encoding="utf-8") as f: with open(path, encoding="utf-8") as f:
for line in f: for line in f:
line = line.strip() line = line.strip()
if not line: if not line:
@@ -592,7 +590,7 @@ def load_jsonl(path: Path) -> List[Dict]:
return snapshots return snapshots
def aggregate_top_types(snapshots: List[Dict], top_n: int = 10): def aggregate_top_types(snapshots: list[dict], top_n: int = 10):
"""聚合前 N 个对象类型的时间序列""" """聚合前 N 个对象类型的时间序列"""
type_max = defaultdict(int) type_max = defaultdict(int)
for snap in snapshots: for snap in snapshots:
@@ -622,7 +620,7 @@ def aggregate_top_types(snapshots: List[Dict], top_n: int = 10):
return times, series return times, series
def plot_series(times: List, series: Dict, output: Path, top_n: int): def plot_series(times: list, series: dict, output: Path, top_n: int):
"""绘制时间序列图""" """绘制时间序列图"""
plt.figure(figsize=(14, 8)) plt.figure(figsize=(14, 8))

View File

@@ -1,5 +1,4 @@
import os import os
import random
import time import time
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any

View File

@@ -127,7 +127,7 @@ class StyleLearner:
# 最后使用时间(越近越好) # 最后使用时间(越近越好)
last_used = self.learning_stats["style_last_used"].get(style_id, 0) last_used = self.learning_stats["style_last_used"].get(style_id, 0)
time_since_used = current_time - last_used if last_used > 0 else float('inf') time_since_used = current_time - last_used if last_used > 0 else float("inf")
# 综合分数:使用次数越多越好,距离上次使用时间越短越好 # 综合分数:使用次数越多越好,距离上次使用时间越短越好
# 使用对数来平滑使用次数的影响 # 使用对数来平滑使用次数的影响

View File

@@ -8,7 +8,6 @@ from datetime import datetime
from typing import Any from typing import Any
import numpy as np import numpy as np
import orjson
from sqlalchemy import select from sqlalchemy import select
from src.common.config_helpers import resolve_embedding_dimension from src.common.config_helpers import resolve_embedding_dimension
@@ -605,7 +604,7 @@ class BotInterestManager:
) )
# 如果有新生成的扩展embedding保存到缓存文件 # 如果有新生成的扩展embedding保存到缓存文件
if hasattr(self, '_new_expanded_embeddings_generated') and self._new_expanded_embeddings_generated: if hasattr(self, "_new_expanded_embeddings_generated") and self._new_expanded_embeddings_generated:
await self._save_embedding_cache_to_file(self.current_interests.personality_id) await self._save_embedding_cache_to_file(self.current_interests.personality_id)
self._new_expanded_embeddings_generated = False self._new_expanded_embeddings_generated = False
logger.debug("💾 已保存新生成的扩展embedding到缓存文件") logger.debug("💾 已保存新生成的扩展embedding到缓存文件")
@@ -670,59 +669,59 @@ class BotInterestManager:
tag_lower = tag_name.lower() tag_lower = tag_name.lower()
# 技术编程类标签(具体化描述) # 技术编程类标签(具体化描述)
if any(word in tag_lower for word in ['python', 'java', 'code', '代码', '编程', '脚本', '算法', '开发']): if any(word in tag_lower for word in ["python", "java", "code", "代码", "编程", "脚本", "算法", "开发"]):
if 'python' in tag_lower: if "python" in tag_lower:
return f"讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题" return "讨论Python编程语言、写Python代码、Python脚本开发、Python技术问题"
elif '算法' in tag_lower: elif "算法" in tag_lower:
return f"讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化" return "讨论算法题目、数据结构、编程竞赛、刷LeetCode题目、代码优化"
elif '代码' in tag_lower or '被窝' in tag_lower: elif "代码" in tag_lower or "被窝" in tag_lower:
return f"讨论写代码、编程开发、代码实现、技术方案、编程技巧" return "讨论写代码、编程开发、代码实现、技术方案、编程技巧"
else: else:
return f"讨论编程开发、软件技术、代码编写、技术实现" return "讨论编程开发、软件技术、代码编写、技术实现"
# 情感表达类标签(具体化为真实对话场景) # 情感表达类标签(具体化为真实对话场景)
elif any(word in tag_lower for word in ['治愈', '撒娇', '安慰', '呼噜', '', '卖萌']): elif any(word in tag_lower for word in ["治愈", "撒娇", "安慰", "呼噜", "", "卖萌"]):
return f"想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话" return "想要获得安慰、寻求温暖关怀、撒娇卖萌、表达亲昵、求抱抱求陪伴的对话"
# 游戏娱乐类标签(具体游戏场景) # 游戏娱乐类标签(具体游戏场景)
elif any(word in tag_lower for word in ['游戏', '网游', 'mmo', '', '']): elif any(word in tag_lower for word in ["游戏", "网游", "mmo", "", ""]):
return f"讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得" return "讨论网络游戏、MMO游戏、游戏玩法、组队打副本、游戏攻略心得"
# 动漫影视类标签(具体观看行为) # 动漫影视类标签(具体观看行为)
elif any(word in tag_lower for word in ['', '动漫', '视频', 'b站', '弹幕', '追番', '云新番']): elif any(word in tag_lower for word in ["", "动漫", "视频", "b站", "弹幕", "追番", "云新番"]):
# 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西" # 特别处理"云新番" - 它的意思是在网上看新动漫,不是泛泛的"新东西"
if '' in tag_lower or '新番' in tag_lower: if "" in tag_lower or "新番" in tag_lower:
return f"讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色" return "讨论正在播出的新动漫、新番剧集、动漫剧情、追番心得、动漫角色"
else: else:
return f"讨论动漫番剧内容、B站视频、弹幕文化、追番体验" return "讨论动漫番剧内容、B站视频、弹幕文化、追番体验"
# 社交平台类标签(具体平台行为) # 社交平台类标签(具体平台行为)
elif any(word in tag_lower for word in ['小红书', '贴吧', '论坛', '社区', '吃瓜', '八卦']): elif any(word in tag_lower for word in ["小红书", "贴吧", "论坛", "社区", "吃瓜", "八卦"]):
if '吃瓜' in tag_lower: if "吃瓜" in tag_lower:
return f"聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题" return "聊八卦爆料、吃瓜看热闹、网络热点事件、社交平台热议话题"
else: else:
return f"讨论社交平台内容、网络社区话题、论坛讨论、分享生活" return "讨论社交平台内容、网络社区话题、论坛讨论、分享生活"
# 生活日常类标签(具体萌宠场景) # 生活日常类标签(具体萌宠场景)
elif any(word in tag_lower for word in ['', '宠物', '尾巴', '耳朵', '毛绒']): elif any(word in tag_lower for word in ["", "宠物", "尾巴", "耳朵", "毛绒"]):
return f"讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得" return "讨论猫咪宠物、晒猫分享、萌宠日常、可爱猫猫、养猫心得"
# 状态心情类标签(具体情绪状态) # 状态心情类标签(具体情绪状态)
elif any(word in tag_lower for word in ['社恐', '隐身', '流浪', '深夜', '被窝']): elif any(word in tag_lower for word in ["社恐", "隐身", "流浪", "深夜", "被窝"]):
if '社恐' in tag_lower: if "社恐" in tag_lower:
return f"表达社交焦虑、不想见人、想躲起来、害怕社交的心情" return "表达社交焦虑、不想见人、想躲起来、害怕社交的心情"
elif '深夜' in tag_lower: elif "深夜" in tag_lower:
return f"深夜睡不着、熬夜、夜猫子、深夜思考人生的对话" return "深夜睡不着、熬夜、夜猫子、深夜思考人生的对话"
else: else:
return f"表达当前心情状态、个人感受、生活状态" return "表达当前心情状态、个人感受、生活状态"
# 物品装备类标签(具体使用场景) # 物品装备类标签(具体使用场景)
elif any(word in tag_lower for word in ['键盘', '耳机', '装备', '设备']): elif any(word in tag_lower for word in ["键盘", "耳机", "装备", "设备"]):
return f"讨论键盘耳机装备、数码产品、使用体验、装备推荐评测" return "讨论键盘耳机装备、数码产品、使用体验、装备推荐评测"
# 互动关系类标签 # 互动关系类标签
elif any(word in tag_lower for word in ['拾风', '互怼', '互动']): elif any(word in tag_lower for word in ["拾风", "互怼", "互动"]):
return f"聊天互动、开玩笑、友好互怼、日常对话交流" return "聊天互动、开玩笑、友好互怼、日常对话交流"
# 默认:尽量具体化 # 默认:尽量具体化
else: else:
@@ -1011,9 +1010,10 @@ class BotInterestManager:
async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None: async def _load_embedding_cache_from_file(self, personality_id: str) -> dict[str, list[float]] | None:
"""从文件加载embedding缓存""" """从文件加载embedding缓存"""
try: try:
import orjson
from pathlib import Path from pathlib import Path
import orjson
cache_dir = Path("data/embedding") cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)
cache_file = cache_dir / f"{personality_id}_embeddings.json" cache_file = cache_dir / f"{personality_id}_embeddings.json"
@@ -1053,9 +1053,10 @@ class BotInterestManager:
async def _save_embedding_cache_to_file(self, personality_id: str): async def _save_embedding_cache_to_file(self, personality_id: str):
"""保存embedding缓存到文件包括扩展标签的embedding""" """保存embedding缓存到文件包括扩展标签的embedding"""
try: try:
import orjson
from pathlib import Path
from datetime import datetime from datetime import datetime
from pathlib import Path
import orjson
cache_dir = Path("data/embedding") cache_dir = Path("data/embedding")
cache_dir.mkdir(parents=True, exist_ok=True) cache_dir.mkdir(parents=True, exist_ok=True)

View File

@@ -9,8 +9,8 @@ from .scheduler_dispatcher import SchedulerDispatcher, scheduler_dispatcher
__all__ = [ __all__ = [
"MessageManager", "MessageManager",
"SingleStreamContextManager",
"SchedulerDispatcher", "SchedulerDispatcher",
"SingleStreamContextManager",
"message_manager", "message_manager",
"scheduler_dispatcher", "scheduler_dispatcher",
] ]

View File

@@ -73,7 +73,7 @@ class SingleStreamContextManager:
cache_enabled = global_config.chat.enable_message_cache cache_enabled = global_config.chat.enable_message_cache
use_cache_system = message_manager.is_running and cache_enabled use_cache_system = message_manager.is_running and cache_enabled
if not cache_enabled: if not cache_enabled:
logger.debug(f"消息缓存系统已在配置中禁用") logger.debug("消息缓存系统已在配置中禁用")
except Exception as e: except Exception as e:
logger.debug(f"MessageManager不可用使用直接添加: {e}") logger.debug(f"MessageManager不可用使用直接添加: {e}")
use_cache_system = False use_cache_system = False

View File

@@ -4,13 +4,11 @@
""" """
import asyncio import asyncio
import random
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from src.chat.chatter_manager import ChatterManager from src.chat.chatter_manager import ChatterManager
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.planner_actions.action_manager import ChatterActionManager
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats

View File

@@ -367,7 +367,7 @@ class ChatBot:
message_segment = message_data.get("message_segment") message_segment = message_data.get("message_segment")
if message_segment and isinstance(message_segment, dict): if message_segment and isinstance(message_segment, dict):
if message_segment.get("type") == "adapter_response": if message_segment.get("type") == "adapter_response":
logger.info(f"[DEBUG bot.py message_process] 检测到adapter_response立即处理") logger.info("[DEBUG bot.py message_process] 检测到adapter_response立即处理")
await self._handle_adapter_response_from_dict(message_segment.get("data")) await self._handle_adapter_response_from_dict(message_segment.get("data"))
return return

View File

@@ -557,7 +557,7 @@ class DefaultReplyer:
participants = [] participants = []
try: try:
# 尝试从聊天流中获取参与者信息 # 尝试从聊天流中获取参与者信息
if hasattr(stream, 'chat_history_manager'): if hasattr(stream, "chat_history_manager"):
history_manager = stream.chat_history_manager history_manager = stream.chat_history_manager
# 获取最近的参与者列表 # 获取最近的参与者列表
recent_records = history_manager.get_memory_chat_history( recent_records = history_manager.get_memory_chat_history(
@@ -586,9 +586,9 @@ class DefaultReplyer:
formatted_history = "" formatted_history = ""
if chat_history: if chat_history:
# 移除过长的历史记录,只保留最近部分 # 移除过长的历史记录,只保留最近部分
lines = chat_history.strip().split('\n') lines = chat_history.strip().split("\n")
recent_lines = lines[-10:] if len(lines) > 10 else lines recent_lines = lines[-10:] if len(lines) > 10 else lines
formatted_history = '\n'.join(recent_lines) formatted_history = "\n".join(recent_lines)
query_context = { query_context = {
"chat_history": formatted_history, "chat_history": formatted_history,
@@ -725,7 +725,7 @@ class DefaultReplyer:
for tool_result in tool_results: for tool_result in tool_results:
tool_name = tool_result.get("tool_name", "unknown") tool_name = tool_result.get("tool_name", "unknown")
content = tool_result.get("content", "") content = tool_result.get("content", "")
result_type = tool_result.get("type", "tool_result") tool_result.get("type", "tool_result")
# 不进行截断,让工具自己处理结果长度 # 不进行截断,让工具自己处理结果长度
current_results_parts.append(f"- **{tool_name}**: {content}") current_results_parts.append(f"- **{tool_name}**: {content}")
@@ -1913,7 +1913,6 @@ class DefaultReplyer:
return return
# 使用统一记忆系统存储记忆 # 使用统一记忆系统存储记忆
from src.chat.memory_system import get_memory_system
stream = self.chat_stream stream = self.chat_stream
user_info_obj = getattr(stream, "user_info", None) user_info_obj = getattr(stream, "user_info", None)
@@ -2036,7 +2035,7 @@ class DefaultReplyer:
timestamp=time.time(), timestamp=time.time(),
limit=int(global_config.chat.max_context_size), limit=int(global_config.chat.max_context_size),
) )
chat_history = await build_readable_messages( await build_readable_messages(
message_list_before_short, message_list_before_short,
replace_bot_name=True, replace_bot_name=True,
merge_messages=False, merge_messages=False,

View File

@@ -91,13 +91,11 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
# 1. 判断是否为私聊(强提及) # 1. 判断是否为私聊(强提及)
group_info = getattr(message, "group_info", None) group_info = getattr(message, "group_info", None)
if not group_info or not getattr(group_info, "group_id", None): if not group_info or not getattr(group_info, "group_id", None):
is_private = True
mention_type = 2 mention_type = 2
logger.debug("检测到私聊消息 - 强提及") logger.debug("检测到私聊消息 - 强提及")
# 2. 判断是否被@(强提及) # 2. 判断是否被@(强提及)
if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text): if re.search(rf"@<(.+?):{global_config.bot.qq_account}>", processed_text):
is_at = True
mention_type = 2 mention_type = 2
logger.debug("检测到@提及 - 强提及") logger.debug("检测到@提及 - 强提及")
@@ -108,7 +106,6 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:", rf"\[回复<(.+?)(?=:{global_config.bot.qq_account!s}>)\:{global_config.bot.qq_account!s}>(.+?)\],说:",
processed_text, processed_text,
): ):
is_replied = True
mention_type = 2 mention_type = 2
logger.debug("检测到回复消息 - 强提及") logger.debug("检测到回复消息 - 强提及")
@@ -122,14 +119,12 @@ def is_mentioned_bot_in_message(message) -> tuple[bool, float]:
# 检查bot主名字 # 检查bot主名字
if global_config.bot.nickname in message_content: if global_config.bot.nickname in message_content:
is_text_mentioned = True
mention_type = 1 mention_type = 1
logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及") logger.debug(f"检测到文本提及bot主名字 '{global_config.bot.nickname}' - 弱提及")
# 如果主名字没匹配,再检查别名 # 如果主名字没匹配,再检查别名
elif nicknames: elif nicknames:
for alias_name in nicknames: for alias_name in nicknames:
if alias_name in message_content: if alias_name in message_content:
is_text_mentioned = True
mention_type = 1 mention_type = 1
logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及") logger.debug(f"检测到文本提及bot别名 '{alias_name}' - 弱提及")
break break

View File

@@ -374,7 +374,7 @@ class CacheManager:
# 简化的健康统计,不包含内存监控(因为相关属性未定义) # 简化的健康统计,不包含内存监控(因为相关属性未定义)
return { return {
"l1_count": len(self.l1_kv_cache), "l1_count": len(self.l1_kv_cache),
"l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0, "l1_vector_count": self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0,
"tool_stats": { "tool_stats": {
"total_tool_calls": self.tool_stats.get("total_tool_calls", 0), "total_tool_calls": self.tool_stats.get("total_tool_calls", 0),
"tracked_tools": len(self.tool_stats.get("most_used_tools", {})), "tracked_tools": len(self.tool_stats.get("most_used_tools", {})),
@@ -397,7 +397,7 @@ class CacheManager:
warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}") warnings.append(f"⚠️ L1缓存条目数较多: {l1_size}")
# 检查向量索引大小 # 检查向量索引大小
vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, 'ntotal') else 0 vector_count = self.l1_vector_index.ntotal if hasattr(self.l1_vector_index, "ntotal") else 0
if isinstance(vector_count, int) and vector_count > 500: if isinstance(vector_count, int) and vector_count > 500:
warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}") warnings.append(f"⚠️ 向量索引条目数较多: {vector_count}")

View File

@@ -539,7 +539,6 @@ class AdaptiveBatchScheduler:
def _set_cache(self, cache_key: str, result: Any) -> None: def _set_cache(self, cache_key: str, result: Any) -> None:
"""设置缓存(改进版,带大小限制和内存统计)""" """设置缓存(改进版,带大小限制和内存统计)"""
import sys
# 🔧 检查缓存大小限制 # 🔧 检查缓存大小限制
if len(self._result_cache) >= self._cache_max_size: if len(self._result_cache) >= self._cache_max_size:

View File

@@ -4,9 +4,10 @@
提供比 sys.getsizeof() 更准确的内存占用估算方法 提供比 sys.getsizeof() 更准确的内存占用估算方法
""" """
import sys
import pickle import pickle
import sys
from typing import Any from typing import Any
import numpy as np import numpy as np
@@ -44,15 +45,15 @@ def get_accurate_size(obj: Any, seen: set | None = None) -> int:
for k, v in obj.items()) for k, v in obj.items())
# 列表、元组、集合:递归计算所有元素 # 列表、元组、集合:递归计算所有元素
elif isinstance(obj, (list, tuple, set, frozenset)): elif isinstance(obj, list | tuple | set | frozenset):
size += sum(get_accurate_size(item, seen) for item in obj) size += sum(get_accurate_size(item, seen) for item in obj)
# 有 __dict__ 的对象:递归计算属性 # 有 __dict__ 的对象:递归计算属性
elif hasattr(obj, '__dict__'): elif hasattr(obj, "__dict__"):
size += get_accurate_size(obj.__dict__, seen) size += get_accurate_size(obj.__dict__, seen)
# 其他可迭代对象 # 其他可迭代对象
elif hasattr(obj, '__iter__') and not isinstance(obj, (str, bytes, bytearray)): elif hasattr(obj, "__iter__") and not isinstance(obj, str | bytes | bytearray):
try: try:
size += sum(get_accurate_size(item, seen) for item in obj) size += sum(get_accurate_size(item, seen) for item in obj)
except: except:
@@ -116,7 +117,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
size = sys.getsizeof(obj) size = sys.getsizeof(obj)
# 简单类型直接返回 # 简单类型直接返回
if isinstance(obj, (int, float, bool, type(None), str, bytes, bytearray)): if isinstance(obj, int | float | bool | type(None) | str | bytes | bytearray):
return size return size
# NumPy 数组特殊处理 # NumPy 数组特殊处理
@@ -144,7 +145,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
return size return size
# 列表、元组、集合递归 # 列表、元组、集合递归
if isinstance(obj, (list, tuple, set, frozenset)): if isinstance(obj, list | tuple | set | frozenset):
items = list(obj) items = list(obj)
if sample_large and len(items) > 100: if sample_large and len(items) > 100:
# 大容器采样前50 + 中间50 + 最后50 # 大容器采样前50 + 中间50 + 最后50
@@ -162,7 +163,7 @@ def _estimate_recursive(obj: Any, depth: int, seen: set, sample_large: bool) ->
return size return size
# 有 __dict__ 的对象 # 有 __dict__ 的对象
if hasattr(obj, '__dict__'): if hasattr(obj, "__dict__"):
size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large) size += _estimate_recursive(obj.__dict__, depth - 1, seen, sample_large)
return size return size

View File

@@ -2,7 +2,6 @@ import os
import shutil import shutil
import sys import sys
from datetime import datetime from datetime import datetime
from typing import Optional
import tomlkit import tomlkit
from pydantic import Field from pydantic import Field
@@ -381,7 +380,7 @@ class Config(ValidatedConfigBase):
notice: NoticeConfig = Field(..., description="Notice消息配置") notice: NoticeConfig = Field(..., description="Notice消息配置")
emoji: EmojiConfig = Field(..., description="表情配置") emoji: EmojiConfig = Field(..., description="表情配置")
expression: ExpressionConfig = Field(..., description="表达配置") expression: ExpressionConfig = Field(..., description="表达配置")
memory: Optional[MemoryConfig] = Field(default=None, description="记忆配置") memory: MemoryConfig | None = Field(default=None, description="记忆配置")
mood: MoodConfig = Field(..., description="情绪配置") mood: MoodConfig = Field(..., description="情绪配置")
reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置") reaction: ReactionConfig = Field(default_factory=ReactionConfig, description="反应规则配置")
chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置") chinese_typo: ChineseTypoConfig = Field(..., description="中文错别字配置")

View File

@@ -534,7 +534,7 @@ class _RequestExecutor:
model_name = model_info.name model_name = model_info.name
retry_interval = api_provider.retry_interval retry_interval = api_provider.retry_interval
if isinstance(e, (NetworkConnectionError, ReqAbortException)): if isinstance(e, NetworkConnectionError | ReqAbortException):
return await self._check_retry(remain_try, retry_interval, "连接异常", model_name) return await self._check_retry(remain_try, retry_interval, "连接异常", model_name)
elif isinstance(e, RespNotOkException): elif isinstance(e, RespNotOkException):
return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info) return await self._handle_resp_not_ok(e, model_info, api_provider, remain_try, messages_info)

View File

@@ -100,10 +100,10 @@ class VectorStore:
# 处理额外的元数据,将 list 转换为 JSON 字符串 # 处理额外的元数据,将 list 转换为 JSON 字符串
for key, value in node.metadata.items(): for key, value in node.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, list | dict):
import orjson import orjson
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value metadata[key] = value
else: else:
metadata[key] = str(value) metadata[key] = str(value)
@@ -149,9 +149,9 @@ class VectorStore:
"created_at": n.created_at.isoformat(), "created_at": n.created_at.isoformat(),
} }
for key, value in n.metadata.items(): for key, value in n.metadata.items():
if isinstance(value, (list, dict)): if isinstance(value, list | dict):
metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8") metadata[key] = orjson.dumps(value, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")
elif isinstance(value, (str, int, float, bool)) or value is None: elif isinstance(value, str | int | float | bool) or value is None:
metadata[key] = value # type: ignore metadata[key] = value # type: ignore
else: else:
metadata[key] = str(value) metadata[key] = str(value)

View File

@@ -4,8 +4,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import numpy as np import numpy as np
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -7,11 +7,12 @@
""" """
import atexit import atexit
import orjson
import os import os
import threading import threading
from typing import Any, ClassVar from typing import Any, ClassVar
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
# 获取日志记录器 # 获取日志记录器
@@ -125,7 +126,7 @@ class PluginStorage:
try: try:
with open(self.file_path, "w", encoding="utf-8") as f: with open(self.file_path, "w", encoding="utf-8") as f:
f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8')) f.write(orjson.dumps(self._data, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8"))
self._dirty = False # 保存后重置标志 self._dirty = False # 保存后重置标志
logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。") logger.debug(f"插件 '{self.name}' 的数据已成功保存到磁盘。")
except Exception as e: except Exception as e:

View File

@@ -5,12 +5,12 @@ MCP Client Manager
""" """
import asyncio import asyncio
import orjson
import shutil import shutil
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import mcp.types import mcp.types
import orjson
from fastmcp.client import Client, StdioTransport, StreamableHttpTransport from fastmcp.client import Client, StdioTransport, StreamableHttpTransport
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -4,11 +4,13 @@
""" """
import time import time
from typing import Any, Optional from dataclasses import dataclass, field
from dataclasses import dataclass, asdict, field from typing import Any
import orjson import orjson
from src.common.logger import get_logger
from src.common.cache_manager import tool_cache from src.common.cache_manager import tool_cache
from src.common.logger import get_logger
logger = get_logger("stream_tool_history") logger = get_logger("stream_tool_history")
@@ -18,10 +20,10 @@ class ToolCallRecord:
"""工具调用记录""" """工具调用记录"""
tool_name: str tool_name: str
args: dict[str, Any] args: dict[str, Any]
result: Optional[dict[str, Any]] = None result: dict[str, Any] | None = None
status: str = "success" # success, error, pending status: str = "success" # success, error, pending
timestamp: float = field(default_factory=time.time) timestamp: float = field(default_factory=time.time)
execution_time: Optional[float] = None # 执行耗时(秒) execution_time: float | None = None # 执行耗时(秒)
cache_hit: bool = False # 是否命中缓存 cache_hit: bool = False # 是否命中缓存
result_preview: str = "" # 结果预览 result_preview: str = "" # 结果预览
error_message: str = "" # 错误信息 error_message: str = "" # 错误信息
@@ -32,9 +34,9 @@ class ToolCallRecord:
content = self.result.get("content", "") content = self.result.get("content", "")
if isinstance(content, str): if isinstance(content, str):
self.result_preview = content[:500] + ("..." if len(content) > 500 else "") self.result_preview = content[:500] + ("..." if len(content) > 500 else "")
elif isinstance(content, (list, dict)): elif isinstance(content, list | dict):
try: try:
self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode('utf-8')[:500] + "..." self.result_preview = orjson.dumps(content, option=orjson.OPT_NON_STR_KEYS).decode("utf-8")[:500] + "..."
except Exception: except Exception:
self.result_preview = str(content)[:500] + "..." self.result_preview = str(content)[:500] + "..."
else: else:
@@ -105,7 +107,7 @@ class StreamToolHistoryManager:
logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}") logger.debug(f"[{self.chat_id}] 添加工具调用记录: {record.tool_name}, 缓存命中: {record.cache_hit}")
async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: async def get_cached_result(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""从缓存或历史记录中获取结果 """从缓存或历史记录中获取结果
Args: Args:
@@ -160,9 +162,9 @@ class StreamToolHistoryManager:
return None return None
async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any], async def cache_result(self, tool_name: str, args: dict[str, Any], result: dict[str, Any],
execution_time: Optional[float] = None, execution_time: float | None = None,
tool_file_path: Optional[str] = None, tool_file_path: str | None = None,
ttl: Optional[int] = None) -> None: ttl: int | None = None) -> None:
"""缓存工具调用结果 """缓存工具调用结果
Args: Args:
@@ -207,7 +209,7 @@ class StreamToolHistoryManager:
except Exception as e: except Exception as e:
logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}") logger.warning(f"[{self.chat_id}] 缓存设置失败: {e}")
async def get_recent_history(self, count: int = 5, status_filter: Optional[str] = None) -> list[ToolCallRecord]: async def get_recent_history(self, count: int = 5, status_filter: str | None = None) -> list[ToolCallRecord]:
"""获取最近的历史记录 """获取最近的历史记录
Args: Args:
@@ -295,7 +297,7 @@ class StreamToolHistoryManager:
self._history.clear() self._history.clear()
logger.info(f"[{self.chat_id}] 工具历史记录已清除") logger.info(f"[{self.chat_id}] 工具历史记录已清除")
def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> Optional[dict[str, Any]]: def _search_memory_cache(self, tool_name: str, args: dict[str, Any]) -> dict[str, Any] | None:
"""在内存历史记录中搜索缓存 """在内存历史记录中搜索缓存
Args: Args:
@@ -333,7 +335,7 @@ class StreamToolHistoryManager:
return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py") return tool_path_mapping.get(tool_name, f"src/plugins/tools/{tool_name}.py")
def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> Optional[str]: def _extract_semantic_query(self, tool_name: str, args: dict[str, Any]) -> str | None:
"""提取语义查询参数 """提取语义查询参数
Args: Args:
@@ -370,7 +372,7 @@ class StreamToolHistoryManager:
return "" return ""
try: try:
args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode('utf-8') args_str = orjson.dumps(args, option=orjson.OPT_SORT_KEYS).decode("utf-8")
if len(args_str) > max_length: if len(args_str) > max_length:
args_str = args_str[:max_length] + "..." args_str = args_str[:max_length] + "..."
return args_str return args_str

View File

@@ -1,5 +1,6 @@
import inspect import inspect
import time import time
from dataclasses import asdict
from typing import Any from typing import Any
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
@@ -10,8 +11,7 @@ from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance from src.plugin_system.apis.tool_api import get_llm_available_tool_definitions, get_tool_instance
from src.plugin_system.base.base_tool import BaseTool from src.plugin_system.base.base_tool import BaseTool
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
from src.plugin_system.core.stream_tool_history import get_stream_tool_history_manager, ToolCallRecord from src.plugin_system.core.stream_tool_history import ToolCallRecord, get_stream_tool_history_manager
from dataclasses import asdict
logger = get_logger("tool_use") logger = get_logger("tool_use")
@@ -338,9 +338,8 @@ class ToolExecutor:
if tool_instance and result and tool_instance.enable_cache: if tool_instance and result and tool_instance.enable_cache:
try: try:
tool_file_path = inspect.getfile(tool_instance.__class__) tool_file_path = inspect.getfile(tool_instance.__class__)
semantic_query = None
if tool_instance.semantic_cache_query_key: if tool_instance.semantic_cache_query_key:
semantic_query = function_args.get(tool_instance.semantic_cache_query_key) function_args.get(tool_instance.semantic_cache_query_key)
await self.history_manager.cache_result( await self.history_manager.cache_result(
tool_name=tool_call.func_name, tool_name=tool_call.func_name,

View File

@@ -217,7 +217,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
return 0.0 return 0.0
except asyncio.TimeoutError: except asyncio.TimeoutError:
logger.warning(f"⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数") logger.warning("⏱️ 兴趣匹配计算超时(>1.5秒)返回默认分值0.5以保留其他分数")
return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分 return 0.5 # 超时时返回默认分值,避免丢失提及分和关系分
except Exception as e: except Exception as e:
logger.warning(f"智能兴趣匹配失败: {e}") logger.warning(f"智能兴趣匹配失败: {e}")

View File

@@ -4,10 +4,10 @@ AffinityFlow Chatter 规划器模块
包含计划生成、过滤、执行等规划相关功能 包含计划生成、过滤、执行等规划相关功能
""" """
from . import planner_prompts
from .plan_executor import ChatterPlanExecutor from .plan_executor import ChatterPlanExecutor
from .plan_filter import ChatterPlanFilter from .plan_filter import ChatterPlanFilter
from .plan_generator import ChatterPlanGenerator from .plan_generator import ChatterPlanGenerator
from .planner import ChatterActionPlanner from .planner import ChatterActionPlanner
from . import planner_prompts
__all__ = ["ChatterActionPlanner", "planner_prompts", "ChatterPlanGenerator", "ChatterPlanFilter", "ChatterPlanExecutor"] __all__ = ["ChatterActionPlanner", "ChatterPlanExecutor", "ChatterPlanFilter", "ChatterPlanGenerator", "planner_prompts"]

View File

@@ -14,9 +14,7 @@ from json_repair import repair_json
# 旧的Hippocampus系统已被移除现在使用增强记忆系统 # 旧的Hippocampus系统已被移除现在使用增强记忆系统
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager # from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_actions,
build_readable_messages_with_id, build_readable_messages_with_id,
get_actions_by_timestamp_with_chat,
) )
from src.chat.utils.prompt import global_prompt_manager from src.chat.utils.prompt import global_prompt_manager
from src.common.data_models.info_data_model import ActionPlannerInfo, Plan from src.common.data_models.info_data_model import ActionPlannerInfo, Plan

View File

@@ -21,7 +21,6 @@ if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
# 导入提示词模块以确保其被初始化 # 导入提示词模块以确保其被初始化
from src.plugins.built_in.affinity_flow_chatter.planner import planner_prompts
logger = get_logger("planner") logger = get_logger("planner")

View File

@@ -9,9 +9,9 @@ from .proactive_thinking_executor import execute_proactive_thinking
from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler from .proactive_thinking_scheduler import ProactiveThinkingScheduler, proactive_thinking_scheduler
__all__ = [ __all__ = [
"ProactiveThinkingReplyHandler",
"ProactiveThinkingMessageHandler", "ProactiveThinkingMessageHandler",
"execute_proactive_thinking", "ProactiveThinkingReplyHandler",
"ProactiveThinkingScheduler", "ProactiveThinkingScheduler",
"execute_proactive_thinking",
"proactive_thinking_scheduler", "proactive_thinking_scheduler",
] ]

View File

@@ -3,7 +3,6 @@
当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复 当定时任务触发时负责搜集信息、调用LLM决策、并根据决策生成回复
""" """
import orjson
from datetime import datetime from datetime import datetime
from typing import Any, Literal from typing import Any, Literal

View File

@@ -14,7 +14,6 @@ from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.api_ada_configs import TaskConfig
from src.llm_models.utils_model import LLMRequest from src.llm_models.utils_model import LLMRequest
from src.plugin_system.apis import config_api, generator_api, llm_api from src.plugin_system.apis import config_api, generator_api, llm_api

View File

@@ -136,10 +136,10 @@ class QZoneService:
logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}") logger.info(f"[DEBUG] 准备获取API客户端qq_account={qq_account}")
api_client = await self._get_api_client(qq_account, stream_id) api_client = await self._get_api_client(qq_account, stream_id)
if not api_client: if not api_client:
logger.error(f"[DEBUG] API客户端获取失败返回错误") logger.error("[DEBUG] API客户端获取失败返回错误")
return {"success": False, "message": "获取QZone API客户端失败"} return {"success": False, "message": "获取QZone API客户端失败"}
logger.info(f"[DEBUG] API客户端获取成功准备读取说说") logger.info("[DEBUG] API客户端获取成功准备读取说说")
num_to_read = self.get_config("read.read_number", 5) num_to_read = self.get_config("read.read_number", 5)
# 尝试执行如果Cookie失效则自动重试一次 # 尝试执行如果Cookie失效则自动重试一次
@@ -186,7 +186,7 @@ class QZoneService:
# 检查是否是Cookie失效-3000错误 # 检查是否是Cookie失效-3000错误
if "错误码: -3000" in error_msg and retry_count == 0: if "错误码: -3000" in error_msg and retry_count == 0:
logger.warning(f"检测到Cookie失效-3000错误准备删除缓存并重试...") logger.warning("检测到Cookie失效-3000错误准备删除缓存并重试...")
# 删除Cookie缓存文件 # 删除Cookie缓存文件
cookie_file = self.cookie_service._get_cookie_file_path(qq_account) cookie_file = self.cookie_service._get_cookie_file_path(qq_account)
@@ -623,7 +623,7 @@ class QZoneService:
logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}") logger.error(f"获取API客户端失败Cookie中缺少关键的 'p_skey'。Cookie内容: {cookies}")
return None return None
logger.info(f"[DEBUG] p_skey获取成功") logger.info("[DEBUG] p_skey获取成功")
gtk = self._generate_gtk(p_skey) gtk = self._generate_gtk(p_skey)
uin = cookies.get("uin", "").lstrip("o") uin = cookies.get("uin", "").lstrip("o")
@@ -1230,7 +1230,7 @@ class QZoneService:
logger.error(f"监控好友动态失败: {e}", exc_info=True) logger.error(f"监控好友动态失败: {e}", exc_info=True)
return [] return []
logger.info(f"[DEBUG] API客户端构造完成返回包含6个方法的字典") logger.info("[DEBUG] API客户端构造完成返回包含6个方法的字典")
return { return {
"publish": _publish, "publish": _publish,
"list_feeds": _list_feeds, "list_feeds": _list_feeds,

View File

@@ -3,11 +3,12 @@
负责记录和管理已回复过的评论ID避免重复回复 负责记录和管理已回复过的评论ID避免重复回复
""" """
import orjson
import time import time
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("MaiZone.ReplyTrackerService") logger = get_logger("MaiZone.ReplyTrackerService")
@@ -117,8 +118,8 @@ class ReplyTrackerService:
temp_file = self.reply_record_file.with_suffix(".tmp") temp_file = self.reply_record_file.with_suffix(".tmp")
# 先写入临时文件 # 先写入临时文件
with open(temp_file, "w", encoding="utf-8") as f: with open(temp_file, "w", encoding="utf-8"):
orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode('utf-8') orjson.dumps(self.replied_comments, option=orjson.OPT_INDENT_2 | orjson.OPT_NON_STR_KEYS).decode("utf-8")
# 如果写入成功,重命名为正式文件 # 如果写入成功,重命名为正式文件
if temp_file.stat().st_size > 0: # 确保写入成功 if temp_file.stat().st_size > 0: # 确保写入成功

View File

@@ -1,7 +1,6 @@
import orjson import orjson
import random import random
import time import time
import random
import websockets as Server import websockets as Server
import uuid import uuid
from maim_message import ( from maim_message import (
@@ -205,7 +204,7 @@ class SendHandler:
# 发送响应回MoFox-Bot # 发送响应回MoFox-Bot
logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}") logger.debug(f"[DEBUG handle_adapter_command] 即将调用send_adapter_command_response, request_id={request_id}")
await self.send_adapter_command_response(raw_message_base, response, request_id) await self.send_adapter_command_response(raw_message_base, response, request_id)
logger.debug(f"[DEBUG handle_adapter_command] send_adapter_command_response调用完成") logger.debug("[DEBUG handle_adapter_command] send_adapter_command_response调用完成")
if response.get("status") == "ok": if response.get("status") == "ok":
logger.info(f"适配器命令 {action} 执行成功") logger.info(f"适配器命令 {action} 执行成功")

View File

@@ -1,10 +1,10 @@
""" """
Metaso Search Engine (Chat Completions Mode) Metaso Search Engine (Chat Completions Mode)
""" """
import orjson
from typing import Any from typing import Any
import httpx import httpx
import orjson
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api

View File

@@ -3,9 +3,10 @@ Serper search engine implementation
Google Search via Serper.dev API Google Search via Serper.dev API
""" """
import aiohttp
from typing import Any from typing import Any
import aiohttp
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api

View File

@@ -5,7 +5,7 @@ Web Search Tool Plugin
""" """
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, PythonDependency, register_plugin from src.plugin_system import BasePlugin, ComponentInfo, ConfigField, register_plugin
from src.plugin_system.apis import config_api from src.plugin_system.apis import config_api
from .tools.url_parser import URLParserTool from .tools.url_parser import URLParserTool

View File

@@ -113,7 +113,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
search_tasks.append(engine.answer_search(custom_args)) search_tasks.append(engine.answer_search(custom_args))
else: else:
search_tasks.append(engine.search(custom_args)) search_tasks.append(engine.search(custom_args))
@@ -162,7 +162,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索fallback策略") logger.info("使用Exa答案模式进行搜索fallback策略")
results = await engine.answer_search(custom_args) results = await engine.answer_search(custom_args)
else: else:
@@ -195,7 +195,7 @@ class WebSurfingTool(BaseTool):
custom_args["num_results"] = custom_args.get("num_results", 5) custom_args["num_results"] = custom_args.get("num_results", 5)
# 如果启用了answer模式且是Exa引擎使用answer_search方法 # 如果启用了answer模式且是Exa引擎使用answer_search方法
if answer_mode and engine_name == "exa" and hasattr(engine, 'answer_search'): if answer_mode and engine_name == "exa" and hasattr(engine, "answer_search"):
logger.info("使用Exa答案模式进行搜索") logger.info("使用Exa答案模式进行搜索")
results = await engine.answer_search(custom_args) results = await engine.answer_search(custom_args)
else: else: