This commit is contained in:
LuiKlee
2025-12-16 16:18:59 +08:00
parent c2a1d7b00b
commit 0feb878830
20 changed files with 251 additions and 261 deletions

View File

@@ -6,8 +6,8 @@ from pathlib import Path
project_root = Path(__file__).parent.parent project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root)) sys.path.insert(0, str(project_root))
from src.memory_graph.manager_singleton import get_unified_memory_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.memory_graph.manager_singleton import get_unified_memory_manager
logger = get_logger("memory_transfer_check") logger = get_logger("memory_transfer_check")
@@ -22,20 +22,20 @@ def print_section(title: str):
async def check_short_term_status(): async def check_short_term_status():
"""检查短期记忆状态""" """检查短期记忆状态"""
print_section("1. 短期记忆状态检查") print_section("1. 短期记忆状态检查")
manager = get_unified_memory_manager() manager = get_unified_memory_manager()
short_term = manager.short_term_manager short_term = manager.short_term_manager
# 获取统计信息 # 获取统计信息
stats = short_term.get_statistics() stats = short_term.get_statistics()
print(f"📊 当前记忆数量: {stats['total_memories']}/{stats['max_memories']}") print(f"📊 当前记忆数量: {stats['total_memories']}/{stats['max_memories']}")
# 计算占用率 # 计算占用率
if stats['max_memories'] > 0: if stats["max_memories"] > 0:
occupancy = stats['total_memories'] / stats['max_memories'] occupancy = stats["total_memories"] / stats["max_memories"]
print(f"📈 容量占用率: {occupancy:.1%}") print(f"📈 容量占用率: {occupancy:.1%}")
# 根据占用率给出建议 # 根据占用率给出建议
if occupancy >= 1.0: if occupancy >= 1.0:
print("⚠️ 警告:已达到容量上限!应该触发紧急转移") print("⚠️ 警告:已达到容量上限!应该触发紧急转移")
@@ -43,25 +43,25 @@ async def check_short_term_status():
print("✅ 占用率超过50%,符合自动转移条件") print("✅ 占用率超过50%,符合自动转移条件")
else: else:
print(f" 占用率未达到50%阈值,当前 {occupancy:.1%}") print(f" 占用率未达到50%阈值,当前 {occupancy:.1%}")
print(f"🎯 可转移记忆数: {stats['transferable_count']}") print(f"🎯 可转移记忆数: {stats['transferable_count']}")
print(f"📏 转移重要性阈值: {stats['transfer_threshold']}") print(f"📏 转移重要性阈值: {stats['transfer_threshold']}")
return stats return stats
async def check_transfer_candidates(): async def check_transfer_candidates():
"""检查当前可转移的候选记忆""" """检查当前可转移的候选记忆"""
print_section("2. 转移候选记忆分析") print_section("2. 转移候选记忆分析")
manager = get_unified_memory_manager() manager = get_unified_memory_manager()
short_term = manager.short_term_manager short_term = manager.short_term_manager
# 获取转移候选 # 获取转移候选
candidates = short_term.get_memories_for_transfer() candidates = short_term.get_memories_for_transfer()
print(f"🎫 当前转移候选: {len(candidates)}\n") print(f"🎫 当前转移候选: {len(candidates)}\n")
if not candidates: if not candidates:
print("❌ 没有记忆符合转移条件!") print("❌ 没有记忆符合转移条件!")
print("\n可能原因:") print("\n可能原因:")
@@ -69,7 +69,7 @@ async def check_transfer_candidates():
print(" 2. 短期记忆数量未超过容量限制") print(" 2. 短期记忆数量未超过容量限制")
print(" 3. 短期记忆列表为空") print(" 3. 短期记忆列表为空")
return [] return []
# 显示前5条候选的详细信息 # 显示前5条候选的详细信息
print("前 5 条候选记忆:\n") print("前 5 条候选记忆:\n")
for i, mem in enumerate(candidates[:5], 1): for i, mem in enumerate(candidates[:5], 1):
@@ -78,38 +78,38 @@ async def check_transfer_candidates():
print(f" 内容: {mem.content[:50]}...") print(f" 内容: {mem.content[:50]}...")
print(f" 创建时间: {mem.created_at}") print(f" 创建时间: {mem.created_at}")
print() print()
if len(candidates) > 5: if len(candidates) > 5:
print(f"... 还有 {len(candidates) - 5} 条候选记忆\n") print(f"... 还有 {len(candidates) - 5} 条候选记忆\n")
# 分析重要性分布 # 分析重要性分布
importance_levels = { importance_levels = {
"高 (>=0.8)": sum(1 for m in candidates if m.importance >= 0.8), "高 (>=0.8)": sum(1 for m in candidates if m.importance >= 0.8),
"中 (0.6-0.8)": sum(1 for m in candidates if 0.6 <= m.importance < 0.8), "中 (0.6-0.8)": sum(1 for m in candidates if 0.6 <= m.importance < 0.8),
"低 (<0.6)": sum(1 for m in candidates if m.importance < 0.6), "低 (<0.6)": sum(1 for m in candidates if m.importance < 0.6),
} }
print("📊 重要性分布:") print("📊 重要性分布:")
for level, count in importance_levels.items(): for level, count in importance_levels.items():
print(f" {level}: {count}") print(f" {level}: {count}")
return candidates return candidates
async def check_auto_transfer_task(): async def check_auto_transfer_task():
"""检查自动转移任务状态""" """检查自动转移任务状态"""
print_section("3. 自动转移任务状态") print_section("3. 自动转移任务状态")
manager = get_unified_memory_manager() manager = get_unified_memory_manager()
# 检查任务是否存在 # 检查任务是否存在
if not hasattr(manager, '_auto_transfer_task') or manager._auto_transfer_task is None: if not hasattr(manager, "_auto_transfer_task") or manager._auto_transfer_task is None:
print("❌ 自动转移任务未创建!") print("❌ 自动转移任务未创建!")
print("\n建议:调用 manager.initialize() 初始化系统") print("\n建议:调用 manager.initialize() 初始化系统")
return False return False
task = manager._auto_transfer_task task = manager._auto_transfer_task
# 检查任务状态 # 检查任务状态
if task.done(): if task.done():
print("❌ 自动转移任务已结束!") print("❌ 自动转移任务已结束!")
@@ -121,78 +121,78 @@ async def check_auto_transfer_task():
pass pass
print("\n建议:重启系统或手动重启任务") print("\n建议:重启系统或手动重启任务")
return False return False
print("✅ 自动转移任务正在运行") print("✅ 自动转移任务正在运行")
# 检查转移缓存 # 检查转移缓存
if hasattr(manager, '_transfer_cache'): if hasattr(manager, "_transfer_cache"):
cache_size = len(manager._transfer_cache) if manager._transfer_cache else 0 cache_size = len(manager._transfer_cache) if manager._transfer_cache else 0
print(f"📦 转移缓存: {cache_size} 条记忆") print(f"📦 转移缓存: {cache_size} 条记忆")
# 检查上次转移时间 # 检查上次转移时间
if hasattr(manager, '_last_transfer_time'): if hasattr(manager, "_last_transfer_time"):
from datetime import datetime from datetime import datetime
last_time = manager._last_transfer_time last_time = manager._last_transfer_time
if last_time: if last_time:
time_diff = (datetime.now() - last_time).total_seconds() time_diff = (datetime.now() - last_time).total_seconds()
print(f"⏱️ 距上次转移: {time_diff:.1f} 秒前") print(f"⏱️ 距上次转移: {time_diff:.1f} 秒前")
return True return True
async def check_long_term_status(): async def check_long_term_status():
"""检查长期记忆状态""" """检查长期记忆状态"""
print_section("4. 长期记忆图谱状态") print_section("4. 长期记忆图谱状态")
manager = get_unified_memory_manager() manager = get_unified_memory_manager()
long_term = manager.long_term_manager long_term = manager.long_term_manager
# 获取图谱统计 # 获取图谱统计
stats = long_term.get_statistics() stats = long_term.get_statistics()
print(f"👥 人物节点数: {stats.get('person_count', 0)}") print(f"👥 人物节点数: {stats.get('person_count', 0)}")
print(f"📅 事件节点数: {stats.get('event_count', 0)}") print(f"📅 事件节点数: {stats.get('event_count', 0)}")
print(f"🔗 关系边数: {stats.get('edge_count', 0)}") print(f"🔗 关系边数: {stats.get('edge_count', 0)}")
print(f"💾 向量存储数: {stats.get('vector_count', 0)}") print(f"💾 向量存储数: {stats.get('vector_count', 0)}")
return stats return stats
async def manual_transfer_test(): async def manual_transfer_test():
"""手动触发转移测试""" """手动触发转移测试"""
print_section("5. 手动转移测试") print_section("5. 手动转移测试")
manager = get_unified_memory_manager() manager = get_unified_memory_manager()
# 询问用户是否执行 # 询问用户是否执行
print("⚠️ 即将手动触发一次记忆转移") print("⚠️ 即将手动触发一次记忆转移")
print("这将把当前符合条件的短期记忆转移到长期记忆") print("这将把当前符合条件的短期记忆转移到长期记忆")
response = input("\n是否继续? (y/n): ").strip().lower() response = input("\n是否继续? (y/n): ").strip().lower()
if response != 'y': if response != "y":
print("❌ 已取消手动转移") print("❌ 已取消手动转移")
return None return None
print("\n🚀 开始手动转移...") print("\n🚀 开始手动转移...")
try: try:
# 执行手动转移 # 执行手动转移
result = await manager.manual_transfer() result = await manager.manual_transfer()
print("\n✅ 转移完成!") print("\n✅ 转移完成!")
print(f"\n转移结果:") print("\n转移结果:")
print(f" 已处理: {result.get('processed_count', 0)}") print(f" 已处理: {result.get('processed_count', 0)}")
print(f" 成功转移: {len(result.get('transferred_memory_ids', []))}") print(f" 成功转移: {len(result.get('transferred_memory_ids', []))}")
print(f" 失败: {result.get('failed_count', 0)}") print(f" 失败: {result.get('failed_count', 0)}")
print(f" 跳过: {result.get('skipped_count', 0)}") print(f" 跳过: {result.get('skipped_count', 0)}")
if result.get('errors'): if result.get("errors"):
print(f"\n错误信息:") print("\n错误信息:")
for error in result['errors'][:3]: # 只显示前3个错误 for error in result["errors"][:3]: # 只显示前3个错误
print(f" - {error}") print(f" - {error}")
return result return result
except Exception as e: except Exception as e:
print(f"\n❌ 转移失败: {e}") print(f"\n❌ 转移失败: {e}")
logger.exception("手动转移失败") logger.exception("手动转移失败")
@@ -202,29 +202,29 @@ async def manual_transfer_test():
async def check_configuration(): async def check_configuration():
"""检查相关配置""" """检查相关配置"""
print_section("6. 配置参数检查") print_section("6. 配置参数检查")
from src.config.config import global_config from src.config.config import global_config
config = global_config.memory config = global_config.memory
print("📋 当前配置:") print("📋 当前配置:")
print(f" 短期记忆容量: {config.short_term_max_memories}") print(f" 短期记忆容量: {config.short_term_max_memories}")
print(f" 转移重要性阈值: {config.short_term_transfer_threshold}") print(f" 转移重要性阈值: {config.short_term_transfer_threshold}")
print(f" 批量转移大小: {config.long_term_batch_size}") print(f" 批量转移大小: {config.long_term_batch_size}")
print(f" 自动转移间隔: {config.long_term_auto_transfer_interval}") print(f" 自动转移间隔: {config.long_term_auto_transfer_interval}")
print(f" 启用泄压清理: {config.short_term_enable_force_cleanup}") print(f" 启用泄压清理: {config.short_term_enable_force_cleanup}")
# 给出配置建议 # 给出配置建议
print("\n💡 配置建议:") print("\n💡 配置建议:")
if config.short_term_transfer_threshold > 0.6: if config.short_term_transfer_threshold > 0.6:
print(" ⚠️ 转移阈值较高(>0.6),可能导致记忆难以转移") print(" ⚠️ 转移阈值较高(>0.6),可能导致记忆难以转移")
print(" 建议:降低到 0.4-0.5") print(" 建议:降低到 0.4-0.5")
if config.long_term_batch_size > 10: if config.long_term_batch_size > 10:
print(" ⚠️ 批量大小较大(>10),可能延迟转移触发") print(" ⚠️ 批量大小较大(>10),可能延迟转移触发")
print(" 建议:设置为 5-10") print(" 建议:设置为 5-10")
if config.long_term_auto_transfer_interval > 300: if config.long_term_auto_transfer_interval > 300:
print(" ⚠️ 转移间隔较长(>5分钟),可能导致转移不及时") print(" ⚠️ 转移间隔较长(>5分钟),可能导致转移不及时")
print(" 建议:设置为 60-180 秒") print(" 建议:设置为 60-180 秒")
@@ -235,37 +235,37 @@ async def main():
print("\n" + "=" * 60) print("\n" + "=" * 60)
print(" MoFox-Bot 记忆转移诊断工具") print(" MoFox-Bot 记忆转移诊断工具")
print("=" * 60) print("=" * 60)
try: try:
# 初始化管理器 # 初始化管理器
print("\n⚙️ 正在初始化记忆管理器...") print("\n⚙️ 正在初始化记忆管理器...")
manager = get_unified_memory_manager() manager = get_unified_memory_manager()
await manager.initialize() await manager.initialize()
print("✅ 初始化完成\n") print("✅ 初始化完成\n")
# 执行各项检查 # 执行各项检查
await check_short_term_status() await check_short_term_status()
candidates = await check_transfer_candidates() candidates = await check_transfer_candidates()
task_running = await check_auto_transfer_task() task_running = await check_auto_transfer_task()
await check_long_term_status() await check_long_term_status()
await check_configuration() await check_configuration()
# 综合诊断 # 综合诊断
print_section("7. 综合诊断结果") print_section("7. 综合诊断结果")
issues = [] issues = []
if not candidates: if not candidates:
issues.append("❌ 没有符合条件的转移候选") issues.append("❌ 没有符合条件的转移候选")
if not task_running: if not task_running:
issues.append("❌ 自动转移任务未运行") issues.append("❌ 自动转移任务未运行")
if issues: if issues:
print("🚨 发现以下问题:\n") print("🚨 发现以下问题:\n")
for issue in issues: for issue in issues:
print(f" {issue}") print(f" {issue}")
print("\n建议操作:") print("\n建议操作:")
print(" 1. 检查短期记忆的重要性评分是否合理") print(" 1. 检查短期记忆的重要性评分是否合理")
print(" 2. 降低配置中的转移阈值") print(" 2. 降低配置中的转移阈值")
@@ -273,7 +273,7 @@ async def main():
print(" 4. 尝试手动触发转移测试") print(" 4. 尝试手动触发转移测试")
else: else:
print("✅ 系统运行正常,转移机制已就绪") print("✅ 系统运行正常,转移机制已就绪")
if candidates: if candidates:
print(f"\n当前有 {len(candidates)} 条记忆等待转移") print(f"\n当前有 {len(candidates)} 条记忆等待转移")
print("转移将在满足以下任一条件时自动触发:") print("转移将在满足以下任一条件时自动触发:")
@@ -281,20 +281,20 @@ async def main():
print(" • 短期记忆占用率超过 50%") print(" • 短期记忆占用率超过 50%")
print(" • 距上次转移超过最大延迟") print(" • 距上次转移超过最大延迟")
print(" • 短期记忆达到容量上限") print(" • 短期记忆达到容量上限")
# 询问是否手动触发转移 # 询问是否手动触发转移
if candidates: if candidates:
print() print()
await manual_transfer_test() await manual_transfer_test()
print_section("检查完成") print_section("检查完成")
print("详细诊断报告: docs/memory_transfer_diagnostic_report.md") print("详细诊断报告: docs/memory_transfer_diagnostic_report.md")
except Exception as e: except Exception as e:
print(f"\n❌ 检查过程出错: {e}") print(f"\n❌ 检查过程出错: {e}")
logger.exception("检查脚本执行失败") logger.exception("检查脚本执行失败")
return 1 return 1
return 0 return 0

View File

@@ -17,8 +17,8 @@ from pathlib import Path
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(PROJECT_ROOT)) sys.path.insert(0, str(PROJECT_ROOT))
from src.config.config import global_config # noqa: E402 from src.config.config import global_config
from src.memory_graph.short_term_manager import ShortTermMemoryManager # noqa: E402 from src.memory_graph.short_term_manager import ShortTermMemoryManager
def resolve_data_dir() -> Path: def resolve_data_dir() -> Path:

View File

@@ -12,17 +12,16 @@ from typing import Any, Optional, cast
import json_repair import json_repair
from PIL import Image from PIL import Image
from rich.traceback import install
from sqlalchemy import select from sqlalchemy import select
from src.chat.emoji_system.emoji_constants import EMOJI_DIR, EMOJI_REGISTERED_DIR, MAX_EMOJI_FOR_PROMPT from src.chat.emoji_system.emoji_constants import EMOJI_DIR, EMOJI_REGISTERED_DIR, MAX_EMOJI_FOR_PROMPT
from src.chat.emoji_system.emoji_entities import MaiEmoji from src.chat.emoji_system.emoji_entities import MaiEmoji
from src.chat.emoji_system.emoji_utils import ( from src.chat.emoji_system.emoji_utils import (
_emoji_objects_to_readable_list, _emoji_objects_to_readable_list,
_to_emoji_objects,
_ensure_emoji_dir, _ensure_emoji_dir,
clear_temp_emoji, _to_emoji_objects,
clean_unused_emojis, clean_unused_emojis,
clear_temp_emoji,
list_image_files, list_image_files,
) )
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64 from src.chat.utils.utils_image import get_image_manager, image_path_to_base64

View File

@@ -415,20 +415,20 @@ class ExpressionLearner:
.offset(offset) .offset(offset)
) )
batch_expressions = list(batch_result.scalars()) batch_expressions = list(batch_result.scalars())
if not batch_expressions: if not batch_expressions:
break # 没有更多数据 break # 没有更多数据
# 批量处理当前批次 # 批量处理当前批次
to_delete = [] to_delete = []
for expr in batch_expressions: for expr in batch_expressions:
# 计算时间差 # 计算时间差
time_diff_days = (current_time - expr.last_active_time) / (24 * 3600) time_diff_days = (current_time - expr.last_active_time) / (24 * 3600)
# 计算衰减值 # 计算衰减值
decay_value = self.calculate_decay_factor(time_diff_days) decay_value = self.calculate_decay_factor(time_diff_days)
new_count = max(0.01, expr.count - decay_value) new_count = max(0.01, expr.count - decay_value)
if new_count <= 0.01: if new_count <= 0.01:
# 标记删除 # 标记删除
to_delete.append(expr) to_delete.append(expr)
@@ -436,22 +436,22 @@ class ExpressionLearner:
# 更新count # 更新count
expr.count = new_count expr.count = new_count
updated_count += 1 updated_count += 1
# 批量删除 # 批量删除
if to_delete: if to_delete:
for expr in to_delete: for expr in to_delete:
await session.delete(expr) await session.delete(expr)
deleted_count += len(to_delete) deleted_count += len(to_delete)
# 提交当前批次 # 提交当前批次
await session.commit() await session.commit()
# 如果批次不满,说明已经处理完所有数据 # 如果批次不满,说明已经处理完所有数据
if len(batch_expressions) < BATCH_SIZE: if len(batch_expressions) < BATCH_SIZE:
break break
offset += BATCH_SIZE offset += BATCH_SIZE
if updated_count > 0 or deleted_count > 0: if updated_count > 0 or deleted_count > 0:
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式") logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
@@ -544,12 +544,12 @@ class ExpressionLearner:
) )
) )
existing_exprs = list(existing_exprs_result.scalars()) existing_exprs = list(existing_exprs_result.scalars())
# 构建快速查找索引 # 构建快速查找索引
exact_match_map = {} # (situation, style) -> Expression exact_match_map = {} # (situation, style) -> Expression
situation_map = {} # situation -> Expression situation_map = {} # situation -> Expression
style_map = {} # style -> Expression style_map = {} # style -> Expression
for expr in existing_exprs: for expr in existing_exprs:
key = (expr.situation, expr.style) key = (expr.situation, expr.style)
exact_match_map[key] = expr exact_match_map[key] = expr
@@ -558,13 +558,13 @@ class ExpressionLearner:
situation_map[expr.situation] = expr situation_map[expr.situation] = expr
if expr.style not in style_map: if expr.style not in style_map:
style_map[expr.style] = expr style_map[expr.style] = expr
# 批量处理所有新表达方式 # 批量处理所有新表达方式
for new_expr in expr_list: for new_expr in expr_list:
situation = new_expr["situation"] situation = new_expr["situation"]
style_val = new_expr["style"] style_val = new_expr["style"]
exact_key = (situation, style_val) exact_key = (situation, style_val)
# 优先处理完全匹配的情况 # 优先处理完全匹配的情况
if exact_key in exact_match_map: if exact_key in exact_match_map:
# 完全相同增加count更新时间 # 完全相同增加count更新时间
@@ -578,8 +578,7 @@ class ExpressionLearner:
logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{style_val}'") logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{style_val}'")
# 更新映射 # 更新映射
old_key = (same_situation_expr.situation, same_situation_expr.style) old_key = (same_situation_expr.situation, same_situation_expr.style)
if old_key in exact_match_map: exact_match_map.pop(old_key, None)
del exact_match_map[old_key]
same_situation_expr.style = style_val same_situation_expr.style = style_val
same_situation_expr.count = same_situation_expr.count + 1 same_situation_expr.count = same_situation_expr.count + 1
same_situation_expr.last_active_time = current_time same_situation_expr.last_active_time = current_time
@@ -591,8 +590,7 @@ class ExpressionLearner:
logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{situation}'") logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{situation}'")
# 更新映射 # 更新映射
old_key = (same_style_expr.situation, same_style_expr.style) old_key = (same_style_expr.situation, same_style_expr.style)
if old_key in exact_match_map: exact_match_map.pop(old_key, None)
del exact_match_map[old_key]
same_style_expr.situation = situation same_style_expr.situation = situation
same_style_expr.count = same_style_expr.count + 1 same_style_expr.count = same_style_expr.count + 1
same_style_expr.last_active_time = current_time same_style_expr.last_active_time = current_time
@@ -627,8 +625,7 @@ class ExpressionLearner:
await session.delete(expr) await session.delete(expr)
# 从映射中移除 # 从映射中移除
key = (expr.situation, expr.style) key = (expr.situation, expr.style)
if key in exact_match_map: exact_match_map.pop(key, None)
del exact_match_map[key]
logger.debug(f"已删除 {len(all_current_exprs) - MAX_EXPRESSION_COUNT} 个低频表达方式") logger.debug(f"已删除 {len(all_current_exprs) - MAX_EXPRESSION_COUNT} 个低频表达方式")
# 提交数据库更改 # 提交数据库更改
@@ -658,31 +655,31 @@ class ExpressionLearner:
# 为每个共享组内的 chat_id 训练其 StyleLearner # 为每个共享组内的 chat_id 训练其 StyleLearner
for target_chat_id in related_chat_ids: for target_chat_id in related_chat_ids:
learner = style_learner_manager.get_learner(target_chat_id) learner = style_learner_manager.get_learner(target_chat_id)
# 收集该 target_chat_id 对应的所有表达方式 # 收集该 target_chat_id 对应的所有表达方式
# 如果是源 chat_id使用 chat_dict 中的数据;否则也要训练(共享组特性) # 如果是源 chat_id使用 chat_dict 中的数据;否则也要训练(共享组特性)
total_success = 0 total_success = 0
total_samples = 0 total_samples = 0
for source_chat_id, expr_list in chat_dict.items(): for source_chat_id, expr_list in chat_dict.items():
# 为每个学习到的表达方式训练模型 # 为每个学习到的表达方式训练模型
# 使用 situation 作为输入style 作为目标 # 使用 situation 作为输入style 作为目标
for expr in expr_list: for expr in expr_list:
situation = expr["situation"] situation = expr["situation"]
style = expr["style"] style = expr["style"]
# 训练映射关系: situation -> style # 训练映射关系: situation -> style
if learner.learn_mapping(situation, style): if learner.learn_mapping(situation, style):
total_success += 1 total_success += 1
total_samples += 1 total_samples += 1
# 保存模型 # 保存模型
if total_samples > 0: if total_samples > 0:
if learner.save(style_learner_manager.model_save_path): if learner.save(style_learner_manager.model_save_path):
logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}") logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}")
else: else:
logger.error(f"StyleLearner 模型保存失败: {target_chat_id}") logger.error(f"StyleLearner 模型保存失败: {target_chat_id}")
if target_chat_id == self.chat_id: if target_chat_id == self.chat_id:
# 只为当前 chat_id 记录详细日志 # 只为当前 chat_id 记录详细日志
logger.info( logger.info(

View File

@@ -218,7 +218,7 @@ class ExpressionSelector:
"type": expr_type, "type": expr_type,
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time, "create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
} }
style_exprs = [expr_to_dict(expr, "style") for expr in style_query.scalars()] style_exprs = [expr_to_dict(expr, "style") for expr in style_query.scalars()]
grammar_exprs = [expr_to_dict(expr, "grammar") for expr in grammar_query.scalars()] grammar_exprs = [expr_to_dict(expr, "grammar") for expr in grammar_query.scalars()]
@@ -246,7 +246,7 @@ class ExpressionSelector:
""" """
if not expressions_to_update: if not expressions_to_update:
return return
# 去重处理 # 去重处理
updates_by_key = {} updates_by_key = {}
affected_chat_ids = set() affected_chat_ids = set()
@@ -524,7 +524,7 @@ class ExpressionSelector:
# 预处理:提前计算所有预测 style 的小写版本,避免重复计算 # 预处理:提前计算所有预测 style 的小写版本,避免重复计算
predicted_styles_lower = [(s.lower(), score) for s, score in predicted_styles[:20]] predicted_styles_lower = [(s.lower(), score) for s, score in predicted_styles[:20]]
matched_expressions = [] matched_expressions = []
for expr in all_expressions: for expr in all_expressions:
db_style = expr.style or "" db_style = expr.style or ""
@@ -539,7 +539,7 @@ class ExpressionSelector:
max_similarity = 1.0 max_similarity = 1.0
best_predicted = predicted_style_lower best_predicted = predicted_style_lower
break break
# 快速检查:子串匹配 # 快速检查:子串匹配
if len(predicted_style_lower) >= 2 and len(db_style_lower) >= 2: if len(predicted_style_lower) >= 2 and len(db_style_lower) >= 2:
if predicted_style_lower in db_style_lower or db_style_lower in predicted_style_lower: if predicted_style_lower in db_style_lower or db_style_lower in predicted_style_lower:
@@ -548,7 +548,7 @@ class ExpressionSelector:
max_similarity = similarity max_similarity = similarity
best_predicted = predicted_style_lower best_predicted = predicted_style_lower
continue continue
# 计算字符串相似度(较慢,只在必要时使用) # 计算字符串相似度(较慢,只在必要时使用)
similarity = SequenceMatcher(None, predicted_style_lower, db_style_lower).ratio() similarity = SequenceMatcher(None, predicted_style_lower, db_style_lower).ratio()
if similarity > max_similarity: if similarity > max_similarity:

View File

@@ -38,7 +38,7 @@ class InterestManager:
self._calculation_queue = asyncio.Queue() self._calculation_queue = asyncio.Queue()
self._worker_task = None self._worker_task = None
self._shutdown_event = asyncio.Event() self._shutdown_event = asyncio.Event()
# 性能优化相关字段 # 性能优化相关字段
self._result_cache: OrderedDict[str, InterestCalculationResult] = OrderedDict() # LRU缓存 self._result_cache: OrderedDict[str, InterestCalculationResult] = OrderedDict() # LRU缓存
self._cache_max_size = 1000 # 最大缓存数量 self._cache_max_size = 1000 # 最大缓存数量
@@ -48,13 +48,13 @@ class InterestManager:
self._batch_timeout = 0.1 # 批处理超时(秒) self._batch_timeout = 0.1 # 批处理超时(秒)
self._batch_task = None self._batch_task = None
self._is_warmed_up = False # 预热状态标记 self._is_warmed_up = False # 预热状态标记
# 性能统计 # 性能统计
self._cache_hits = 0 self._cache_hits = 0
self._cache_misses = 0 self._cache_misses = 0
self._batch_calculations = 0 self._batch_calculations = 0
self._total_calculation_time = 0.0 self._total_calculation_time = 0.0
self._initialized = True self._initialized = True
async def initialize(self): async def initialize(self):
@@ -67,7 +67,7 @@ class InterestManager:
async def shutdown(self): async def shutdown(self):
"""关闭管理器""" """关闭管理器"""
self._shutdown_event.set() self._shutdown_event.set()
# 取消批处理任务 # 取消批处理任务
if self._batch_task and not self._batch_task.done(): if self._batch_task and not self._batch_task.done():
self._batch_task.cancel() self._batch_task.cancel()
@@ -79,7 +79,7 @@ class InterestManager:
if self._current_calculator: if self._current_calculator:
await self._current_calculator.cleanup() await self._current_calculator.cleanup()
self._current_calculator = None self._current_calculator = None
# 清理缓存 # 清理缓存
self._result_cache.clear() self._result_cache.clear()
@@ -142,9 +142,9 @@ class InterestManager:
interest_value=0.3, interest_value=0.3,
error_message="没有可用的兴趣值计算组件", error_message="没有可用的兴趣值计算组件",
) )
message_id = getattr(message, "message_id", "") message_id = getattr(message, "message_id", "")
# 缓存查询 # 缓存查询
if use_cache and message_id: if use_cache and message_id:
cached_result = self._get_from_cache(message_id) cached_result = self._get_from_cache(message_id)
@@ -183,11 +183,11 @@ class InterestManager:
interest_value=0.3, interest_value=0.3,
error_message=f"计算异常: {e!s}", error_message=f"计算异常: {e!s}",
) )
# 缓存结果 # 缓存结果
if use_cache and result.success and message_id: if use_cache and result.success and message_id:
self._put_to_cache(message_id, result) self._put_to_cache(message_id, result)
return result return result
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult: async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
@@ -249,36 +249,36 @@ class InterestManager:
break break
except Exception as e: except Exception as e:
logger.error(f"计算工作线程异常: {e}") logger.error(f"计算工作线程异常: {e}")
def _get_from_cache(self, message_id: str) -> InterestCalculationResult | None: def _get_from_cache(self, message_id: str) -> InterestCalculationResult | None:
"""从缓存中获取结果LRU策略""" """从缓存中获取结果LRU策略"""
if message_id not in self._result_cache: if message_id not in self._result_cache:
return None return None
# 检查TTL # 检查TTL
result = self._result_cache[message_id] result = self._result_cache[message_id]
if time.time() - result.timestamp > self._cache_ttl: if time.time() - result.timestamp > self._cache_ttl:
# 过期,删除 # 过期,删除
del self._result_cache[message_id] del self._result_cache[message_id]
return None return None
# 更新访问顺序LRU # 更新访问顺序LRU
self._result_cache.move_to_end(message_id) self._result_cache.move_to_end(message_id)
return result return result
def _put_to_cache(self, message_id: str, result: InterestCalculationResult): def _put_to_cache(self, message_id: str, result: InterestCalculationResult):
"""将结果放入缓存LRU策略""" """将结果放入缓存LRU策略"""
# 如果已存在,更新 # 如果已存在,更新
if message_id in self._result_cache: if message_id in self._result_cache:
self._result_cache.move_to_end(message_id) self._result_cache.move_to_end(message_id)
self._result_cache[message_id] = result self._result_cache[message_id] = result
# 限制缓存大小 # 限制缓存大小
while len(self._result_cache) > self._cache_max_size: while len(self._result_cache) > self._cache_max_size:
# 删除最旧的项 # 删除最旧的项
self._result_cache.popitem(last=False) self._result_cache.popitem(last=False)
async def calculate_interest_batch(self, messages: list["DatabaseMessages"], timeout: float | None = None) -> list[InterestCalculationResult]: async def calculate_interest_batch(self, messages: list["DatabaseMessages"], timeout: float | None = None) -> list[InterestCalculationResult]:
"""批量计算消息兴趣值(并发优化) """批量计算消息兴趣值(并发优化)
@@ -291,11 +291,11 @@ class InterestManager:
""" """
if not messages: if not messages:
return [] return []
# 并发计算所有消息 # 并发计算所有消息
tasks = [self.calculate_interest(msg, timeout=timeout) for msg in messages] tasks = [self.calculate_interest(msg, timeout=timeout) for msg in messages]
results = await asyncio.gather(*tasks, return_exceptions=True) results = await asyncio.gather(*tasks, return_exceptions=True)
# 处理异常 # 处理异常
final_results = [] final_results = []
for i, result in enumerate(results): for i, result in enumerate(results):
@@ -309,44 +309,44 @@ class InterestManager:
)) ))
else: else:
final_results.append(result) final_results.append(result)
self._batch_calculations += 1 self._batch_calculations += 1
return final_results return final_results
async def _batch_processing_worker(self): async def _batch_processing_worker(self):
"""批处理工作线程""" """批处理工作线程"""
while not self._shutdown_event.is_set(): while not self._shutdown_event.is_set():
batch = [] batch = []
deadline = time.time() + self._batch_timeout deadline = time.time() + self._batch_timeout
try: try:
# 收集批次 # 收集批次
while len(batch) < self._batch_size and time.time() < deadline: while len(batch) < self._batch_size and time.time() < deadline:
remaining_time = deadline - time.time() remaining_time = deadline - time.time()
if remaining_time <= 0: if remaining_time <= 0:
break break
try: try:
item = await asyncio.wait_for(self._batch_queue.get(), timeout=remaining_time) item = await asyncio.wait_for(self._batch_queue.get(), timeout=remaining_time)
batch.append(item) batch.append(item)
except asyncio.TimeoutError: except asyncio.TimeoutError:
break break
# 处理批次 # 处理批次
if batch: if batch:
await self._process_batch(batch) await self._process_batch(batch)
except asyncio.CancelledError: except asyncio.CancelledError:
break break
except Exception as e: except Exception as e:
logger.error(f"批处理工作线程异常: {e}") logger.error(f"批处理工作线程异常: {e}")
async def _process_batch(self, batch: list): async def _process_batch(self, batch: list):
"""处理批次消息""" """处理批次消息"""
# 这里可以实现具体的批处理逻辑 # 这里可以实现具体的批处理逻辑
# 当前版本只是占位,实际的批处理逻辑可以根据具体需求实现 # 当前版本只是占位,实际的批处理逻辑可以根据具体需求实现
pass pass
async def warmup(self, sample_messages: list["DatabaseMessages"] | None = None): async def warmup(self, sample_messages: list["DatabaseMessages"] | None = None):
"""预热兴趣计算器 """预热兴趣计算器
@@ -356,10 +356,10 @@ class InterestManager:
if not self._current_calculator: if not self._current_calculator:
logger.warning("无法预热:没有可用的兴趣值计算组件") logger.warning("无法预热:没有可用的兴趣值计算组件")
return return
logger.info("开始预热兴趣值计算器...") logger.info("开始预热兴趣值计算器...")
start_time = time.time() start_time = time.time()
# 如果提供了样本消息,进行预热计算 # 如果提供了样本消息,进行预热计算
if sample_messages: if sample_messages:
try: try:
@@ -370,15 +370,15 @@ class InterestManager:
logger.error(f"预热过程中出现异常: {e}") logger.error(f"预热过程中出现异常: {e}")
else: else:
logger.info(f"预热完成:计算器已就绪,耗时 {time.time() - start_time:.2f}s") logger.info(f"预热完成:计算器已就绪,耗时 {time.time() - start_time:.2f}s")
self._is_warmed_up = True self._is_warmed_up = True
def clear_cache(self): def clear_cache(self):
"""清空缓存""" """清空缓存"""
cleared_count = len(self._result_cache) cleared_count = len(self._result_cache)
self._result_cache.clear() self._result_cache.clear()
logger.info(f"已清空 {cleared_count} 条缓存记录") logger.info(f"已清空 {cleared_count} 条缓存记录")
def set_cache_config(self, max_size: int | None = None, ttl: int | None = None): def set_cache_config(self, max_size: int | None = None, ttl: int | None = None):
"""设置缓存配置 """设置缓存配置
@@ -389,11 +389,11 @@ class InterestManager:
if max_size is not None: if max_size is not None:
self._cache_max_size = max_size self._cache_max_size = max_size
logger.info(f"缓存最大容量设置为: {max_size}") logger.info(f"缓存最大容量设置为: {max_size}")
if ttl is not None: if ttl is not None:
self._cache_ttl = ttl self._cache_ttl = ttl
logger.info(f"缓存TTL设置为: {ttl}") logger.info(f"缓存TTL设置为: {ttl}")
# 如果当前缓存超过新的最大值,清理旧数据 # 如果当前缓存超过新的最大值,清理旧数据
if max_size is not None: if max_size is not None:
while len(self._result_cache) > self._cache_max_size: while len(self._result_cache) > self._cache_max_size:
@@ -446,14 +446,14 @@ class InterestManager:
def has_calculator(self) -> bool: def has_calculator(self) -> bool:
"""检查是否有可用的计算组件""" """检查是否有可用的计算组件"""
return self._current_calculator is not None and self._current_calculator.is_enabled return self._current_calculator is not None and self._current_calculator.is_enabled
async def adaptive_optimize(self): async def adaptive_optimize(self):
"""自适应优化:根据性能统计自动调整参数""" """自适应优化:根据性能统计自动调整参数"""
if not self._current_calculator: if not self._current_calculator:
return return
stats = self.get_statistics()["manager_statistics"] stats = self.get_statistics()["manager_statistics"]
# 根据缓存命中率调整缓存大小 # 根据缓存命中率调整缓存大小
cache_hit_rate = stats["cache_hit_rate"] cache_hit_rate = stats["cache_hit_rate"]
if cache_hit_rate < 0.5 and self._cache_max_size < 5000: if cache_hit_rate < 0.5 and self._cache_max_size < 5000:
@@ -469,7 +469,7 @@ class InterestManager:
# 清理多余缓存 # 清理多余缓存
while len(self._result_cache) > self._cache_max_size: while len(self._result_cache) > self._cache_max_size:
self._result_cache.popitem(last=False) self._result_cache.popitem(last=False)
# 根据平均计算时间调整批处理参数 # 根据平均计算时间调整批处理参数
avg_calc_time = stats["average_calculation_time"] avg_calc_time = stats["average_calculation_time"]
if avg_calc_time > 0.5 and self._batch_size < 50: if avg_calc_time > 0.5 and self._batch_size < 50:
@@ -482,11 +482,11 @@ class InterestManager:
new_batch_size = max(self._batch_size // 2, 5) new_batch_size = max(self._batch_size // 2, 5)
logger.info(f"自适应优化:平均计算时间较短 ({avg_calc_time:.3f}s),减小批次大小 {self._batch_size} -> {new_batch_size}") logger.info(f"自适应优化:平均计算时间较短 ({avg_calc_time:.3f}s),减小批次大小 {self._batch_size} -> {new_batch_size}")
self._batch_size = new_batch_size self._batch_size = new_batch_size
def get_performance_report(self) -> str: def get_performance_report(self) -> str:
"""生成性能报告""" """生成性能报告"""
stats = self.get_statistics()["manager_statistics"] stats = self.get_statistics()["manager_statistics"]
report = [ report = [
"=" * 60, "=" * 60,
"兴趣值管理器性能报告", "兴趣值管理器性能报告",
@@ -504,7 +504,7 @@ class InterestManager:
f"当前计算器: {stats['current_calculator'] or ''}", f"当前计算器: {stats['current_calculator'] or ''}",
"=" * 60, "=" * 60,
] ]
# 添加计算器统计 # 添加计算器统计
if self._current_calculator: if self._current_calculator:
calc_stats = self.get_statistics()["calculator_statistics"] calc_stats = self.get_statistics()["calculator_statistics"]
@@ -520,7 +520,7 @@ class InterestManager:
f" 平均耗时: {calc_stats['average_calculation_time']:.4f}s", f" 平均耗时: {calc_stats['average_calculation_time']:.4f}s",
"=" * 60, "=" * 60,
]) ])
return "\n".join(report) return "\n".join(report)

View File

@@ -30,7 +30,7 @@ logger = get_logger("message_manager")
class MessageManager: class MessageManager:
"""消息管理器""" """消息管理器"""
def __init__(self, check_interval: float = 5.0): def __init__(self, check_interval: float = 5.0):
self.check_interval = check_interval # 检查间隔(秒) self.check_interval = check_interval # 检查间隔(秒)
self.is_running = False self.is_running = False
self.manager_task: asyncio.Task | None = None self.manager_task: asyncio.Task | None = None

View File

@@ -348,12 +348,12 @@ class StatisticOutputTask(AsyncTask):
prompt_tokens = int(record.get("prompt_tokens") or 0) prompt_tokens = int(record.get("prompt_tokens") or 0)
except (ValueError, TypeError): except (ValueError, TypeError):
prompt_tokens = 0 prompt_tokens = 0
try: try:
completion_tokens = int(record.get("completion_tokens") or 0) completion_tokens = int(record.get("completion_tokens") or 0)
except (ValueError, TypeError): except (ValueError, TypeError):
completion_tokens = 0 completion_tokens = 0
total_tokens = prompt_tokens + completion_tokens total_tokens = prompt_tokens + completion_tokens
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
@@ -378,7 +378,7 @@ class StatisticOutputTask(AsyncTask):
cost = float(cost) if cost else 0.0 cost = float(cost) if cost else 0.0
except (ValueError, TypeError): except (ValueError, TypeError):
cost = 0.0 cost = 0.0
stats[period_key][TOTAL_COST] += cost stats[period_key][TOTAL_COST] += cost
stats[period_key][COST_BY_TYPE][request_type] += cost stats[period_key][COST_BY_TYPE][request_type] += cost
stats[period_key][COST_BY_USER][user_id] += cost stats[period_key][COST_BY_USER][user_id] += cost

View File

@@ -969,7 +969,7 @@ class LongTermMemoryManager:
content=f"临时节点 - {source_id}", content=f"临时节点 - {source_id}",
metadata={"placeholder": True, "created_by": "long_term_manager_edge_creation"} metadata={"placeholder": True, "created_by": "long_term_manager_edge_creation"}
) )
if not self.memory_manager.graph_store.graph.has_node(target_id): if not self.memory_manager.graph_store.graph.has_node(target_id):
logger.debug(f"目标节点不存在,创建占位符节点: {target_id}") logger.debug(f"目标节点不存在,创建占位符节点: {target_id}")
self.memory_manager.graph_store.add_node( self.memory_manager.graph_store.add_node(

View File

@@ -1,4 +1,3 @@
# ruff: noqa: G004, BLE001
# pylint: disable=logging-fstring-interpolation,broad-except,unused-argument # pylint: disable=logging-fstring-interpolation,broad-except,unused-argument
# pyright: reportOptionalMemberAccess=false # pyright: reportOptionalMemberAccess=false
""" """

View File

@@ -658,7 +658,7 @@ class ShortTermMemoryManager:
return self._get_transfer_all_strategy() return self._get_transfer_all_strategy()
else: # "selective_cleanup" 或其他值默认使用选择性清理 else: # "selective_cleanup" 或其他值默认使用选择性清理
return self._get_selective_cleanup_strategy() return self._get_selective_cleanup_strategy()
def _get_transfer_all_strategy(self) -> list[ShortTermMemory]: def _get_transfer_all_strategy(self) -> list[ShortTermMemory]:
""" """
"一次性转移所有"策略:当短期记忆满了以后,将所有记忆转移到长期记忆 "一次性转移所有"策略:当短期记忆满了以后,将所有记忆转移到长期记忆
@@ -673,24 +673,24 @@ class ShortTermMemoryManager:
f"将转移所有 {len(self.memories)} 条记忆到长期记忆" f"将转移所有 {len(self.memories)} 条记忆到长期记忆"
) )
return self.memories.copy() return self.memories.copy()
# 如果还没满,检查是否有高重要性记忆需要转移 # 如果还没满,检查是否有高重要性记忆需要转移
high_importance_memories = [ high_importance_memories = [
mem for mem in self.memories mem for mem in self.memories
if mem.importance >= self.transfer_importance_threshold if mem.importance >= self.transfer_importance_threshold
] ]
if high_importance_memories: if high_importance_memories:
logger.debug( logger.debug(
f"转移策略(transfer_all): 发现 {len(high_importance_memories)} 条高重要性记忆待转移" f"转移策略(transfer_all): 发现 {len(high_importance_memories)} 条高重要性记忆待转移"
) )
return high_importance_memories return high_importance_memories
logger.debug( logger.debug(
f"转移策略(transfer_all): 无需转移 (当前容量 {len(self.memories)}/{self.max_memories})" f"转移策略(transfer_all): 无需转移 (当前容量 {len(self.memories)}/{self.max_memories})"
) )
return [] return []
def _get_selective_cleanup_strategy(self) -> list[ShortTermMemory]: def _get_selective_cleanup_strategy(self) -> list[ShortTermMemory]:
""" """
"选择性清理"策略(原有策略):优先转移重要记忆,低重要性记忆考虑直接删除 "选择性清理"策略(原有策略):优先转移重要记忆,低重要性记忆考虑直接删除
@@ -720,11 +720,11 @@ class ShortTermMemoryManager:
if len(self.memories) > self.max_memories: if len(self.memories) > self.max_memories:
# 计算需要转移的数量(目标:降到上限) # 计算需要转移的数量(目标:降到上限)
num_to_transfer = len(self.memories) - self.max_memories num_to_transfer = len(self.memories) - self.max_memories
# 按创建时间排序低重要性记忆,优先转移最早的(可能包含过时信息) # 按创建时间排序低重要性记忆,优先转移最早的(可能包含过时信息)
low_importance_memories.sort(key=lambda x: x.created_at) low_importance_memories.sort(key=lambda x: x.created_at)
to_transfer = low_importance_memories[:num_to_transfer] to_transfer = low_importance_memories[:num_to_transfer]
if to_transfer: if to_transfer:
logger.debug( logger.debug(
f"转移策略(selective): 发现 {len(to_transfer)} 条低重要性记忆待转移 " f"转移策略(selective): 发现 {len(to_transfer)} 条低重要性记忆待转移 "
@@ -757,7 +757,7 @@ class ShortTermMemoryManager:
# 使用实例配置或传入参数 # 使用实例配置或传入参数
if keep_ratio is None: if keep_ratio is None:
keep_ratio = self.cleanup_keep_ratio keep_ratio = self.cleanup_keep_ratio
current = len(self.memories) current = len(self.memories)
limit = int(self.max_memories * keep_ratio) limit = int(self.max_memories * keep_ratio)
if current <= self.max_memories: if current <= self.max_memories:
@@ -804,28 +804,28 @@ class ShortTermMemoryManager:
self._similarity_cache.pop(mem_id, None) self._similarity_cache.pop(mem_id, None)
logger.info(f"清除 {len(memory_ids)} 条已转移的短期记忆") logger.info(f"清除 {len(memory_ids)} 条已转移的短期记忆")
# 在 "transfer_all" 策略下,进一步删除不重要的短期记忆 # 在 "transfer_all" 策略下,进一步删除不重要的短期记忆
if self.overflow_strategy == "transfer_all": if self.overflow_strategy == "transfer_all":
# 计算需要删除的低重要性记忆数量 # 计算需要删除的低重要性记忆数量
low_importance_memories = [ low_importance_memories = [
mem for mem in self.memories mem for mem in self.memories
if mem.importance < self.transfer_importance_threshold if mem.importance < self.transfer_importance_threshold
] ]
if low_importance_memories: if low_importance_memories:
# 按重要性和创建时间排序,删除最不重要的 # 按重要性和创建时间排序,删除最不重要的
low_importance_memories.sort(key=lambda m: (m.importance, m.created_at)) low_importance_memories.sort(key=lambda m: (m.importance, m.created_at))
# 删除所有低重要性记忆 # 删除所有低重要性记忆
to_delete = {mem.id for mem in low_importance_memories} to_delete = {mem.id for mem in low_importance_memories}
self.memories = [mem for mem in self.memories if mem.id not in to_delete] self.memories = [mem for mem in self.memories if mem.id not in to_delete]
# 更新索引 # 更新索引
for mem_id in to_delete: for mem_id in to_delete:
self._memory_id_index.pop(mem_id, None) self._memory_id_index.pop(mem_id, None)
self._similarity_cache.pop(mem_id, None) self._similarity_cache.pop(mem_id, None)
logger.info( logger.info(
f"transfer_all 策略: 额外删除了 {len(to_delete)} 条低重要性记忆 " f"transfer_all 策略: 额外删除了 {len(to_delete)} 条低重要性记忆 "
f"(重要性 < {self.transfer_importance_threshold:.2f})" f"(重要性 < {self.transfer_importance_threshold:.2f})"

View File

@@ -936,7 +936,7 @@ class GraphStore:
edge_type_enum = EdgeType.RELATION edge_type_enum = EdgeType.RELATION
else: else:
edge_type_enum = edge_type_value edge_type_enum = edge_type_value
mem_edge = MemoryEdge( mem_edge = MemoryEdge(
id=edge_dict["id"] or "", id=edge_dict["id"] or "",
source_id=edge_dict["source_id"], source_id=edge_dict["source_id"],

View File

@@ -124,7 +124,7 @@ class BaseInterestCalculator(ABC):
logger.error(f"初始化兴趣计算器失败: {e}") logger.error(f"初始化兴趣计算器失败: {e}")
self._enabled = False self._enabled = False
return False return False
async def on_initialize(self): async def on_initialize(self):
"""子类可重写的初始化钩子""" """子类可重写的初始化钩子"""
pass pass
@@ -143,7 +143,7 @@ class BaseInterestCalculator(ABC):
except Exception as e: except Exception as e:
logger.error(f"清理兴趣计算器失败: {e}") logger.error(f"清理兴趣计算器失败: {e}")
return False return False
async def on_cleanup(self): async def on_cleanup(self):
"""子类可重写的清理钩子""" """子类可重写的清理钩子"""
pass pass

View File

@@ -3,7 +3,6 @@ MaiZone麦麦空间- 重构版
""" """
import asyncio import asyncio
from pathlib import Path
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin from src.plugin_system import BasePlugin, ComponentInfo, register_plugin

View File

@@ -151,7 +151,7 @@ class ContentService:
bot_personality_side = config_api.get_global_config("personality.personality_side", "") bot_personality_side = config_api.get_global_config("personality.personality_side", "")
bot_reply_style = config_api.get_global_config("personality.reply_style", "内容积极向上") bot_reply_style = config_api.get_global_config("personality.reply_style", "内容积极向上")
qq_account = config_api.get_global_config("bot.qq_account", "") qq_account = config_api.get_global_config("bot.qq_account", "")
# 获取角色外貌描述用于告知LLM # 获取角色外貌描述用于告知LLM
character_prompt = self.get_config("novelai.character_prompt", "") character_prompt = self.get_config("novelai.character_prompt", "")
@@ -163,21 +163,21 @@ class ContentService:
# 构建提示词 # 构建提示词
prompt_topic = f"主题是'{topic}'" if topic else "主题不限" prompt_topic = f"主题是'{topic}'" if topic else "主题不限"
# 构建人设描述 # 构建人设描述
personality_desc = f"你的核心人格:{bot_personality_core}" personality_desc = f"你的核心人格:{bot_personality_core}"
if bot_personality_side: if bot_personality_side:
personality_desc += f"\n你的人格侧面:{bot_personality_side}" personality_desc += f"\n你的人格侧面:{bot_personality_side}"
personality_desc += f"\n\n你的表达方式:{bot_reply_style}" personality_desc += f"\n\n你的表达方式:{bot_reply_style}"
# 检查是否启用AI配图统一开关 # 检查是否启用AI配图统一开关
ai_image_enabled = self.get_config("ai_image.enable_ai_image", False) ai_image_enabled = self.get_config("ai_image.enable_ai_image", False)
provider = self.get_config("ai_image.provider", "siliconflow") provider = self.get_config("ai_image.provider", "siliconflow")
# NovelAI配图指引内置 # NovelAI配图指引内置
novelai_guide = "" novelai_guide = ""
output_format = '{"text": "说说正文内容"}' output_format = '{"text": "说说正文内容"}'
if ai_image_enabled and provider == "novelai": if ai_image_enabled and provider == "novelai":
# 构建角色信息提示 # 构建角色信息提示
character_info = "" character_info = ""
@@ -195,7 +195,7 @@ class ContentService:
- 例如可以搭配各种表情smile, laugh, serious, thinking, surprised等 - 例如可以搭配各种表情smile, laugh, serious, thinking, surprised等
- **鼓励创意**:根据说说内容自由发挥,让画面更丰富生动! - **鼓励创意**:根据说说内容自由发挥,让画面更丰富生动!
""" """
novelai_guide = f""" novelai_guide = f"""
**配图说明:** **配图说明:**
这条说说会使用NovelAI Diffusion模型二次元风格生成配图。 这条说说会使用NovelAI Diffusion模型二次元风格生成配图。
@@ -258,7 +258,7 @@ class ContentService:
- 运动风:"masterpiece, best quality, 1girl, sportswear, running in park, energetic, morning light, trees background, dynamic pose, healthy lifestyle" - 运动风:"masterpiece, best quality, 1girl, sportswear, running in park, energetic, morning light, trees background, dynamic pose, healthy lifestyle"
- 咖啡馆:"masterpiece, best quality, 1girl, sitting in cozy cafe, holding coffee cup, warm lighting, wooden table, books beside, peaceful atmosphere" - 咖啡馆:"masterpiece, best quality, 1girl, sitting in cozy cafe, holding coffee cup, warm lighting, wooden table, books beside, peaceful atmosphere"
""" """
output_format = '''{"text": "说说正文内容", "image": {"prompt": "详细的英文提示词(包含画质+主体+场景+氛围+光线+色彩)", "negative_prompt": "负面词", "include_character": true/false, "aspect_ratio": "方图/横图/竖图"}}''' output_format = """{"text": "说说正文内容", "image": {"prompt": "详细的英文提示词(包含画质+主体+场景+氛围+光线+色彩)", "negative_prompt": "负面词", "include_character": true/false, "aspect_ratio": "方图/横图/竖图"}}"""
elif ai_image_enabled and provider == "siliconflow": elif ai_image_enabled and provider == "siliconflow":
novelai_guide = """ novelai_guide = """
**配图说明:** **配图说明:**
@@ -277,8 +277,8 @@ class ContentService:
- "sunset over the calm ocean, golden hour, orange and purple sky, gentle waves, peaceful and serene mood, wide angle view" - "sunset over the calm ocean, golden hour, orange and purple sky, gentle waves, peaceful and serene mood, wide angle view"
- "cherry blossoms in spring, soft pink petals falling, blue sky, sunlight filtering through branches, peaceful park scene, gentle breeze" - "cherry blossoms in spring, soft pink petals falling, blue sky, sunlight filtering through branches, peaceful park scene, gentle breeze"
""" """
output_format = '''{"text": "说说正文内容", "image": {"prompt": "详细的英文描述(主体+场景+氛围+光线+细节)"}}''' output_format = """{"text": "说说正文内容", "image": {"prompt": "详细的英文描述(主体+场景+氛围+光线+细节)"}}"""
prompt = f""" prompt = f"""
{personality_desc} {personality_desc}
@@ -333,20 +333,20 @@ class ContentService:
if json_text.endswith("```"): if json_text.endswith("```"):
json_text = json_text[:-3] json_text = json_text[:-3]
json_text = json_text.strip() json_text = json_text.strip()
data = json5.loads(json_text) data = json5.loads(json_text)
story_text = data.get("text", "") story_text = data.get("text", "")
image_info = data.get("image", {}) image_info = data.get("image", {})
# 确保图片信息完整 # 确保图片信息完整
if not isinstance(image_info, dict): if not isinstance(image_info, dict):
image_info = {} image_info = {}
logger.info(f"成功生成说说:'{story_text}'") logger.info(f"成功生成说说:'{story_text}'")
logger.info(f"配图信息: {image_info}") logger.info(f"配图信息: {image_info}")
return story_text, image_info return story_text, image_info
except Exception as e: except Exception as e:
logger.error(f"解析JSON失败: {e}, 原始响应: {response[:200]}") logger.error(f"解析JSON失败: {e}, 原始响应: {response[:200]}")
# 降级处理:只返回文本,空配图信息 # 降级处理:只返回文本,空配图信息

View File

@@ -42,7 +42,7 @@ class ImageService:
try: try:
api_key = str(self.get_config("siliconflow.api_key", "")) api_key = str(self.get_config("siliconflow.api_key", ""))
image_num = self.get_config("ai_image.image_number", 1) image_num = self.get_config("ai_image.image_number", 1)
if not api_key: if not api_key:
logger.warning("硅基流动API未配置跳过图片生成") logger.warning("硅基流动API未配置跳过图片生成")
return False, None return False, None
@@ -237,7 +237,7 @@ class ImageService:
image.save(save_path, format="PNG") image.save(save_path, format="PNG")
logger.info(f"图片已保存至: {save_path}") logger.info(f"图片已保存至: {save_path}")
success_count += 1 success_count += 1
# 记录第一张图片路径 # 记录第一张图片路径
if first_img_path is None: if first_img_path is None:
first_img_path = save_path first_img_path = save_path

View File

@@ -2,14 +2,11 @@
NovelAI图片生成服务 - 空间插件专用 NovelAI图片生成服务 - 空间插件专用
独立实现,不依赖其他插件 独立实现,不依赖其他插件
""" """
import asyncio import io
import base64
import random import random
import uuid import uuid
import zipfile import zipfile
import io
from pathlib import Path from pathlib import Path
from typing import Optional
import aiohttp import aiohttp
from PIL import Image from PIL import Image
@@ -21,50 +18,50 @@ logger = get_logger("MaiZone.NovelAIService")
class MaiZoneNovelAIService: class MaiZoneNovelAIService:
"""空间插件的NovelAI图片生成服务独立实现""" """空间插件的NovelAI图片生成服务独立实现"""
def __init__(self, get_config): def __init__(self, get_config):
self.get_config = get_config self.get_config = get_config
# NovelAI配置 # NovelAI配置
self.api_key = self.get_config("novelai.api_key", "") self.api_key = self.get_config("novelai.api_key", "")
self.base_url = "https://image.novelai.net/ai/generate-image" self.base_url = "https://image.novelai.net/ai/generate-image"
self.model = "nai-diffusion-4-5-full" self.model = "nai-diffusion-4-5-full"
# 代理配置 # 代理配置
proxy_host = self.get_config("novelai.proxy_host", "") proxy_host = self.get_config("novelai.proxy_host", "")
proxy_port = self.get_config("novelai.proxy_port", 0) proxy_port = self.get_config("novelai.proxy_port", 0)
self.proxy = f"http://{proxy_host}:{proxy_port}" if proxy_host and proxy_port else "" self.proxy = f"http://{proxy_host}:{proxy_port}" if proxy_host and proxy_port else ""
# 生成参数 # 生成参数
self.steps = 28 self.steps = 28
self.scale = 5.0 self.scale = 5.0
self.sampler = "k_euler" self.sampler = "k_euler"
self.noise_schedule = "karras" self.noise_schedule = "karras"
# 角色提示词当LLM决定包含角色时使用 # 角色提示词当LLM决定包含角色时使用
self.character_prompt = self.get_config("novelai.character_prompt", "") self.character_prompt = self.get_config("novelai.character_prompt", "")
self.base_negative_prompt = self.get_config("novelai.base_negative_prompt", "nsfw, nude, explicit, sexual content, lowres, bad anatomy, bad hands") self.base_negative_prompt = self.get_config("novelai.base_negative_prompt", "nsfw, nude, explicit, sexual content, lowres, bad anatomy, bad hands")
# 图片保存目录(使用统一配置) # 图片保存目录(使用统一配置)
plugin_dir = Path(__file__).parent.parent plugin_dir = Path(__file__).parent.parent
self.image_dir = plugin_dir / "images" self.image_dir = plugin_dir / "images"
self.image_dir.mkdir(parents=True, exist_ok=True) self.image_dir.mkdir(parents=True, exist_ok=True)
if self.api_key: if self.api_key:
logger.info(f"NovelAI图片生成已配置模型: {self.model}") logger.info(f"NovelAI图片生成已配置模型: {self.model}")
def is_available(self) -> bool: def is_available(self) -> bool:
"""检查NovelAI服务是否可用""" """检查NovelAI服务是否可用"""
return bool(self.api_key) return bool(self.api_key)
async def generate_image_from_prompt_data( async def generate_image_from_prompt_data(
self, self,
prompt: str, prompt: str,
negative_prompt: Optional[str] = None, negative_prompt: str | None = None,
include_character: bool = False, include_character: bool = False,
width: int = 1024, width: int = 1024,
height: int = 1024 height: int = 1024
) -> tuple[bool, Optional[Path], str]: ) -> tuple[bool, Path | None, str]:
"""根据提示词生成图片 """根据提示词生成图片
Args: Args:
@@ -79,14 +76,14 @@ class MaiZoneNovelAIService:
""" """
if not self.api_key: if not self.api_key:
return False, None, "NovelAI API Key未配置" return False, None, "NovelAI API Key未配置"
try: try:
# 处理角色提示词 # 处理角色提示词
final_prompt = prompt final_prompt = prompt
if include_character and self.character_prompt: if include_character and self.character_prompt:
final_prompt = f"{self.character_prompt}, {prompt}" final_prompt = f"{self.character_prompt}, {prompt}"
logger.info(f"包含角色形象,添加角色提示词") logger.info("包含角色形象,添加角色提示词")
# 合并负面提示词 # 合并负面提示词
final_negative = self.base_negative_prompt final_negative = self.base_negative_prompt
if negative_prompt: if negative_prompt:
@@ -94,37 +91,37 @@ class MaiZoneNovelAIService:
final_negative = f"{final_negative}, {negative_prompt}" final_negative = f"{final_negative}, {negative_prompt}"
else: else:
final_negative = negative_prompt final_negative = negative_prompt
logger.info(f"🎨 开始生成图片...") logger.info("🎨 开始生成图片...")
logger.info(f" 尺寸: {width}x{height}") logger.info(f" 尺寸: {width}x{height}")
logger.info(f" 正面提示词: {final_prompt[:100]}...") logger.info(f" 正面提示词: {final_prompt[:100]}...")
logger.info(f" 负面提示词: {final_negative[:100]}...") logger.info(f" 负面提示词: {final_negative[:100]}...")
# 构建请求payload # 构建请求payload
payload = self._build_payload(final_prompt, final_negative, width, height) payload = self._build_payload(final_prompt, final_negative, width, height)
# 发送请求 # 发送请求
image_data = await self._call_novelai_api(payload) image_data = await self._call_novelai_api(payload)
if not image_data: if not image_data:
return False, None, "API请求失败" return False, None, "API请求失败"
# 保存图片 # 保存图片
image_path = await self._save_image(image_data) image_path = await self._save_image(image_data)
if not image_path: if not image_path:
return False, None, "图片保存失败" return False, None, "图片保存失败"
logger.info(f"✅ 图片生成成功: {image_path}") logger.info(f"✅ 图片生成成功: {image_path}")
return True, image_path, "生成成功" return True, image_path, "生成成功"
except Exception as e: except Exception as e:
logger.error(f"生成图片时出错: {e}", exc_info=True) logger.error(f"生成图片时出错: {e}", exc_info=True)
return False, None, f"生成失败: {str(e)}" return False, None, f"生成失败: {e!s}"
def _build_payload(self, prompt: str, negative_prompt: str, width: int, height: int) -> dict: def _build_payload(self, prompt: str, negative_prompt: str, width: int, height: int) -> dict:
"""构建NovelAI API请求payload""" """构建NovelAI API请求payload"""
is_v4_model = "diffusion-4" in self.model is_v4_model = "diffusion-4" in self.model
is_v3_model = "diffusion-3" in self.model is_v3_model = "diffusion-3" in self.model
parameters = { parameters = {
"width": width, "width": width,
"height": height, "height": height,
@@ -139,7 +136,7 @@ class MaiZoneNovelAIService:
"sm_dyn": False, "sm_dyn": False,
"noise_schedule": self.noise_schedule if is_v4_model else "native", "noise_schedule": self.noise_schedule if is_v4_model else "native",
} }
# V4.5模型使用新格式 # V4.5模型使用新格式
if is_v4_model: if is_v4_model:
parameters.update({ parameters.update({
@@ -183,39 +180,39 @@ class MaiZoneNovelAIService:
# V3使用negative_prompt字段 # V3使用negative_prompt字段
elif is_v3_model: elif is_v3_model:
parameters["negative_prompt"] = negative_prompt parameters["negative_prompt"] = negative_prompt
payload = { payload = {
"input": prompt, "input": prompt,
"model": self.model, "model": self.model,
"action": "generate", "action": "generate",
"parameters": parameters "parameters": parameters
} }
# V4.5需要额外字段 # V4.5需要额外字段
if is_v4_model: if is_v4_model:
payload["use_new_shared_trial"] = True payload["use_new_shared_trial"] = True
return payload return payload
async def _call_novelai_api(self, payload: dict) -> Optional[bytes]: async def _call_novelai_api(self, payload: dict) -> bytes | None:
"""调用NovelAI API""" """调用NovelAI API"""
headers = { headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
connector = None connector = None
request_kwargs = { request_kwargs = {
"json": payload, "json": payload,
"headers": headers, "headers": headers,
"timeout": aiohttp.ClientTimeout(total=120) "timeout": aiohttp.ClientTimeout(total=120)
} }
if self.proxy: if self.proxy:
request_kwargs["proxy"] = self.proxy request_kwargs["proxy"] = self.proxy
connector = aiohttp.TCPConnector() connector = aiohttp.TCPConnector()
logger.info(f"使用代理: {self.proxy}") logger.info(f"使用代理: {self.proxy}")
try: try:
async with aiohttp.ClientSession(connector=connector) as session: async with aiohttp.ClientSession(connector=connector) as session:
async with session.post(self.base_url, **request_kwargs) as resp: async with session.post(self.base_url, **request_kwargs) as resp:
@@ -223,31 +220,31 @@ class MaiZoneNovelAIService:
error_text = await resp.text() error_text = await resp.text()
logger.error(f"API请求失败 ({resp.status}): {error_text[:200]}") logger.error(f"API请求失败 ({resp.status}): {error_text[:200]}")
return None return None
img_data = await resp.read() img_data = await resp.read()
logger.info(f"收到响应数据: {len(img_data)} bytes") logger.info(f"收到响应数据: {len(img_data)} bytes")
# 检查是否是ZIP文件 # 检查是否是ZIP文件
if img_data[:4] == b'PK\x03\x04': if img_data[:4] == b"PK\x03\x04":
logger.info("检测到ZIP格式解压中...") logger.info("检测到ZIP格式解压中...")
return self._extract_from_zip(img_data) return self._extract_from_zip(img_data)
elif img_data[:4] == b'\x89PNG': elif img_data[:4] == b"\x89PNG":
logger.info("检测到PNG格式") logger.info("检测到PNG格式")
return img_data return img_data
else: else:
logger.warning(f"未知文件格式前4字节: {img_data[:4].hex()}") logger.warning(f"未知文件格式前4字节: {img_data[:4].hex()}")
return img_data return img_data
except Exception as e: except Exception as e:
logger.error(f"API调用失败: {e}", exc_info=True) logger.error(f"API调用失败: {e}", exc_info=True)
return None return None
def _extract_from_zip(self, zip_data: bytes) -> Optional[bytes]: def _extract_from_zip(self, zip_data: bytes) -> bytes | None:
"""从ZIP中提取PNG""" """从ZIP中提取PNG"""
try: try:
with zipfile.ZipFile(io.BytesIO(zip_data)) as zf: with zipfile.ZipFile(io.BytesIO(zip_data)) as zf:
for filename in zf.namelist(): for filename in zf.namelist():
if filename.lower().endswith('.png'): if filename.lower().endswith(".png"):
img_data = zf.read(filename) img_data = zf.read(filename)
logger.info(f"从ZIP提取: {filename} ({len(img_data)} bytes)") logger.info(f"从ZIP提取: {filename} ({len(img_data)} bytes)")
return img_data return img_data
@@ -256,20 +253,20 @@ class MaiZoneNovelAIService:
except Exception as e: except Exception as e:
logger.error(f"解压ZIP失败: {e}") logger.error(f"解压ZIP失败: {e}")
return None return None
async def _save_image(self, image_data: bytes) -> Optional[Path]: async def _save_image(self, image_data: bytes) -> Path | None:
"""保存图片到本地""" """保存图片到本地"""
try: try:
filename = f"novelai_{uuid.uuid4().hex[:12]}.png" filename = f"novelai_{uuid.uuid4().hex[:12]}.png"
filepath = self.image_dir / filename filepath = self.image_dir / filename
# 写入文件 # 写入文件
with open(filepath, "wb") as f: with open(filepath, "wb") as f:
f.write(image_data) f.write(image_data)
f.flush() f.flush()
import os import os
os.fsync(f.fileno()) os.fsync(f.fileno())
# 验证图片 # 验证图片
try: try:
with Image.open(filepath) as img: with Image.open(filepath) as img:
@@ -278,9 +275,9 @@ class MaiZoneNovelAIService:
logger.info(f"图片验证成功: {img.format} {img.size}") logger.info(f"图片验证成功: {img.format} {img.size}")
except Exception as e: except Exception as e:
logger.warning(f"图片验证失败: {e}") logger.warning(f"图片验证失败: {e}")
return filepath return filepath
except Exception as e: except Exception as e:
logger.error(f"保存图片失败: {e}") logger.error(f"保存图片失败: {e}")
return None return None

View File

@@ -5,7 +5,6 @@ QQ空间服务模块
import asyncio import asyncio
import base64 import base64
import os
import random import random
import time import time
from collections.abc import Callable from collections.abc import Callable
@@ -85,25 +84,25 @@ class QZoneService:
async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]: async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]:
"""发送一条说说支持AI配图""" """发送一条说说支持AI配图"""
cross_context = await self._get_cross_context() cross_context = await self._get_cross_context()
# 检查是否启用AI配图 # 检查是否启用AI配图
ai_image_enabled = self.get_config("ai_image.enable_ai_image", False) ai_image_enabled = self.get_config("ai_image.enable_ai_image", False)
provider = self.get_config("ai_image.provider", "siliconflow") provider = self.get_config("ai_image.provider", "siliconflow")
image_path = None image_path = None
if ai_image_enabled: if ai_image_enabled:
# 启用AI配图文本模型生成说说+图片提示词 # 启用AI配图文本模型生成说说+图片提示词
story, image_info = await self.content_service.generate_story_with_image_info(topic, context=cross_context) story, image_info = await self.content_service.generate_story_with_image_info(topic, context=cross_context)
if not story: if not story:
return {"success": False, "message": "生成说说内容失败"} return {"success": False, "message": "生成说说内容失败"}
# 根据provider调用对应的生图服务 # 根据provider调用对应的生图服务
if provider == "novelai": if provider == "novelai":
try: try:
from .novelai_service import MaiZoneNovelAIService from .novelai_service import MaiZoneNovelAIService
novelai_service = MaiZoneNovelAIService(self.get_config) novelai_service = MaiZoneNovelAIService(self.get_config)
if novelai_service.is_available(): if novelai_service.is_available():
# 解析画幅 # 解析画幅
aspect_ratio = image_info.get("aspect_ratio", "方图") aspect_ratio = image_info.get("aspect_ratio", "方图")
@@ -113,8 +112,8 @@ class QZoneService:
"竖图": (832, 1216), "竖图": (832, 1216),
} }
width, height = size_map.get(aspect_ratio, (1024, 1024)) width, height = size_map.get(aspect_ratio, (1024, 1024))
logger.info(f"🎨 开始生成NovelAI配图...") logger.info("🎨 开始生成NovelAI配图...")
success, img_path, msg = await novelai_service.generate_image_from_prompt_data( success, img_path, msg = await novelai_service.generate_image_from_prompt_data(
prompt=image_info.get("prompt", ""), prompt=image_info.get("prompt", ""),
negative_prompt=image_info.get("negative_prompt"), negative_prompt=image_info.get("negative_prompt"),
@@ -122,18 +121,18 @@ class QZoneService:
width=width, width=width,
height=height height=height
) )
if success and img_path: if success and img_path:
image_path = img_path image_path = img_path
logger.info(f"✅ NovelAI配图生成成功") logger.info("✅ NovelAI配图生成成功")
else: else:
logger.warning(f"⚠️ NovelAI配图生成失败: {msg}") logger.warning(f"⚠️ NovelAI配图生成失败: {msg}")
else: else:
logger.warning("NovelAI服务不可用未配置API Key") logger.warning("NovelAI服务不可用未配置API Key")
except Exception as e: except Exception as e:
logger.error(f"NovelAI配图生成出错: {e}", exc_info=True) logger.error(f"NovelAI配图生成出错: {e}", exc_info=True)
elif provider == "siliconflow": elif provider == "siliconflow":
try: try:
# 调用硅基流动生成图片 # 调用硅基流动生成图片
@@ -143,9 +142,9 @@ class QZoneService:
) )
if success and img_path: if success and img_path:
image_path = img_path image_path = img_path
logger.info(f"✅ 硅基流动配图生成成功") logger.info("✅ 硅基流动配图生成成功")
else: else:
logger.warning(f"⚠️ 硅基流动配图生成失败") logger.warning("⚠️ 硅基流动配图生成失败")
except Exception as e: except Exception as e:
logger.error(f"硅基流动配图生成出错: {e}", exc_info=True) logger.error(f"硅基流动配图生成出错: {e}", exc_info=True)
else: else:
@@ -161,13 +160,13 @@ class QZoneService:
# 加载图片 # 加载图片
images_bytes = [] images_bytes = []
# 使用AI生成的图片 # 使用AI生成的图片
if image_path and image_path.exists(): if image_path and image_path.exists():
try: try:
with open(image_path, "rb") as f: with open(image_path, "rb") as f:
images_bytes.append(f.read()) images_bytes.append(f.read())
logger.info(f"添加AI配图到说说") logger.info("添加AI配图到说说")
except Exception as e: except Exception as e:
logger.error(f"读取AI配图失败: {e}") logger.error(f"读取AI配图失败: {e}")

View File

@@ -416,18 +416,18 @@ class NapcatAdapterPlugin(BasePlugin):
"reply_at_rate": ConfigField(type=float, default=0.5, description="回复时@的概率0.0-1.0"), "reply_at_rate": ConfigField(type=float, default=0.5, description="回复时@的概率0.0-1.0"),
# ========== 视频消息处理配置 ========== # ========== 视频消息处理配置 ==========
"enable_video_processing": ConfigField( "enable_video_processing": ConfigField(
type=bool, type=bool,
default=True, default=True,
description="是否启用视频消息处理(下载和解析)。关闭后视频消息将显示为 [视频消息] 占位符,不会进行下载" description="是否启用视频消息处理(下载和解析)。关闭后视频消息将显示为 [视频消息] 占位符,不会进行下载"
), ),
"video_max_size_mb": ConfigField( "video_max_size_mb": ConfigField(
type=int, type=int,
default=100, default=100,
description="允许下载的视频文件最大大小MB超过此大小的视频将被跳过" description="允许下载的视频文件最大大小MB超过此大小的视频将被跳过"
), ),
"video_download_timeout": ConfigField( "video_download_timeout": ConfigField(
type=int, type=int,
default=60, default=60,
description="视频下载超时时间(秒),若超时将中止下载" description="视频下载超时时间(秒),若超时将中止下载"
), ),
}, },

View File

@@ -42,14 +42,14 @@ class MessageHandler:
def set_plugin_config(self, config: dict[str, Any]) -> None: def set_plugin_config(self, config: dict[str, Any]) -> None:
"""设置插件配置,并根据配置初始化视频下载器""" """设置插件配置,并根据配置初始化视频下载器"""
self.plugin_config = config self.plugin_config = config
# 如果启用了视频处理,根据配置初始化视频下载器 # 如果启用了视频处理,根据配置初始化视频下载器
if config_api.get_plugin_config(config, "features.enable_video_processing", True): if config_api.get_plugin_config(config, "features.enable_video_processing", True):
from ..video_handler import VideoDownloader from ..video_handler import VideoDownloader
max_size = config_api.get_plugin_config(config, "features.video_max_size_mb", 100) max_size = config_api.get_plugin_config(config, "features.video_max_size_mb", 100)
timeout = config_api.get_plugin_config(config, "features.video_download_timeout", 60) timeout = config_api.get_plugin_config(config, "features.video_download_timeout", 60)
self._video_downloader = VideoDownloader(max_size_mb=max_size, download_timeout=timeout) self._video_downloader = VideoDownloader(max_size_mb=max_size, download_timeout=timeout)
logger.debug(f"视频下载器已初始化: max_size={max_size}MB, timeout={timeout}s") logger.debug(f"视频下载器已初始化: max_size={max_size}MB, timeout={timeout}s")
@@ -341,7 +341,7 @@ class MessageHandler:
if not downloader: if not downloader:
from ..video_handler import get_video_downloader from ..video_handler import get_video_downloader
downloader = get_video_downloader() downloader = get_video_downloader()
download_result = await downloader.download_video(video_url) download_result = await downloader.download_video(video_url)
if not download_result["success"]: if not download_result["success"]: