ruff
This commit is contained in:
@@ -6,8 +6,8 @@ from pathlib import Path
|
||||
project_root = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(project_root))
|
||||
|
||||
from src.memory_graph.manager_singleton import get_unified_memory_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.memory_graph.manager_singleton import get_unified_memory_manager
|
||||
|
||||
logger = get_logger("memory_transfer_check")
|
||||
|
||||
@@ -22,20 +22,20 @@ def print_section(title: str):
|
||||
async def check_short_term_status():
|
||||
"""检查短期记忆状态"""
|
||||
print_section("1. 短期记忆状态检查")
|
||||
|
||||
|
||||
manager = get_unified_memory_manager()
|
||||
short_term = manager.short_term_manager
|
||||
|
||||
|
||||
# 获取统计信息
|
||||
stats = short_term.get_statistics()
|
||||
|
||||
|
||||
print(f"📊 当前记忆数量: {stats['total_memories']}/{stats['max_memories']}")
|
||||
|
||||
|
||||
# 计算占用率
|
||||
if stats['max_memories'] > 0:
|
||||
occupancy = stats['total_memories'] / stats['max_memories']
|
||||
if stats["max_memories"] > 0:
|
||||
occupancy = stats["total_memories"] / stats["max_memories"]
|
||||
print(f"📈 容量占用率: {occupancy:.1%}")
|
||||
|
||||
|
||||
# 根据占用率给出建议
|
||||
if occupancy >= 1.0:
|
||||
print("⚠️ 警告:已达到容量上限!应该触发紧急转移")
|
||||
@@ -43,25 +43,25 @@ async def check_short_term_status():
|
||||
print("✅ 占用率超过50%,符合自动转移条件")
|
||||
else:
|
||||
print(f"ℹ️ 占用率未达到50%阈值,当前 {occupancy:.1%}")
|
||||
|
||||
|
||||
print(f"🎯 可转移记忆数: {stats['transferable_count']}")
|
||||
print(f"📏 转移重要性阈值: {stats['transfer_threshold']}")
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def check_transfer_candidates():
|
||||
"""检查当前可转移的候选记忆"""
|
||||
print_section("2. 转移候选记忆分析")
|
||||
|
||||
|
||||
manager = get_unified_memory_manager()
|
||||
short_term = manager.short_term_manager
|
||||
|
||||
|
||||
# 获取转移候选
|
||||
candidates = short_term.get_memories_for_transfer()
|
||||
|
||||
|
||||
print(f"🎫 当前转移候选: {len(candidates)} 条\n")
|
||||
|
||||
|
||||
if not candidates:
|
||||
print("❌ 没有记忆符合转移条件!")
|
||||
print("\n可能原因:")
|
||||
@@ -69,7 +69,7 @@ async def check_transfer_candidates():
|
||||
print(" 2. 短期记忆数量未超过容量限制")
|
||||
print(" 3. 短期记忆列表为空")
|
||||
return []
|
||||
|
||||
|
||||
# 显示前5条候选的详细信息
|
||||
print("前 5 条候选记忆:\n")
|
||||
for i, mem in enumerate(candidates[:5], 1):
|
||||
@@ -78,38 +78,38 @@ async def check_transfer_candidates():
|
||||
print(f" 内容: {mem.content[:50]}...")
|
||||
print(f" 创建时间: {mem.created_at}")
|
||||
print()
|
||||
|
||||
|
||||
if len(candidates) > 5:
|
||||
print(f"... 还有 {len(candidates) - 5} 条候选记忆\n")
|
||||
|
||||
|
||||
# 分析重要性分布
|
||||
importance_levels = {
|
||||
"高 (>=0.8)": sum(1 for m in candidates if m.importance >= 0.8),
|
||||
"中 (0.6-0.8)": sum(1 for m in candidates if 0.6 <= m.importance < 0.8),
|
||||
"低 (<0.6)": sum(1 for m in candidates if m.importance < 0.6),
|
||||
}
|
||||
|
||||
|
||||
print("📊 重要性分布:")
|
||||
for level, count in importance_levels.items():
|
||||
print(f" {level}: {count} 条")
|
||||
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
async def check_auto_transfer_task():
|
||||
"""检查自动转移任务状态"""
|
||||
print_section("3. 自动转移任务状态")
|
||||
|
||||
|
||||
manager = get_unified_memory_manager()
|
||||
|
||||
|
||||
# 检查任务是否存在
|
||||
if not hasattr(manager, '_auto_transfer_task') or manager._auto_transfer_task is None:
|
||||
if not hasattr(manager, "_auto_transfer_task") or manager._auto_transfer_task is None:
|
||||
print("❌ 自动转移任务未创建!")
|
||||
print("\n建议:调用 manager.initialize() 初始化系统")
|
||||
return False
|
||||
|
||||
|
||||
task = manager._auto_transfer_task
|
||||
|
||||
|
||||
# 检查任务状态
|
||||
if task.done():
|
||||
print("❌ 自动转移任务已结束!")
|
||||
@@ -121,78 +121,78 @@ async def check_auto_transfer_task():
|
||||
pass
|
||||
print("\n建议:重启系统或手动重启任务")
|
||||
return False
|
||||
|
||||
|
||||
print("✅ 自动转移任务正在运行")
|
||||
|
||||
|
||||
# 检查转移缓存
|
||||
if hasattr(manager, '_transfer_cache'):
|
||||
if hasattr(manager, "_transfer_cache"):
|
||||
cache_size = len(manager._transfer_cache) if manager._transfer_cache else 0
|
||||
print(f"📦 转移缓存: {cache_size} 条记忆")
|
||||
|
||||
|
||||
# 检查上次转移时间
|
||||
if hasattr(manager, '_last_transfer_time'):
|
||||
if hasattr(manager, "_last_transfer_time"):
|
||||
from datetime import datetime
|
||||
last_time = manager._last_transfer_time
|
||||
if last_time:
|
||||
time_diff = (datetime.now() - last_time).total_seconds()
|
||||
print(f"⏱️ 距上次转移: {time_diff:.1f} 秒前")
|
||||
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def check_long_term_status():
|
||||
"""检查长期记忆状态"""
|
||||
print_section("4. 长期记忆图谱状态")
|
||||
|
||||
|
||||
manager = get_unified_memory_manager()
|
||||
long_term = manager.long_term_manager
|
||||
|
||||
|
||||
# 获取图谱统计
|
||||
stats = long_term.get_statistics()
|
||||
|
||||
|
||||
print(f"👥 人物节点数: {stats.get('person_count', 0)}")
|
||||
print(f"📅 事件节点数: {stats.get('event_count', 0)}")
|
||||
print(f"🔗 关系边数: {stats.get('edge_count', 0)}")
|
||||
print(f"💾 向量存储数: {stats.get('vector_count', 0)}")
|
||||
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def manual_transfer_test():
|
||||
"""手动触发转移测试"""
|
||||
print_section("5. 手动转移测试")
|
||||
|
||||
|
||||
manager = get_unified_memory_manager()
|
||||
|
||||
|
||||
# 询问用户是否执行
|
||||
print("⚠️ 即将手动触发一次记忆转移")
|
||||
print("这将把当前符合条件的短期记忆转移到长期记忆")
|
||||
response = input("\n是否继续? (y/n): ").strip().lower()
|
||||
|
||||
if response != 'y':
|
||||
|
||||
if response != "y":
|
||||
print("❌ 已取消手动转移")
|
||||
return None
|
||||
|
||||
|
||||
print("\n🚀 开始手动转移...")
|
||||
|
||||
|
||||
try:
|
||||
# 执行手动转移
|
||||
result = await manager.manual_transfer()
|
||||
|
||||
|
||||
print("\n✅ 转移完成!")
|
||||
print(f"\n转移结果:")
|
||||
print("\n转移结果:")
|
||||
print(f" 已处理: {result.get('processed_count', 0)} 条")
|
||||
print(f" 成功转移: {len(result.get('transferred_memory_ids', []))} 条")
|
||||
print(f" 失败: {result.get('failed_count', 0)} 条")
|
||||
print(f" 跳过: {result.get('skipped_count', 0)} 条")
|
||||
|
||||
if result.get('errors'):
|
||||
print(f"\n错误信息:")
|
||||
for error in result['errors'][:3]: # 只显示前3个错误
|
||||
|
||||
if result.get("errors"):
|
||||
print("\n错误信息:")
|
||||
for error in result["errors"][:3]: # 只显示前3个错误
|
||||
print(f" - {error}")
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 转移失败: {e}")
|
||||
logger.exception("手动转移失败")
|
||||
@@ -202,29 +202,29 @@ async def manual_transfer_test():
|
||||
async def check_configuration():
|
||||
"""检查相关配置"""
|
||||
print_section("6. 配置参数检查")
|
||||
|
||||
|
||||
from src.config.config import global_config
|
||||
|
||||
|
||||
config = global_config.memory
|
||||
|
||||
|
||||
print("📋 当前配置:")
|
||||
print(f" 短期记忆容量: {config.short_term_max_memories}")
|
||||
print(f" 转移重要性阈值: {config.short_term_transfer_threshold}")
|
||||
print(f" 批量转移大小: {config.long_term_batch_size}")
|
||||
print(f" 自动转移间隔: {config.long_term_auto_transfer_interval} 秒")
|
||||
print(f" 启用泄压清理: {config.short_term_enable_force_cleanup}")
|
||||
|
||||
|
||||
# 给出配置建议
|
||||
print("\n💡 配置建议:")
|
||||
|
||||
|
||||
if config.short_term_transfer_threshold > 0.6:
|
||||
print(" ⚠️ 转移阈值较高(>0.6),可能导致记忆难以转移")
|
||||
print(" 建议:降低到 0.4-0.5")
|
||||
|
||||
|
||||
if config.long_term_batch_size > 10:
|
||||
print(" ⚠️ 批量大小较大(>10),可能延迟转移触发")
|
||||
print(" 建议:设置为 5-10")
|
||||
|
||||
|
||||
if config.long_term_auto_transfer_interval > 300:
|
||||
print(" ⚠️ 转移间隔较长(>5分钟),可能导致转移不及时")
|
||||
print(" 建议:设置为 60-180 秒")
|
||||
@@ -235,37 +235,37 @@ async def main():
|
||||
print("\n" + "=" * 60)
|
||||
print(" MoFox-Bot 记忆转移诊断工具")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
try:
|
||||
# 初始化管理器
|
||||
print("\n⚙️ 正在初始化记忆管理器...")
|
||||
manager = get_unified_memory_manager()
|
||||
await manager.initialize()
|
||||
print("✅ 初始化完成\n")
|
||||
|
||||
|
||||
# 执行各项检查
|
||||
await check_short_term_status()
|
||||
candidates = await check_transfer_candidates()
|
||||
task_running = await check_auto_transfer_task()
|
||||
await check_long_term_status()
|
||||
await check_configuration()
|
||||
|
||||
|
||||
# 综合诊断
|
||||
print_section("7. 综合诊断结果")
|
||||
|
||||
|
||||
issues = []
|
||||
|
||||
|
||||
if not candidates:
|
||||
issues.append("❌ 没有符合条件的转移候选")
|
||||
|
||||
|
||||
if not task_running:
|
||||
issues.append("❌ 自动转移任务未运行")
|
||||
|
||||
|
||||
if issues:
|
||||
print("🚨 发现以下问题:\n")
|
||||
for issue in issues:
|
||||
print(f" {issue}")
|
||||
|
||||
|
||||
print("\n建议操作:")
|
||||
print(" 1. 检查短期记忆的重要性评分是否合理")
|
||||
print(" 2. 降低配置中的转移阈值")
|
||||
@@ -273,7 +273,7 @@ async def main():
|
||||
print(" 4. 尝试手动触发转移测试")
|
||||
else:
|
||||
print("✅ 系统运行正常,转移机制已就绪")
|
||||
|
||||
|
||||
if candidates:
|
||||
print(f"\n当前有 {len(candidates)} 条记忆等待转移")
|
||||
print("转移将在满足以下任一条件时自动触发:")
|
||||
@@ -281,20 +281,20 @@ async def main():
|
||||
print(" • 短期记忆占用率超过 50%")
|
||||
print(" • 距上次转移超过最大延迟")
|
||||
print(" • 短期记忆达到容量上限")
|
||||
|
||||
|
||||
# 询问是否手动触发转移
|
||||
if candidates:
|
||||
print()
|
||||
await manual_transfer_test()
|
||||
|
||||
|
||||
print_section("检查完成")
|
||||
print("详细诊断报告: docs/memory_transfer_diagnostic_report.md")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
print(f"\n❌ 检查过程出错: {e}")
|
||||
logger.exception("检查脚本执行失败")
|
||||
return 1
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ from pathlib import Path
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(PROJECT_ROOT))
|
||||
|
||||
from src.config.config import global_config # noqa: E402
|
||||
from src.memory_graph.short_term_manager import ShortTermMemoryManager # noqa: E402
|
||||
from src.config.config import global_config
|
||||
from src.memory_graph.short_term_manager import ShortTermMemoryManager
|
||||
|
||||
|
||||
def resolve_data_dir() -> Path:
|
||||
|
||||
@@ -12,17 +12,16 @@ from typing import Any, Optional, cast
|
||||
|
||||
import json_repair
|
||||
from PIL import Image
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
|
||||
from src.chat.emoji_system.emoji_constants import EMOJI_DIR, EMOJI_REGISTERED_DIR, MAX_EMOJI_FOR_PROMPT
|
||||
from src.chat.emoji_system.emoji_entities import MaiEmoji
|
||||
from src.chat.emoji_system.emoji_utils import (
|
||||
_emoji_objects_to_readable_list,
|
||||
_to_emoji_objects,
|
||||
_ensure_emoji_dir,
|
||||
clear_temp_emoji,
|
||||
_to_emoji_objects,
|
||||
clean_unused_emojis,
|
||||
clear_temp_emoji,
|
||||
list_image_files,
|
||||
)
|
||||
from src.chat.utils.utils_image import get_image_manager, image_path_to_base64
|
||||
|
||||
@@ -415,20 +415,20 @@ class ExpressionLearner:
|
||||
.offset(offset)
|
||||
)
|
||||
batch_expressions = list(batch_result.scalars())
|
||||
|
||||
|
||||
if not batch_expressions:
|
||||
break # 没有更多数据
|
||||
|
||||
|
||||
# 批量处理当前批次
|
||||
to_delete = []
|
||||
for expr in batch_expressions:
|
||||
# 计算时间差
|
||||
time_diff_days = (current_time - expr.last_active_time) / (24 * 3600)
|
||||
|
||||
|
||||
# 计算衰减值
|
||||
decay_value = self.calculate_decay_factor(time_diff_days)
|
||||
new_count = max(0.01, expr.count - decay_value)
|
||||
|
||||
|
||||
if new_count <= 0.01:
|
||||
# 标记删除
|
||||
to_delete.append(expr)
|
||||
@@ -436,22 +436,22 @@ class ExpressionLearner:
|
||||
# 更新count
|
||||
expr.count = new_count
|
||||
updated_count += 1
|
||||
|
||||
|
||||
# 批量删除
|
||||
if to_delete:
|
||||
for expr in to_delete:
|
||||
await session.delete(expr)
|
||||
deleted_count += len(to_delete)
|
||||
|
||||
|
||||
# 提交当前批次
|
||||
await session.commit()
|
||||
|
||||
|
||||
# 如果批次不满,说明已经处理完所有数据
|
||||
if len(batch_expressions) < BATCH_SIZE:
|
||||
break
|
||||
|
||||
|
||||
offset += BATCH_SIZE
|
||||
|
||||
|
||||
if updated_count > 0 or deleted_count > 0:
|
||||
logger.info(f"全局衰减完成:更新了 {updated_count} 个表达方式,删除了 {deleted_count} 个表达方式")
|
||||
|
||||
@@ -544,12 +544,12 @@ class ExpressionLearner:
|
||||
)
|
||||
)
|
||||
existing_exprs = list(existing_exprs_result.scalars())
|
||||
|
||||
|
||||
# 构建快速查找索引
|
||||
exact_match_map = {} # (situation, style) -> Expression
|
||||
situation_map = {} # situation -> Expression
|
||||
style_map = {} # style -> Expression
|
||||
|
||||
|
||||
for expr in existing_exprs:
|
||||
key = (expr.situation, expr.style)
|
||||
exact_match_map[key] = expr
|
||||
@@ -558,13 +558,13 @@ class ExpressionLearner:
|
||||
situation_map[expr.situation] = expr
|
||||
if expr.style not in style_map:
|
||||
style_map[expr.style] = expr
|
||||
|
||||
|
||||
# 批量处理所有新表达方式
|
||||
for new_expr in expr_list:
|
||||
situation = new_expr["situation"]
|
||||
style_val = new_expr["style"]
|
||||
exact_key = (situation, style_val)
|
||||
|
||||
|
||||
# 优先处理完全匹配的情况
|
||||
if exact_key in exact_match_map:
|
||||
# 完全相同:增加count,更新时间
|
||||
@@ -578,8 +578,7 @@ class ExpressionLearner:
|
||||
logger.info(f"相同情景覆盖:'{same_situation_expr.situation}' 的表达从 '{same_situation_expr.style}' 更新为 '{style_val}'")
|
||||
# 更新映射
|
||||
old_key = (same_situation_expr.situation, same_situation_expr.style)
|
||||
if old_key in exact_match_map:
|
||||
del exact_match_map[old_key]
|
||||
exact_match_map.pop(old_key, None)
|
||||
same_situation_expr.style = style_val
|
||||
same_situation_expr.count = same_situation_expr.count + 1
|
||||
same_situation_expr.last_active_time = current_time
|
||||
@@ -591,8 +590,7 @@ class ExpressionLearner:
|
||||
logger.info(f"相同表达覆盖:'{same_style_expr.style}' 的情景从 '{same_style_expr.situation}' 更新为 '{situation}'")
|
||||
# 更新映射
|
||||
old_key = (same_style_expr.situation, same_style_expr.style)
|
||||
if old_key in exact_match_map:
|
||||
del exact_match_map[old_key]
|
||||
exact_match_map.pop(old_key, None)
|
||||
same_style_expr.situation = situation
|
||||
same_style_expr.count = same_style_expr.count + 1
|
||||
same_style_expr.last_active_time = current_time
|
||||
@@ -627,8 +625,7 @@ class ExpressionLearner:
|
||||
await session.delete(expr)
|
||||
# 从映射中移除
|
||||
key = (expr.situation, expr.style)
|
||||
if key in exact_match_map:
|
||||
del exact_match_map[key]
|
||||
exact_match_map.pop(key, None)
|
||||
logger.debug(f"已删除 {len(all_current_exprs) - MAX_EXPRESSION_COUNT} 个低频表达方式")
|
||||
|
||||
# 提交数据库更改
|
||||
@@ -658,31 +655,31 @@ class ExpressionLearner:
|
||||
# 为每个共享组内的 chat_id 训练其 StyleLearner
|
||||
for target_chat_id in related_chat_ids:
|
||||
learner = style_learner_manager.get_learner(target_chat_id)
|
||||
|
||||
|
||||
# 收集该 target_chat_id 对应的所有表达方式
|
||||
# 如果是源 chat_id,使用 chat_dict 中的数据;否则也要训练(共享组特性)
|
||||
total_success = 0
|
||||
total_samples = 0
|
||||
|
||||
|
||||
for source_chat_id, expr_list in chat_dict.items():
|
||||
# 为每个学习到的表达方式训练模型
|
||||
# 使用 situation 作为输入,style 作为目标
|
||||
for expr in expr_list:
|
||||
situation = expr["situation"]
|
||||
style = expr["style"]
|
||||
|
||||
|
||||
# 训练映射关系: situation -> style
|
||||
if learner.learn_mapping(situation, style):
|
||||
total_success += 1
|
||||
total_samples += 1
|
||||
|
||||
|
||||
# 保存模型
|
||||
if total_samples > 0:
|
||||
if learner.save(style_learner_manager.model_save_path):
|
||||
logger.debug(f"StyleLearner 模型保存成功: {target_chat_id}")
|
||||
else:
|
||||
logger.error(f"StyleLearner 模型保存失败: {target_chat_id}")
|
||||
|
||||
|
||||
if target_chat_id == self.chat_id:
|
||||
# 只为当前 chat_id 记录详细日志
|
||||
logger.info(
|
||||
|
||||
@@ -218,7 +218,7 @@ class ExpressionSelector:
|
||||
"type": expr_type,
|
||||
"create_date": expr.create_date if expr.create_date is not None else expr.last_active_time,
|
||||
}
|
||||
|
||||
|
||||
style_exprs = [expr_to_dict(expr, "style") for expr in style_query.scalars()]
|
||||
grammar_exprs = [expr_to_dict(expr, "grammar") for expr in grammar_query.scalars()]
|
||||
|
||||
@@ -246,7 +246,7 @@ class ExpressionSelector:
|
||||
"""
|
||||
if not expressions_to_update:
|
||||
return
|
||||
|
||||
|
||||
# 去重处理
|
||||
updates_by_key = {}
|
||||
affected_chat_ids = set()
|
||||
@@ -524,7 +524,7 @@ class ExpressionSelector:
|
||||
|
||||
# 预处理:提前计算所有预测 style 的小写版本,避免重复计算
|
||||
predicted_styles_lower = [(s.lower(), score) for s, score in predicted_styles[:20]]
|
||||
|
||||
|
||||
matched_expressions = []
|
||||
for expr in all_expressions:
|
||||
db_style = expr.style or ""
|
||||
@@ -539,7 +539,7 @@ class ExpressionSelector:
|
||||
max_similarity = 1.0
|
||||
best_predicted = predicted_style_lower
|
||||
break
|
||||
|
||||
|
||||
# 快速检查:子串匹配
|
||||
if len(predicted_style_lower) >= 2 and len(db_style_lower) >= 2:
|
||||
if predicted_style_lower in db_style_lower or db_style_lower in predicted_style_lower:
|
||||
@@ -548,7 +548,7 @@ class ExpressionSelector:
|
||||
max_similarity = similarity
|
||||
best_predicted = predicted_style_lower
|
||||
continue
|
||||
|
||||
|
||||
# 计算字符串相似度(较慢,只在必要时使用)
|
||||
similarity = SequenceMatcher(None, predicted_style_lower, db_style_lower).ratio()
|
||||
if similarity > max_similarity:
|
||||
|
||||
@@ -38,7 +38,7 @@ class InterestManager:
|
||||
self._calculation_queue = asyncio.Queue()
|
||||
self._worker_task = None
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
|
||||
# 性能优化相关字段
|
||||
self._result_cache: OrderedDict[str, InterestCalculationResult] = OrderedDict() # LRU缓存
|
||||
self._cache_max_size = 1000 # 最大缓存数量
|
||||
@@ -48,13 +48,13 @@ class InterestManager:
|
||||
self._batch_timeout = 0.1 # 批处理超时(秒)
|
||||
self._batch_task = None
|
||||
self._is_warmed_up = False # 预热状态标记
|
||||
|
||||
|
||||
# 性能统计
|
||||
self._cache_hits = 0
|
||||
self._cache_misses = 0
|
||||
self._batch_calculations = 0
|
||||
self._total_calculation_time = 0.0
|
||||
|
||||
|
||||
self._initialized = True
|
||||
|
||||
async def initialize(self):
|
||||
@@ -67,7 +67,7 @@ class InterestManager:
|
||||
async def shutdown(self):
|
||||
"""关闭管理器"""
|
||||
self._shutdown_event.set()
|
||||
|
||||
|
||||
# 取消批处理任务
|
||||
if self._batch_task and not self._batch_task.done():
|
||||
self._batch_task.cancel()
|
||||
@@ -79,7 +79,7 @@ class InterestManager:
|
||||
if self._current_calculator:
|
||||
await self._current_calculator.cleanup()
|
||||
self._current_calculator = None
|
||||
|
||||
|
||||
# 清理缓存
|
||||
self._result_cache.clear()
|
||||
|
||||
@@ -142,9 +142,9 @@ class InterestManager:
|
||||
interest_value=0.3,
|
||||
error_message="没有可用的兴趣值计算组件",
|
||||
)
|
||||
|
||||
|
||||
message_id = getattr(message, "message_id", "")
|
||||
|
||||
|
||||
# 缓存查询
|
||||
if use_cache and message_id:
|
||||
cached_result = self._get_from_cache(message_id)
|
||||
@@ -183,11 +183,11 @@ class InterestManager:
|
||||
interest_value=0.3,
|
||||
error_message=f"计算异常: {e!s}",
|
||||
)
|
||||
|
||||
|
||||
# 缓存结果
|
||||
if use_cache and result.success and message_id:
|
||||
self._put_to_cache(message_id, result)
|
||||
|
||||
|
||||
return result
|
||||
|
||||
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
|
||||
@@ -249,36 +249,36 @@ class InterestManager:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"计算工作线程异常: {e}")
|
||||
|
||||
|
||||
def _get_from_cache(self, message_id: str) -> InterestCalculationResult | None:
|
||||
"""从缓存中获取结果(LRU策略)"""
|
||||
if message_id not in self._result_cache:
|
||||
return None
|
||||
|
||||
|
||||
# 检查TTL
|
||||
result = self._result_cache[message_id]
|
||||
if time.time() - result.timestamp > self._cache_ttl:
|
||||
# 过期,删除
|
||||
del self._result_cache[message_id]
|
||||
return None
|
||||
|
||||
|
||||
# 更新访问顺序(LRU)
|
||||
self._result_cache.move_to_end(message_id)
|
||||
return result
|
||||
|
||||
|
||||
def _put_to_cache(self, message_id: str, result: InterestCalculationResult):
|
||||
"""将结果放入缓存(LRU策略)"""
|
||||
# 如果已存在,更新
|
||||
if message_id in self._result_cache:
|
||||
self._result_cache.move_to_end(message_id)
|
||||
|
||||
|
||||
self._result_cache[message_id] = result
|
||||
|
||||
|
||||
# 限制缓存大小
|
||||
while len(self._result_cache) > self._cache_max_size:
|
||||
# 删除最旧的项
|
||||
self._result_cache.popitem(last=False)
|
||||
|
||||
|
||||
async def calculate_interest_batch(self, messages: list["DatabaseMessages"], timeout: float | None = None) -> list[InterestCalculationResult]:
|
||||
"""批量计算消息兴趣值(并发优化)
|
||||
|
||||
@@ -291,11 +291,11 @@ class InterestManager:
|
||||
"""
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
|
||||
# 并发计算所有消息
|
||||
tasks = [self.calculate_interest(msg, timeout=timeout) for msg in messages]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
|
||||
# 处理异常
|
||||
final_results = []
|
||||
for i, result in enumerate(results):
|
||||
@@ -309,44 +309,44 @@ class InterestManager:
|
||||
))
|
||||
else:
|
||||
final_results.append(result)
|
||||
|
||||
|
||||
self._batch_calculations += 1
|
||||
return final_results
|
||||
|
||||
|
||||
async def _batch_processing_worker(self):
|
||||
"""批处理工作线程"""
|
||||
while not self._shutdown_event.is_set():
|
||||
batch = []
|
||||
deadline = time.time() + self._batch_timeout
|
||||
|
||||
|
||||
try:
|
||||
# 收集批次
|
||||
while len(batch) < self._batch_size and time.time() < deadline:
|
||||
remaining_time = deadline - time.time()
|
||||
if remaining_time <= 0:
|
||||
break
|
||||
|
||||
|
||||
try:
|
||||
item = await asyncio.wait_for(self._batch_queue.get(), timeout=remaining_time)
|
||||
batch.append(item)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
|
||||
# 处理批次
|
||||
if batch:
|
||||
await self._process_batch(batch)
|
||||
|
||||
|
||||
except asyncio.CancelledError:
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"批处理工作线程异常: {e}")
|
||||
|
||||
|
||||
async def _process_batch(self, batch: list):
|
||||
"""处理批次消息"""
|
||||
# 这里可以实现具体的批处理逻辑
|
||||
# 当前版本只是占位,实际的批处理逻辑可以根据具体需求实现
|
||||
pass
|
||||
|
||||
|
||||
async def warmup(self, sample_messages: list["DatabaseMessages"] | None = None):
|
||||
"""预热兴趣计算器
|
||||
|
||||
@@ -356,10 +356,10 @@ class InterestManager:
|
||||
if not self._current_calculator:
|
||||
logger.warning("无法预热:没有可用的兴趣值计算组件")
|
||||
return
|
||||
|
||||
|
||||
logger.info("开始预热兴趣值计算器...")
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
# 如果提供了样本消息,进行预热计算
|
||||
if sample_messages:
|
||||
try:
|
||||
@@ -370,15 +370,15 @@ class InterestManager:
|
||||
logger.error(f"预热过程中出现异常: {e}")
|
||||
else:
|
||||
logger.info(f"预热完成:计算器已就绪,耗时 {time.time() - start_time:.2f}s")
|
||||
|
||||
|
||||
self._is_warmed_up = True
|
||||
|
||||
|
||||
def clear_cache(self):
|
||||
"""清空缓存"""
|
||||
cleared_count = len(self._result_cache)
|
||||
self._result_cache.clear()
|
||||
logger.info(f"已清空 {cleared_count} 条缓存记录")
|
||||
|
||||
|
||||
def set_cache_config(self, max_size: int | None = None, ttl: int | None = None):
|
||||
"""设置缓存配置
|
||||
|
||||
@@ -389,11 +389,11 @@ class InterestManager:
|
||||
if max_size is not None:
|
||||
self._cache_max_size = max_size
|
||||
logger.info(f"缓存最大容量设置为: {max_size}")
|
||||
|
||||
|
||||
if ttl is not None:
|
||||
self._cache_ttl = ttl
|
||||
logger.info(f"缓存TTL设置为: {ttl}秒")
|
||||
|
||||
|
||||
# 如果当前缓存超过新的最大值,清理旧数据
|
||||
if max_size is not None:
|
||||
while len(self._result_cache) > self._cache_max_size:
|
||||
@@ -446,14 +446,14 @@ class InterestManager:
|
||||
def has_calculator(self) -> bool:
|
||||
"""检查是否有可用的计算组件"""
|
||||
return self._current_calculator is not None and self._current_calculator.is_enabled
|
||||
|
||||
|
||||
async def adaptive_optimize(self):
|
||||
"""自适应优化:根据性能统计自动调整参数"""
|
||||
if not self._current_calculator:
|
||||
return
|
||||
|
||||
|
||||
stats = self.get_statistics()["manager_statistics"]
|
||||
|
||||
|
||||
# 根据缓存命中率调整缓存大小
|
||||
cache_hit_rate = stats["cache_hit_rate"]
|
||||
if cache_hit_rate < 0.5 and self._cache_max_size < 5000:
|
||||
@@ -469,7 +469,7 @@ class InterestManager:
|
||||
# 清理多余缓存
|
||||
while len(self._result_cache) > self._cache_max_size:
|
||||
self._result_cache.popitem(last=False)
|
||||
|
||||
|
||||
# 根据平均计算时间调整批处理参数
|
||||
avg_calc_time = stats["average_calculation_time"]
|
||||
if avg_calc_time > 0.5 and self._batch_size < 50:
|
||||
@@ -482,11 +482,11 @@ class InterestManager:
|
||||
new_batch_size = max(self._batch_size // 2, 5)
|
||||
logger.info(f"自适应优化:平均计算时间较短 ({avg_calc_time:.3f}s),减小批次大小 {self._batch_size} -> {new_batch_size}")
|
||||
self._batch_size = new_batch_size
|
||||
|
||||
|
||||
def get_performance_report(self) -> str:
|
||||
"""生成性能报告"""
|
||||
stats = self.get_statistics()["manager_statistics"]
|
||||
|
||||
|
||||
report = [
|
||||
"=" * 60,
|
||||
"兴趣值管理器性能报告",
|
||||
@@ -504,7 +504,7 @@ class InterestManager:
|
||||
f"当前计算器: {stats['current_calculator'] or '无'}",
|
||||
"=" * 60,
|
||||
]
|
||||
|
||||
|
||||
# 添加计算器统计
|
||||
if self._current_calculator:
|
||||
calc_stats = self.get_statistics()["calculator_statistics"]
|
||||
@@ -520,7 +520,7 @@ class InterestManager:
|
||||
f" 平均耗时: {calc_stats['average_calculation_time']:.4f}s",
|
||||
"=" * 60,
|
||||
])
|
||||
|
||||
|
||||
return "\n".join(report)
|
||||
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ logger = get_logger("message_manager")
|
||||
class MessageManager:
|
||||
"""消息管理器"""
|
||||
|
||||
def __init__(self, check_interval: float = 5.0):
|
||||
def __init__(self, check_interval: float = 5.0):
|
||||
self.check_interval = check_interval # 检查间隔(秒)
|
||||
self.is_running = False
|
||||
self.manager_task: asyncio.Task | None = None
|
||||
|
||||
@@ -348,12 +348,12 @@ class StatisticOutputTask(AsyncTask):
|
||||
prompt_tokens = int(record.get("prompt_tokens") or 0)
|
||||
except (ValueError, TypeError):
|
||||
prompt_tokens = 0
|
||||
|
||||
|
||||
try:
|
||||
completion_tokens = int(record.get("completion_tokens") or 0)
|
||||
except (ValueError, TypeError):
|
||||
completion_tokens = 0
|
||||
|
||||
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
stats[period_key][IN_TOK_BY_TYPE][request_type] += prompt_tokens
|
||||
@@ -378,7 +378,7 @@ class StatisticOutputTask(AsyncTask):
|
||||
cost = float(cost) if cost else 0.0
|
||||
except (ValueError, TypeError):
|
||||
cost = 0.0
|
||||
|
||||
|
||||
stats[period_key][TOTAL_COST] += cost
|
||||
stats[period_key][COST_BY_TYPE][request_type] += cost
|
||||
stats[period_key][COST_BY_USER][user_id] += cost
|
||||
|
||||
@@ -969,7 +969,7 @@ class LongTermMemoryManager:
|
||||
content=f"临时节点 - {source_id}",
|
||||
metadata={"placeholder": True, "created_by": "long_term_manager_edge_creation"}
|
||||
)
|
||||
|
||||
|
||||
if not self.memory_manager.graph_store.graph.has_node(target_id):
|
||||
logger.debug(f"目标节点不存在,创建占位符节点: {target_id}")
|
||||
self.memory_manager.graph_store.add_node(
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# ruff: noqa: G004, BLE001
|
||||
# pylint: disable=logging-fstring-interpolation,broad-except,unused-argument
|
||||
# pyright: reportOptionalMemberAccess=false
|
||||
"""
|
||||
|
||||
@@ -658,7 +658,7 @@ class ShortTermMemoryManager:
|
||||
return self._get_transfer_all_strategy()
|
||||
else: # "selective_cleanup" 或其他值默认使用选择性清理
|
||||
return self._get_selective_cleanup_strategy()
|
||||
|
||||
|
||||
def _get_transfer_all_strategy(self) -> list[ShortTermMemory]:
|
||||
"""
|
||||
"一次性转移所有"策略:当短期记忆满了以后,将所有记忆转移到长期记忆
|
||||
@@ -673,24 +673,24 @@ class ShortTermMemoryManager:
|
||||
f"将转移所有 {len(self.memories)} 条记忆到长期记忆"
|
||||
)
|
||||
return self.memories.copy()
|
||||
|
||||
|
||||
# 如果还没满,检查是否有高重要性记忆需要转移
|
||||
high_importance_memories = [
|
||||
mem for mem in self.memories
|
||||
mem for mem in self.memories
|
||||
if mem.importance >= self.transfer_importance_threshold
|
||||
]
|
||||
|
||||
|
||||
if high_importance_memories:
|
||||
logger.debug(
|
||||
f"转移策略(transfer_all): 发现 {len(high_importance_memories)} 条高重要性记忆待转移"
|
||||
)
|
||||
return high_importance_memories
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"转移策略(transfer_all): 无需转移 (当前容量 {len(self.memories)}/{self.max_memories})"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _get_selective_cleanup_strategy(self) -> list[ShortTermMemory]:
|
||||
"""
|
||||
"选择性清理"策略(原有策略):优先转移重要记忆,低重要性记忆考虑直接删除
|
||||
@@ -720,11 +720,11 @@ class ShortTermMemoryManager:
|
||||
if len(self.memories) > self.max_memories:
|
||||
# 计算需要转移的数量(目标:降到上限)
|
||||
num_to_transfer = len(self.memories) - self.max_memories
|
||||
|
||||
|
||||
# 按创建时间排序低重要性记忆,优先转移最早的(可能包含过时信息)
|
||||
low_importance_memories.sort(key=lambda x: x.created_at)
|
||||
to_transfer = low_importance_memories[:num_to_transfer]
|
||||
|
||||
|
||||
if to_transfer:
|
||||
logger.debug(
|
||||
f"转移策略(selective): 发现 {len(to_transfer)} 条低重要性记忆待转移 "
|
||||
@@ -757,7 +757,7 @@ class ShortTermMemoryManager:
|
||||
# 使用实例配置或传入参数
|
||||
if keep_ratio is None:
|
||||
keep_ratio = self.cleanup_keep_ratio
|
||||
|
||||
|
||||
current = len(self.memories)
|
||||
limit = int(self.max_memories * keep_ratio)
|
||||
if current <= self.max_memories:
|
||||
@@ -804,28 +804,28 @@ class ShortTermMemoryManager:
|
||||
self._similarity_cache.pop(mem_id, None)
|
||||
|
||||
logger.info(f"清除 {len(memory_ids)} 条已转移的短期记忆")
|
||||
|
||||
|
||||
# 在 "transfer_all" 策略下,进一步删除不重要的短期记忆
|
||||
if self.overflow_strategy == "transfer_all":
|
||||
# 计算需要删除的低重要性记忆数量
|
||||
low_importance_memories = [
|
||||
mem for mem in self.memories
|
||||
mem for mem in self.memories
|
||||
if mem.importance < self.transfer_importance_threshold
|
||||
]
|
||||
|
||||
|
||||
if low_importance_memories:
|
||||
# 按重要性和创建时间排序,删除最不重要的
|
||||
low_importance_memories.sort(key=lambda m: (m.importance, m.created_at))
|
||||
|
||||
|
||||
# 删除所有低重要性记忆
|
||||
to_delete = {mem.id for mem in low_importance_memories}
|
||||
self.memories = [mem for mem in self.memories if mem.id not in to_delete]
|
||||
|
||||
|
||||
# 更新索引
|
||||
for mem_id in to_delete:
|
||||
self._memory_id_index.pop(mem_id, None)
|
||||
self._similarity_cache.pop(mem_id, None)
|
||||
|
||||
|
||||
logger.info(
|
||||
f"transfer_all 策略: 额外删除了 {len(to_delete)} 条低重要性记忆 "
|
||||
f"(重要性 < {self.transfer_importance_threshold:.2f})"
|
||||
|
||||
@@ -936,7 +936,7 @@ class GraphStore:
|
||||
edge_type_enum = EdgeType.RELATION
|
||||
else:
|
||||
edge_type_enum = edge_type_value
|
||||
|
||||
|
||||
mem_edge = MemoryEdge(
|
||||
id=edge_dict["id"] or "",
|
||||
source_id=edge_dict["source_id"],
|
||||
|
||||
@@ -124,7 +124,7 @@ class BaseInterestCalculator(ABC):
|
||||
logger.error(f"初始化兴趣计算器失败: {e}")
|
||||
self._enabled = False
|
||||
return False
|
||||
|
||||
|
||||
async def on_initialize(self):
|
||||
"""子类可重写的初始化钩子"""
|
||||
pass
|
||||
@@ -143,7 +143,7 @@ class BaseInterestCalculator(ABC):
|
||||
except Exception as e:
|
||||
logger.error(f"清理兴趣计算器失败: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def on_cleanup(self):
|
||||
"""子类可重写的清理钩子"""
|
||||
pass
|
||||
|
||||
@@ -3,7 +3,6 @@ MaiZone(麦麦空间)- 重构版
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.plugin_system import BasePlugin, ComponentInfo, register_plugin
|
||||
|
||||
@@ -151,7 +151,7 @@ class ContentService:
|
||||
bot_personality_side = config_api.get_global_config("personality.personality_side", "")
|
||||
bot_reply_style = config_api.get_global_config("personality.reply_style", "内容积极向上")
|
||||
qq_account = config_api.get_global_config("bot.qq_account", "")
|
||||
|
||||
|
||||
# 获取角色外貌描述(用于告知LLM)
|
||||
character_prompt = self.get_config("novelai.character_prompt", "")
|
||||
|
||||
@@ -163,21 +163,21 @@ class ContentService:
|
||||
|
||||
# 构建提示词
|
||||
prompt_topic = f"主题是'{topic}'" if topic else "主题不限"
|
||||
|
||||
|
||||
# 构建人设描述
|
||||
personality_desc = f"你的核心人格:{bot_personality_core}"
|
||||
if bot_personality_side:
|
||||
personality_desc += f"\n你的人格侧面:{bot_personality_side}"
|
||||
personality_desc += f"\n\n你的表达方式:{bot_reply_style}"
|
||||
|
||||
|
||||
# 检查是否启用AI配图(统一开关)
|
||||
ai_image_enabled = self.get_config("ai_image.enable_ai_image", False)
|
||||
provider = self.get_config("ai_image.provider", "siliconflow")
|
||||
|
||||
|
||||
# NovelAI配图指引(内置)
|
||||
novelai_guide = ""
|
||||
output_format = '{"text": "说说正文内容"}'
|
||||
|
||||
|
||||
if ai_image_enabled and provider == "novelai":
|
||||
# 构建角色信息提示
|
||||
character_info = ""
|
||||
@@ -195,7 +195,7 @@ class ContentService:
|
||||
- 例如:可以搭配各种表情(smile, laugh, serious, thinking, surprised等)
|
||||
- **鼓励创意**:根据说说内容自由发挥,让画面更丰富生动!
|
||||
"""
|
||||
|
||||
|
||||
novelai_guide = f"""
|
||||
**配图说明:**
|
||||
这条说说会使用NovelAI Diffusion模型(二次元风格)生成配图。
|
||||
@@ -258,7 +258,7 @@ class ContentService:
|
||||
- 运动风:"masterpiece, best quality, 1girl, sportswear, running in park, energetic, morning light, trees background, dynamic pose, healthy lifestyle"
|
||||
- 咖啡馆:"masterpiece, best quality, 1girl, sitting in cozy cafe, holding coffee cup, warm lighting, wooden table, books beside, peaceful atmosphere"
|
||||
"""
|
||||
output_format = '''{"text": "说说正文内容", "image": {"prompt": "详细的英文提示词(包含画质+主体+场景+氛围+光线+色彩)", "negative_prompt": "负面词", "include_character": true/false, "aspect_ratio": "方图/横图/竖图"}}'''
|
||||
output_format = """{"text": "说说正文内容", "image": {"prompt": "详细的英文提示词(包含画质+主体+场景+氛围+光线+色彩)", "negative_prompt": "负面词", "include_character": true/false, "aspect_ratio": "方图/横图/竖图"}}"""
|
||||
elif ai_image_enabled and provider == "siliconflow":
|
||||
novelai_guide = """
|
||||
**配图说明:**
|
||||
@@ -277,8 +277,8 @@ class ContentService:
|
||||
- "sunset over the calm ocean, golden hour, orange and purple sky, gentle waves, peaceful and serene mood, wide angle view"
|
||||
- "cherry blossoms in spring, soft pink petals falling, blue sky, sunlight filtering through branches, peaceful park scene, gentle breeze"
|
||||
"""
|
||||
output_format = '''{"text": "说说正文内容", "image": {"prompt": "详细的英文描述(主体+场景+氛围+光线+细节)"}}'''
|
||||
|
||||
output_format = """{"text": "说说正文内容", "image": {"prompt": "详细的英文描述(主体+场景+氛围+光线+细节)"}}"""
|
||||
|
||||
prompt = f"""
|
||||
{personality_desc}
|
||||
|
||||
@@ -333,20 +333,20 @@ class ContentService:
|
||||
if json_text.endswith("```"):
|
||||
json_text = json_text[:-3]
|
||||
json_text = json_text.strip()
|
||||
|
||||
|
||||
data = json5.loads(json_text)
|
||||
story_text = data.get("text", "")
|
||||
image_info = data.get("image", {})
|
||||
|
||||
|
||||
# 确保图片信息完整
|
||||
if not isinstance(image_info, dict):
|
||||
image_info = {}
|
||||
|
||||
|
||||
logger.info(f"成功生成说说:'{story_text}'")
|
||||
logger.info(f"配图信息: {image_info}")
|
||||
|
||||
|
||||
return story_text, image_info
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解析JSON失败: {e}, 原始响应: {response[:200]}")
|
||||
# 降级处理:只返回文本,空配图信息
|
||||
|
||||
@@ -42,7 +42,7 @@ class ImageService:
|
||||
try:
|
||||
api_key = str(self.get_config("siliconflow.api_key", ""))
|
||||
image_num = self.get_config("ai_image.image_number", 1)
|
||||
|
||||
|
||||
if not api_key:
|
||||
logger.warning("硅基流动API未配置,跳过图片生成")
|
||||
return False, None
|
||||
@@ -237,7 +237,7 @@ class ImageService:
|
||||
image.save(save_path, format="PNG")
|
||||
logger.info(f"图片已保存至: {save_path}")
|
||||
success_count += 1
|
||||
|
||||
|
||||
# 记录第一张图片路径
|
||||
if first_img_path is None:
|
||||
first_img_path = save_path
|
||||
|
||||
@@ -2,14 +2,11 @@
|
||||
NovelAI图片生成服务 - 空间插件专用
|
||||
独立实现,不依赖其他插件
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
import io
|
||||
import random
|
||||
import uuid
|
||||
import zipfile
|
||||
import io
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp
|
||||
from PIL import Image
|
||||
@@ -21,50 +18,50 @@ logger = get_logger("MaiZone.NovelAIService")
|
||||
|
||||
class MaiZoneNovelAIService:
|
||||
"""空间插件的NovelAI图片生成服务(独立实现)"""
|
||||
|
||||
|
||||
def __init__(self, get_config):
|
||||
self.get_config = get_config
|
||||
|
||||
|
||||
# NovelAI配置
|
||||
self.api_key = self.get_config("novelai.api_key", "")
|
||||
self.base_url = "https://image.novelai.net/ai/generate-image"
|
||||
self.model = "nai-diffusion-4-5-full"
|
||||
|
||||
|
||||
# 代理配置
|
||||
proxy_host = self.get_config("novelai.proxy_host", "")
|
||||
proxy_port = self.get_config("novelai.proxy_port", 0)
|
||||
self.proxy = f"http://{proxy_host}:{proxy_port}" if proxy_host and proxy_port else ""
|
||||
|
||||
|
||||
# 生成参数
|
||||
self.steps = 28
|
||||
self.scale = 5.0
|
||||
self.sampler = "k_euler"
|
||||
self.noise_schedule = "karras"
|
||||
|
||||
|
||||
# 角色提示词(当LLM决定包含角色时使用)
|
||||
self.character_prompt = self.get_config("novelai.character_prompt", "")
|
||||
self.base_negative_prompt = self.get_config("novelai.base_negative_prompt", "nsfw, nude, explicit, sexual content, lowres, bad anatomy, bad hands")
|
||||
|
||||
|
||||
# 图片保存目录(使用统一配置)
|
||||
plugin_dir = Path(__file__).parent.parent
|
||||
self.image_dir = plugin_dir / "images"
|
||||
self.image_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
if self.api_key:
|
||||
logger.info(f"NovelAI图片生成已配置,模型: {self.model}")
|
||||
|
||||
|
||||
def is_available(self) -> bool:
|
||||
"""检查NovelAI服务是否可用"""
|
||||
return bool(self.api_key)
|
||||
|
||||
|
||||
async def generate_image_from_prompt_data(
|
||||
self,
|
||||
prompt: str,
|
||||
negative_prompt: Optional[str] = None,
|
||||
negative_prompt: str | None = None,
|
||||
include_character: bool = False,
|
||||
width: int = 1024,
|
||||
height: int = 1024
|
||||
) -> tuple[bool, Optional[Path], str]:
|
||||
) -> tuple[bool, Path | None, str]:
|
||||
"""根据提示词生成图片
|
||||
|
||||
Args:
|
||||
@@ -79,14 +76,14 @@ class MaiZoneNovelAIService:
|
||||
"""
|
||||
if not self.api_key:
|
||||
return False, None, "NovelAI API Key未配置"
|
||||
|
||||
|
||||
try:
|
||||
# 处理角色提示词
|
||||
final_prompt = prompt
|
||||
if include_character and self.character_prompt:
|
||||
final_prompt = f"{self.character_prompt}, {prompt}"
|
||||
logger.info(f"包含角色形象,添加角色提示词")
|
||||
|
||||
logger.info("包含角色形象,添加角色提示词")
|
||||
|
||||
# 合并负面提示词
|
||||
final_negative = self.base_negative_prompt
|
||||
if negative_prompt:
|
||||
@@ -94,37 +91,37 @@ class MaiZoneNovelAIService:
|
||||
final_negative = f"{final_negative}, {negative_prompt}"
|
||||
else:
|
||||
final_negative = negative_prompt
|
||||
|
||||
logger.info(f"🎨 开始生成图片...")
|
||||
|
||||
logger.info("🎨 开始生成图片...")
|
||||
logger.info(f" 尺寸: {width}x{height}")
|
||||
logger.info(f" 正面提示词: {final_prompt[:100]}...")
|
||||
logger.info(f" 负面提示词: {final_negative[:100]}...")
|
||||
|
||||
|
||||
# 构建请求payload
|
||||
payload = self._build_payload(final_prompt, final_negative, width, height)
|
||||
|
||||
|
||||
# 发送请求
|
||||
image_data = await self._call_novelai_api(payload)
|
||||
if not image_data:
|
||||
return False, None, "API请求失败"
|
||||
|
||||
|
||||
# 保存图片
|
||||
image_path = await self._save_image(image_data)
|
||||
if not image_path:
|
||||
return False, None, "图片保存失败"
|
||||
|
||||
|
||||
logger.info(f"✅ 图片生成成功: {image_path}")
|
||||
return True, image_path, "生成成功"
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"生成图片时出错: {e}", exc_info=True)
|
||||
return False, None, f"生成失败: {str(e)}"
|
||||
|
||||
return False, None, f"生成失败: {e!s}"
|
||||
|
||||
def _build_payload(self, prompt: str, negative_prompt: str, width: int, height: int) -> dict:
|
||||
"""构建NovelAI API请求payload"""
|
||||
is_v4_model = "diffusion-4" in self.model
|
||||
is_v3_model = "diffusion-3" in self.model
|
||||
|
||||
|
||||
parameters = {
|
||||
"width": width,
|
||||
"height": height,
|
||||
@@ -139,7 +136,7 @@ class MaiZoneNovelAIService:
|
||||
"sm_dyn": False,
|
||||
"noise_schedule": self.noise_schedule if is_v4_model else "native",
|
||||
}
|
||||
|
||||
|
||||
# V4.5模型使用新格式
|
||||
if is_v4_model:
|
||||
parameters.update({
|
||||
@@ -183,39 +180,39 @@ class MaiZoneNovelAIService:
|
||||
# V3使用negative_prompt字段
|
||||
elif is_v3_model:
|
||||
parameters["negative_prompt"] = negative_prompt
|
||||
|
||||
|
||||
payload = {
|
||||
"input": prompt,
|
||||
"model": self.model,
|
||||
"action": "generate",
|
||||
"parameters": parameters
|
||||
}
|
||||
|
||||
|
||||
# V4.5需要额外字段
|
||||
if is_v4_model:
|
||||
payload["use_new_shared_trial"] = True
|
||||
|
||||
|
||||
return payload
|
||||
|
||||
async def _call_novelai_api(self, payload: dict) -> Optional[bytes]:
|
||||
|
||||
async def _call_novelai_api(self, payload: dict) -> bytes | None:
|
||||
"""调用NovelAI API"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
connector = None
|
||||
request_kwargs = {
|
||||
"json": payload,
|
||||
"headers": headers,
|
||||
"timeout": aiohttp.ClientTimeout(total=120)
|
||||
}
|
||||
|
||||
|
||||
if self.proxy:
|
||||
request_kwargs["proxy"] = self.proxy
|
||||
connector = aiohttp.TCPConnector()
|
||||
logger.info(f"使用代理: {self.proxy}")
|
||||
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession(connector=connector) as session:
|
||||
async with session.post(self.base_url, **request_kwargs) as resp:
|
||||
@@ -223,31 +220,31 @@ class MaiZoneNovelAIService:
|
||||
error_text = await resp.text()
|
||||
logger.error(f"API请求失败 ({resp.status}): {error_text[:200]}")
|
||||
return None
|
||||
|
||||
|
||||
img_data = await resp.read()
|
||||
logger.info(f"收到响应数据: {len(img_data)} bytes")
|
||||
|
||||
|
||||
# 检查是否是ZIP文件
|
||||
if img_data[:4] == b'PK\x03\x04':
|
||||
if img_data[:4] == b"PK\x03\x04":
|
||||
logger.info("检测到ZIP格式,解压中...")
|
||||
return self._extract_from_zip(img_data)
|
||||
elif img_data[:4] == b'\x89PNG':
|
||||
elif img_data[:4] == b"\x89PNG":
|
||||
logger.info("检测到PNG格式")
|
||||
return img_data
|
||||
else:
|
||||
logger.warning(f"未知文件格式,前4字节: {img_data[:4].hex()}")
|
||||
return img_data
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"API调用失败: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
def _extract_from_zip(self, zip_data: bytes) -> Optional[bytes]:
|
||||
|
||||
def _extract_from_zip(self, zip_data: bytes) -> bytes | None:
|
||||
"""从ZIP中提取PNG"""
|
||||
try:
|
||||
with zipfile.ZipFile(io.BytesIO(zip_data)) as zf:
|
||||
for filename in zf.namelist():
|
||||
if filename.lower().endswith('.png'):
|
||||
if filename.lower().endswith(".png"):
|
||||
img_data = zf.read(filename)
|
||||
logger.info(f"从ZIP提取: {filename} ({len(img_data)} bytes)")
|
||||
return img_data
|
||||
@@ -256,20 +253,20 @@ class MaiZoneNovelAIService:
|
||||
except Exception as e:
|
||||
logger.error(f"解压ZIP失败: {e}")
|
||||
return None
|
||||
|
||||
async def _save_image(self, image_data: bytes) -> Optional[Path]:
|
||||
|
||||
async def _save_image(self, image_data: bytes) -> Path | None:
|
||||
"""保存图片到本地"""
|
||||
try:
|
||||
filename = f"novelai_{uuid.uuid4().hex[:12]}.png"
|
||||
filepath = self.image_dir / filename
|
||||
|
||||
|
||||
# 写入文件
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image_data)
|
||||
f.flush()
|
||||
import os
|
||||
os.fsync(f.fileno())
|
||||
|
||||
|
||||
# 验证图片
|
||||
try:
|
||||
with Image.open(filepath) as img:
|
||||
@@ -278,9 +275,9 @@ class MaiZoneNovelAIService:
|
||||
logger.info(f"图片验证成功: {img.format} {img.size}")
|
||||
except Exception as e:
|
||||
logger.warning(f"图片验证失败: {e}")
|
||||
|
||||
|
||||
return filepath
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"保存图片失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -5,7 +5,6 @@ QQ空间服务模块
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
@@ -85,25 +84,25 @@ class QZoneService:
|
||||
async def send_feed(self, topic: str, stream_id: str | None) -> dict[str, Any]:
|
||||
"""发送一条说说(支持AI配图)"""
|
||||
cross_context = await self._get_cross_context()
|
||||
|
||||
|
||||
# 检查是否启用AI配图
|
||||
ai_image_enabled = self.get_config("ai_image.enable_ai_image", False)
|
||||
provider = self.get_config("ai_image.provider", "siliconflow")
|
||||
|
||||
|
||||
image_path = None
|
||||
|
||||
|
||||
if ai_image_enabled:
|
||||
# 启用AI配图:文本模型生成说说+图片提示词
|
||||
story, image_info = await self.content_service.generate_story_with_image_info(topic, context=cross_context)
|
||||
if not story:
|
||||
return {"success": False, "message": "生成说说内容失败"}
|
||||
|
||||
|
||||
# 根据provider调用对应的生图服务
|
||||
if provider == "novelai":
|
||||
try:
|
||||
from .novelai_service import MaiZoneNovelAIService
|
||||
novelai_service = MaiZoneNovelAIService(self.get_config)
|
||||
|
||||
|
||||
if novelai_service.is_available():
|
||||
# 解析画幅
|
||||
aspect_ratio = image_info.get("aspect_ratio", "方图")
|
||||
@@ -113,8 +112,8 @@ class QZoneService:
|
||||
"竖图": (832, 1216),
|
||||
}
|
||||
width, height = size_map.get(aspect_ratio, (1024, 1024))
|
||||
|
||||
logger.info(f"🎨 开始生成NovelAI配图...")
|
||||
|
||||
logger.info("🎨 开始生成NovelAI配图...")
|
||||
success, img_path, msg = await novelai_service.generate_image_from_prompt_data(
|
||||
prompt=image_info.get("prompt", ""),
|
||||
negative_prompt=image_info.get("negative_prompt"),
|
||||
@@ -122,18 +121,18 @@ class QZoneService:
|
||||
width=width,
|
||||
height=height
|
||||
)
|
||||
|
||||
|
||||
if success and img_path:
|
||||
image_path = img_path
|
||||
logger.info(f"✅ NovelAI配图生成成功")
|
||||
logger.info("✅ NovelAI配图生成成功")
|
||||
else:
|
||||
logger.warning(f"⚠️ NovelAI配图生成失败: {msg}")
|
||||
else:
|
||||
logger.warning("NovelAI服务不可用(未配置API Key)")
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"NovelAI配图生成出错: {e}", exc_info=True)
|
||||
|
||||
|
||||
elif provider == "siliconflow":
|
||||
try:
|
||||
# 调用硅基流动生成图片
|
||||
@@ -143,9 +142,9 @@ class QZoneService:
|
||||
)
|
||||
if success and img_path:
|
||||
image_path = img_path
|
||||
logger.info(f"✅ 硅基流动配图生成成功")
|
||||
logger.info("✅ 硅基流动配图生成成功")
|
||||
else:
|
||||
logger.warning(f"⚠️ 硅基流动配图生成失败")
|
||||
logger.warning("⚠️ 硅基流动配图生成失败")
|
||||
except Exception as e:
|
||||
logger.error(f"硅基流动配图生成出错: {e}", exc_info=True)
|
||||
else:
|
||||
@@ -161,13 +160,13 @@ class QZoneService:
|
||||
|
||||
# 加载图片
|
||||
images_bytes = []
|
||||
|
||||
|
||||
# 使用AI生成的图片
|
||||
if image_path and image_path.exists():
|
||||
try:
|
||||
with open(image_path, "rb") as f:
|
||||
images_bytes.append(f.read())
|
||||
logger.info(f"添加AI配图到说说")
|
||||
logger.info("添加AI配图到说说")
|
||||
except Exception as e:
|
||||
logger.error(f"读取AI配图失败: {e}")
|
||||
|
||||
|
||||
@@ -416,18 +416,18 @@ class NapcatAdapterPlugin(BasePlugin):
|
||||
"reply_at_rate": ConfigField(type=float, default=0.5, description="回复时@的概率(0.0-1.0)"),
|
||||
# ========== 视频消息处理配置 ==========
|
||||
"enable_video_processing": ConfigField(
|
||||
type=bool,
|
||||
default=True,
|
||||
type=bool,
|
||||
default=True,
|
||||
description="是否启用视频消息处理(下载和解析)。关闭后视频消息将显示为 [视频消息] 占位符,不会进行下载"
|
||||
),
|
||||
"video_max_size_mb": ConfigField(
|
||||
type=int,
|
||||
default=100,
|
||||
type=int,
|
||||
default=100,
|
||||
description="允许下载的视频文件最大大小(MB),超过此大小的视频将被跳过"
|
||||
),
|
||||
"video_download_timeout": ConfigField(
|
||||
type=int,
|
||||
default=60,
|
||||
type=int,
|
||||
default=60,
|
||||
description="视频下载超时时间(秒),若超时将中止下载"
|
||||
),
|
||||
},
|
||||
|
||||
@@ -42,14 +42,14 @@ class MessageHandler:
|
||||
def set_plugin_config(self, config: dict[str, Any]) -> None:
|
||||
"""设置插件配置,并根据配置初始化视频下载器"""
|
||||
self.plugin_config = config
|
||||
|
||||
|
||||
# 如果启用了视频处理,根据配置初始化视频下载器
|
||||
if config_api.get_plugin_config(config, "features.enable_video_processing", True):
|
||||
from ..video_handler import VideoDownloader
|
||||
|
||||
|
||||
max_size = config_api.get_plugin_config(config, "features.video_max_size_mb", 100)
|
||||
timeout = config_api.get_plugin_config(config, "features.video_download_timeout", 60)
|
||||
|
||||
|
||||
self._video_downloader = VideoDownloader(max_size_mb=max_size, download_timeout=timeout)
|
||||
logger.debug(f"视频下载器已初始化: max_size={max_size}MB, timeout={timeout}s")
|
||||
|
||||
@@ -341,7 +341,7 @@ class MessageHandler:
|
||||
if not downloader:
|
||||
from ..video_handler import get_video_downloader
|
||||
downloader = get_video_downloader()
|
||||
|
||||
|
||||
download_result = await downloader.download_video(video_url)
|
||||
|
||||
if not download_result["success"]:
|
||||
|
||||
Reference in New Issue
Block a user