diff --git a/bot.py b/bot.py index 4d5e12a49..d2b9f4b3e 100644 --- a/bot.py +++ b/bot.py @@ -103,7 +103,7 @@ async def graceful_shutdown(main_system_instance): logger.info("正在优雅关闭麦麦...") # 停止MainSystem中的组件,它会处理服务器等 - if main_system_instance and hasattr(main_system_instance, 'shutdown'): + if main_system_instance and hasattr(main_system_instance, "shutdown"): logger.info("正在关闭MainSystem...") await main_system_instance.shutdown() @@ -111,7 +111,7 @@ async def graceful_shutdown(main_system_instance): try: from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() - if hasattr(chat_manager, '_stop_auto_save'): + if hasattr(chat_manager, "_stop_auto_save"): logger.info("正在停止聊天管理器...") chat_manager._stop_auto_save() except Exception as e: @@ -120,7 +120,7 @@ async def graceful_shutdown(main_system_instance): # 停止情绪管理器 try: from src.mood.mood_manager import mood_manager - if hasattr(mood_manager, 'stop'): + if hasattr(mood_manager, "stop"): logger.info("正在停止情绪管理器...") await mood_manager.stop() except Exception as e: @@ -129,7 +129,7 @@ async def graceful_shutdown(main_system_instance): # 停止记忆系统 try: from src.chat.memory_system.memory_manager import memory_manager - if hasattr(memory_manager, 'shutdown'): + if hasattr(memory_manager, "shutdown"): logger.info("正在停止记忆系统...") await memory_manager.shutdown() except Exception as e: diff --git a/plugins/bilibli/bilibli_base.py b/plugins/bilibli/bilibli_base.py index c35538dba..e6418f7c7 100644 --- a/plugins/bilibli/bilibli_base.py +++ b/plugins/bilibli/bilibli_base.py @@ -245,7 +245,7 @@ class BilibiliVideoAnalyzer: logger.exception("详细错误信息:") return None - async def analyze_bilibili_video(self, url: str, prompt: str = None) -> dict[str, Any]: + async def analyze_bilibili_video(self, url: str, prompt: str | None = None) -> dict[str, Any]: """分析哔哩哔哩视频并返回详细信息和AI分析结果""" try: logger.info(f"🎬 开始分析哔哩哔哩视频: {url}") diff --git a/plugins/echo_example/__init__.py b/plugins/echo_example/__init__.py index 8747cc51a..0a78bbfa7 100644 --- a/plugins/echo_example/__init__.py +++ b/plugins/echo_example/__init__.py @@ -7,4 +7,4 @@ __plugin_meta__ = PluginMetadata( version="1.0.0", author="Your Name", license="MIT", -) \ No newline at end of file +) diff --git a/plugins/hello_world_plugin/__init__.py b/plugins/hello_world_plugin/__init__.py index fdca03bcd..f6c7ec72e 100644 --- a/plugins/hello_world_plugin/__init__.py +++ b/plugins/hello_world_plugin/__init__.py @@ -7,4 +7,4 @@ __plugin_meta__ = PluginMetadata( version="1.0.0", author="Your Name", license="MIT", -) \ No newline at end of file +) diff --git a/src/api/__init__.py b/src/api/__init__.py index 359eff96e..a904023b0 100644 --- a/src/api/__init__.py +++ b/src/api/__init__.py @@ -1 +1 @@ -# This file makes src/api a Python package. \ No newline at end of file +# This file makes src/api a Python package. diff --git a/src/api/message_router.py b/src/api/message_router.py index 9595463d1..47ae7771f 100644 --- a/src/api/message_router.py +++ b/src/api/message_router.py @@ -3,10 +3,10 @@ from typing import Literal from fastapi import APIRouter, HTTPException, Query -from src.config.config import global_config -from src.plugin_system.apis import message_api, chat_api, person_api from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger +from src.config.config import global_config +from src.plugin_system.apis import message_api, person_api logger = get_logger("HTTP消息API") @@ -86,7 +86,7 @@ async def get_message_stats_by_chat( if group_by_user: if user_id not in stats[chat_id]["user_stats"]: stats[chat_id]["user_stats"][user_id] = 0 - + stats[chat_id]["user_stats"][user_id] += 1 if not group_by_user: @@ -120,7 +120,7 @@ async def get_message_stats_by_chat( "nickname": nickname, "count": count } - + formatted_stats[chat_id] = formatted_data return formatted_stats @@ -164,7 +164,7 @@ async def get_bot_message_stats_by_chat( chat_name = stream.group_info.group_name elif stream.user_info and stream.user_info.user_nickname: chat_name = stream.user_info.user_nickname - + formatted_stats[chat_id] = { "chat_name": chat_name, "count": count @@ -174,4 +174,4 @@ async def get_bot_message_stats_by_chat( return stats except Exception as e: - raise HTTPException(status_code=500, detail=str(e)) \ No newline at end of file + raise HTTPException(status_code=500, detail=str(e)) diff --git a/src/chat/interest_system/interest_manager.py b/src/chat/interest_system/interest_manager.py index f415c4248..c77ffd25b 100644 --- a/src/chat/interest_system/interest_manager.py +++ b/src/chat/interest_system/interest_manager.py @@ -112,7 +112,7 @@ class InterestManager: # 返回默认结果 return InterestCalculationResult( success=False, - message_id=getattr(message, 'message_id', ''), + message_id=getattr(message, "message_id", ""), interest_value=0.3, error_message="没有可用的兴趣值计算组件" ) @@ -129,7 +129,7 @@ class InterestManager: logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {getattr(message, 'message_id', '')} 使用默认兴趣值 0.5") return InterestCalculationResult( success=True, - message_id=getattr(message, 'message_id', ''), + message_id=getattr(message, "message_id", ""), interest_value=0.5, # 固定默认兴趣值 should_reply=False, should_act=False, @@ -140,9 +140,9 @@ class InterestManager: logger.error(f"兴趣值计算异常: {e}") return InterestCalculationResult( success=False, - message_id=getattr(message, 'message_id', ''), + message_id=getattr(message, "message_id", ""), interest_value=0.3, - error_message=f"计算异常: {str(e)}" + error_message=f"计算异常: {e!s}" ) async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult: @@ -168,9 +168,9 @@ class InterestManager: logger.error(f"兴趣值计算异常: {e}", exc_info=True) return InterestCalculationResult( success=False, - message_id=getattr(message, 'message_id', ''), + message_id=getattr(message, "message_id", ""), interest_value=0.0, - error_message=f"计算异常: {str(e)}", + error_message=f"计算异常: {e!s}", calculation_time=time.time() - start_time ) @@ -245,4 +245,4 @@ def get_interest_manager() -> InterestManager: global _interest_manager if _interest_manager is None: _interest_manager = InterestManager() - return _interest_manager \ No newline at end of file + return _interest_manager diff --git a/src/chat/memory_system/hippocampus_sampler.py b/src/chat/memory_system/hippocampus_sampler.py index 0cc6b61d5..fba3e439f 100644 --- a/src/chat/memory_system/hippocampus_sampler.py +++ b/src/chat/memory_system/hippocampus_sampler.py @@ -1,4 +1,3 @@ -# -*- coding: utf-8 -*- """ 海马体双峰分布采样器 基于旧版海马体的采样策略,适配新版记忆系统 @@ -8,16 +7,15 @@ import asyncio import random import time -from datetime import datetime, timedelta -from typing import List, Optional, Tuple, Dict, Any from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Any import numpy as np -import orjson from src.chat.utils.chat_message_builder import ( - get_raw_msg_by_timestamp, build_readable_messages, + get_raw_msg_by_timestamp, get_raw_msg_by_timestamp_with_chat, ) from src.chat.utils.utils import translate_timestamp_to_human_readable @@ -47,7 +45,7 @@ class HippocampusSampleConfig: batch_size: int = 5 # 批处理大小 @classmethod - def from_global_config(cls) -> 'HippocampusSampleConfig': + def from_global_config(cls) -> "HippocampusSampleConfig": """从全局配置创建海马体采样配置""" config = global_config.memory.hippocampus_distribution_config return cls( @@ -74,12 +72,12 @@ class HippocampusSampler: self.is_running = False # 记忆构建模型 - self.memory_builder_model: Optional[LLMRequest] = None + self.memory_builder_model: LLMRequest | None = None # 统计信息 self.sample_count = 0 self.success_count = 0 - self.last_sample_results: List[Dict[str, Any]] = [] + self.last_sample_results: list[dict[str, Any]] = [] async def initialize(self): """初始化采样器""" @@ -101,7 +99,7 @@ class HippocampusSampler: logger.error(f"❌ 海马体采样器初始化失败: {e}") raise - def generate_time_samples(self) -> List[datetime]: + def generate_time_samples(self) -> list[datetime]: """生成双峰分布的时间采样点""" # 计算每个分布的样本数 recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight)) @@ -132,7 +130,7 @@ class HippocampusSampler: # 按时间排序(从最早到最近) return sorted(timestamps) - async def collect_message_samples(self, target_timestamp: float) -> Optional[List[Dict[str, Any]]]: + async def collect_message_samples(self, target_timestamp: float) -> list[dict[str, Any]] | None: """收集指定时间戳附近的消息样本""" try: # 随机时间窗口:5-30分钟 @@ -190,7 +188,7 @@ class HippocampusSampler: logger.error(f"收集消息样本失败: {e}") return None - async def build_memory_from_samples(self, messages: List[Dict[str, Any]], target_timestamp: float) -> Optional[str]: + async def build_memory_from_samples(self, messages: list[dict[str, Any]], target_timestamp: float) -> str | None: """从消息样本构建记忆""" if not messages or not self.memory_system or not self.memory_builder_model: return None @@ -262,7 +260,7 @@ class HippocampusSampler: logger.error(f"海马体采样构建记忆失败: {e}") return None - async def perform_sampling_cycle(self) -> Dict[str, Any]: + async def perform_sampling_cycle(self) -> dict[str, Any]: """执行一次完整的采样周期(优化版:批量融合构建)""" if not self.should_sample(): return {"status": "skipped", "reason": "interval_not_met"} @@ -363,7 +361,7 @@ class HippocampusSampler: "duration": time.time() - start_time, } - async def _collect_all_message_samples(self, time_samples: List[datetime]) -> List[List[Dict[str, Any]]]: + async def _collect_all_message_samples(self, time_samples: list[datetime]) -> list[list[dict[str, Any]]]: """批量收集所有时间点的消息样本""" collected_messages = [] max_concurrent = min(5, len(time_samples)) # 提高并发数到5 @@ -394,7 +392,7 @@ class HippocampusSampler: return collected_messages - async def _fuse_and_deduplicate_messages(self, collected_messages: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]: + async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]: """融合和去重消息样本""" if not collected_messages: return [] @@ -450,7 +448,7 @@ class HippocampusSampler: # 返回原始消息组作为备选 return collected_messages[:5] # 限制返回数量 - def _merge_adjacent_messages(self, messages: List[Dict[str, Any]], time_gap: int = 1800) -> List[List[Dict[str, Any]]]: + def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]: """合并时间间隔内的消息""" if not messages: return [] @@ -481,7 +479,7 @@ class HippocampusSampler: return result_groups - async def _build_batch_memory(self, fused_messages: List[List[Dict[str, Any]]], time_samples: List[datetime]) -> Dict[str, Any]: + async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]: """批量构建记忆""" if not fused_messages: return {"memory_count": 0, "memories": []} @@ -557,7 +555,7 @@ class HippocampusSampler: logger.error(f"批量构建记忆失败: {e}") return {"memory_count": 0, "error": str(e)} - async def _build_fused_conversation_text(self, fused_messages: List[List[Dict[str, Any]]]) -> str: + async def _build_fused_conversation_text(self, fused_messages: list[list[dict[str, Any]]]) -> str: """构建融合后的对话文本""" try: # 添加批次标识 @@ -589,7 +587,7 @@ class HippocampusSampler: logger.error(f"构建融合文本失败: {e}") return "" - async def _fallback_individual_build(self, fused_messages: List[List[Dict[str, Any]]]) -> Dict[str, Any]: + async def _fallback_individual_build(self, fused_messages: list[list[dict[str, Any]]]) -> dict[str, Any]: """备选方案:单独构建每个消息组""" total_memories = [] total_count = 0 @@ -609,7 +607,7 @@ class HippocampusSampler: "fallback_mode": True } - async def process_sample_timestamp(self, target_timestamp: float) -> Optional[str]: + async def process_sample_timestamp(self, target_timestamp: float) -> str | None: """处理单个时间戳采样(保留作为备选方法)""" try: # 收集消息样本 @@ -676,7 +674,7 @@ class HippocampusSampler: self.is_running = False logger.info("🛑 停止海马体后台采样任务") - def get_sampling_stats(self) -> Dict[str, Any]: + def get_sampling_stats(self) -> dict[str, Any]: """获取采样统计信息""" success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0 @@ -713,7 +711,7 @@ class HippocampusSampler: # 全局海马体采样器实例 -_hippocampus_sampler: Optional[HippocampusSampler] = None +_hippocampus_sampler: HippocampusSampler | None = None def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler: @@ -728,4 +726,4 @@ async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampl """初始化全局海马体采样器""" sampler = get_hippocampus_sampler(memory_system) await sampler.initialize() - return sampler \ No newline at end of file + return sampler diff --git a/src/chat/memory_system/memory_builder.py b/src/chat/memory_system/memory_builder.py index 69fe7432f..764896a0c 100644 --- a/src/chat/memory_system/memory_builder.py +++ b/src/chat/memory_system/memory_builder.py @@ -32,7 +32,7 @@ import time from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Any, Type, TypeVar +from typing import Any, TypeVar E = TypeVar("E", bound=Enum) @@ -503,7 +503,7 @@ class MemoryBuilder: logger.warning(f"无法解析未知的记忆类型 '{type_str}',回退到上下文类型") return MemoryType.CONTEXTUAL - def _parse_enum_value(self, enum_cls: Type[E], raw_value: Any, default: E, field_name: str) -> E: + def _parse_enum_value(self, enum_cls: type[E], raw_value: Any, default: E, field_name: str) -> E: """解析枚举值,兼容数字/字符串表示""" if isinstance(raw_value, enum_cls): return raw_value diff --git a/src/chat/memory_system/memory_fusion.py b/src/chat/memory_system/memory_fusion.py index 59f36ed93..6e384ca8c 100644 --- a/src/chat/memory_system/memory_fusion.py +++ b/src/chat/memory_system/memory_fusion.py @@ -215,8 +215,8 @@ class MemoryFusionEngine: if not keywords1 or not keywords2: return 0.0 - set1 = set(k.lower() for k in keywords1) - set2 = set(k.lower() for k in keywords2) + set1 = set(k.lower() for k in keywords1) # noqa: C401 + set2 = set(k.lower() for k in keywords2) # noqa: C401 intersection = set1 & set2 union = set1 | set2 diff --git a/src/chat/memory_system/memory_manager.py b/src/chat/memory_system/memory_manager.py index 1ba79fe59..dd627c084 100644 --- a/src/chat/memory_system/memory_manager.py +++ b/src/chat/memory_system/memory_manager.py @@ -69,14 +69,11 @@ class MemoryManager: # 初始化记忆系统 self.memory_system = await initialize_memory_system(llm_model) - # 设置全局实例 - global_memory_manager = self.memory_system - self.is_initialized = True - logger.info("✅ 记忆系统初始化完成") + logger.info(" 记忆系统初始化完成") except Exception as e: - logger.error(f"❌ 记忆系统初始化失败: {e}") + logger.error(f"记忆系统初始化失败: {e}") # 如果系统初始化失败,创建一个空的管理器避免系统崩溃 self.memory_system = None self.is_initialized = True # 标记为已初始化但系统不可用 @@ -439,7 +436,7 @@ class MemoryManager: formatted_items = [self._format_object(item) for item in obj] filtered = [item for item in formatted_items if item] return self._clean_text("、".join(filtered)) if filtered else "" - if isinstance(obj, (int, float)): + if isinstance(obj, int | float): return str(obj) text = self._truncate(str(obj).strip()) return self._clean_text(text) @@ -449,12 +446,12 @@ class MemoryManager: for key in keys: if obj.get(key): value = obj[key] - if isinstance(value, (dict, list)): + if isinstance(value, dict | list): return self._clean_text(self._format_object(value)) return self._clean_text(value) if isinstance(obj, list) and obj: return self._clean_text(self._format_object(obj[0])) - if isinstance(obj, (str, int, float)): + if isinstance(obj, str | int | float): return self._clean_text(obj) return None @@ -471,7 +468,7 @@ class MemoryManager: try: if self.memory_system: await self.memory_system.shutdown() - logger.info("✅ 记忆系统已关闭") + logger.info(" 记忆系统已关闭") except Exception as e: logger.error(f"关闭记忆系统失败: {e}") diff --git a/src/chat/memory_system/memory_system.py b/src/chat/memory_system/memory_system.py index 6dc228f76..a2c0a0e83 100644 --- a/src/chat/memory_system/memory_system.py +++ b/src/chat/memory_system/memory_system.py @@ -19,6 +19,8 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio from src.chat.memory_system.memory_chunk import MemoryChunk from src.chat.memory_system.memory_fusion import MemoryFusionEngine from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner + + # 记忆采样模式枚举 class MemorySamplingMode(Enum): """记忆采样模式""" @@ -31,9 +33,10 @@ from src.llm_models.utils_model import LLMRequest if TYPE_CHECKING: from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine + from src.chat.memory_system.vector_memory_storage_v2 import VectorMemoryStorage from src.common.data_models.database_data_model import DatabaseMessages -logger = get_logger(__name__) +logger = get_logger("memory_system") # 全局记忆作用域(共享记忆库) GLOBAL_MEMORY_SCOPE = "global" @@ -133,15 +136,15 @@ class MemorySystem: self.status = MemorySystemStatus.INITIALIZING # 核心组件(简化版) - self.memory_builder: MemoryBuilder = None - self.fusion_engine: MemoryFusionEngine = None - self.unified_storage = None # 统一存储系统 - self.query_planner: MemoryQueryPlanner = None + self.memory_builder: MemoryBuilder | None = None + self.fusion_engine: MemoryFusionEngine | None = None + self.unified_storage: VectorMemoryStorage | None = None # 统一存储系统 + self.query_planner: MemoryQueryPlanner | None = None self.forgetting_engine: MemoryForgettingEngine | None = None # LLM模型 - self.value_assessment_model: LLMRequest = None - self.memory_extraction_model: LLMRequest = None + self.value_assessment_model: LLMRequest | None = None + self.memory_extraction_model: LLMRequest | None = None # 统计信息 self.total_memories = 0 @@ -162,7 +165,6 @@ class MemorySystem: async def initialize(self): """异步初始化记忆系统""" try: - logger.info("正在初始化记忆系统...") # 初始化LLM模型 fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None @@ -249,7 +251,7 @@ class MemorySystem: self.forgetting_engine = MemoryForgettingEngine(forgetting_config) - planner_task_config = getattr(model_config.model_task_config, "utils_small", None) + planner_task_config = model_config.model_task_config.utils_small planner_model: LLMRequest | None = None try: planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner") @@ -269,10 +271,8 @@ class MemorySystem: self.hippocampus_sampler = None # 统一存储已经自动加载数据,无需额外加载 - logger.info("✅ 简化版记忆系统初始化完成") self.status = MemorySystemStatus.READY - logger.info("✅ 记忆系统初始化完成") except Exception as e: self.status = MemorySystemStatus.ERROR @@ -479,7 +479,7 @@ class MemorySystem: existing_id = self._memory_fingerprints.get(fingerprint_key) if existing_id and existing_id not in new_memory_ids: candidate_ids.add(existing_id) - except Exception as exc: + except Exception as exc: # noqa: PERF203 logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc) # 基于主体索引的候选(使用统一存储) @@ -557,11 +557,11 @@ class MemorySystem: context = dict(context or {}) # 获取配置的采样模式 - sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'precision') + sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision") current_mode = MemorySamplingMode(sampling_mode) - context['__sampling_mode'] = current_mode.value + context["__sampling_mode"] = current_mode.value logger.debug(f"使用记忆采样模式: {current_mode.value}") # 根据采样模式处理记忆 @@ -637,7 +637,7 @@ class MemorySystem: # 检查信息价值阈值 value_score = await self._assess_information_value(conversation_text, normalized_context) - threshold = getattr(global_config.memory, 'precision_memory_reply_threshold', 0.5) + threshold = getattr(global_config.memory, "precision_memory_reply_threshold", 0.5) if value_score < threshold: logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建") @@ -843,7 +843,7 @@ class MemorySystem: for i, (mem, score, details) in enumerate(scored_memories[:3], 1): try: summary = mem.content[:60] if hasattr(mem, "content") and mem.content else "" - except: + except Exception: summary = "" logger.info( f" #{i} | final={details['final']:.3f} " @@ -1440,8 +1440,8 @@ class MemorySystem: context_keywords = context.get("keywords") or [] keyword_overlap = 0.0 if context_keywords: - memory_keywords = set(k.lower() for k in memory.keywords) - keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max( + memory_keywords = {k.lower() for k in memory.keywords} + keyword_overlap = len(memory_keywords & {k.lower() for k in context_keywords}) / max( len(context_keywords), 1 ) @@ -1489,7 +1489,7 @@ class MemorySystem: """启动海马体采样""" if self.hippocampus_sampler: asyncio.create_task(self.hippocampus_sampler.start_background_sampling()) - logger.info("🚀 海马体后台采样已启动") + logger.info("海马体后台采样已启动") else: logger.warning("海马体采样器未初始化,无法启动采样") @@ -1497,7 +1497,7 @@ class MemorySystem: """停止海马体采样""" if self.hippocampus_sampler: self.hippocampus_sampler.stop_background_sampling() - logger.info("🛑 海马体后台采样已停止") + logger.info("海马体后台采样已停止") def get_system_stats(self) -> dict[str, Any]: """获取系统统计信息""" @@ -1536,10 +1536,10 @@ class MemorySystem: if self.unified_storage: self.unified_storage.cleanup() - logger.info("✅ 简化记忆系统已关闭") + logger.info("简化记忆系统已关闭") except Exception as e: - logger.error(f"❌ 记忆系统关闭失败: {e}", exc_info=True) + logger.error(f"记忆系统关闭失败: {e}", exc_info=True) async def _rebuild_vector_storage_if_needed(self): """重建向量存储(如果需要)""" @@ -1553,12 +1553,13 @@ class MemorySystem: # 收集需要重建向量的记忆 memories_to_rebuild = [] - for memory_id, memory in self.unified_storage.memory_cache.items(): - # 检查记忆是否有有效的 display 文本 - if memory.display and memory.display.strip(): - memories_to_rebuild.append(memory) - elif memory.text_content and memory.text_content.strip(): - memories_to_rebuild.append(memory) + if self.unified_storage: + for memory in self.unified_storage.memory_cache.values(): + # 检查记忆是否有有效的 display 文本 + if memory.display and memory.display.strip(): + memories_to_rebuild.append(memory) + elif memory.text_content and memory.text_content.strip(): + memories_to_rebuild.append(memory) if not memories_to_rebuild: logger.warning("没有找到可重建向量的记忆") @@ -1583,14 +1584,16 @@ class MemorySystem: logger.error(f"批量重建向量失败: {e}") continue - # 保存重建的向量存储 - await self.unified_storage.save_storage() - - final_count = self.unified_storage.storage_stats.get("total_vectors", 0) - logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}") + # 向量数据在 store_memories 中已保存,此处无需额外操作 + if self.unified_storage: + storage_stats = self.unified_storage.get_storage_stats() + final_count = storage_stats.get("total_vectors", 0) + logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}") + else: + logger.warning("向量存储重建完成,但无法获取最终向量数量,因为存储系统未初始化") except Exception as e: - logger.error(f"❌ 向量存储重建失败: {e}", exc_info=True) + logger.error(f"向量存储重建失败: {e}", exc_info=True) # 全局记忆系统实例 @@ -1613,8 +1616,8 @@ async def initialize_memory_system(llm_model: LLMRequest | None = None): await memory_system.initialize() # 根据配置启动海马体采样 - sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate') - if sampling_mode in ['hippocampus', 'all']: + sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "immediate") + if sampling_mode in ["hippocampus", "all"]: memory_system.start_hippocampus_sampling() return memory_system diff --git a/src/chat/message_manager/adaptive_stream_manager.py b/src/chat/message_manager/adaptive_stream_manager.py index b9fe6ab78..0242d7960 100644 --- a/src/chat/message_manager/adaptive_stream_manager.py +++ b/src/chat/message_manager/adaptive_stream_manager.py @@ -4,14 +4,13 @@ """ import asyncio -import psutil import time -from typing import Dict, List, Optional, Set, Tuple from dataclasses import dataclass, field from enum import Enum +import psutil + from src.common.logger import get_logger -from src.chat.message_receive.chat_stream import ChatStream logger = get_logger("adaptive_stream_manager") @@ -71,16 +70,16 @@ class AdaptiveStreamManager: # 当前状态 self.current_limit = base_concurrent_limit - self.active_streams: Set[str] = set() - self.pending_streams: Set[str] = set() - self.stream_metrics: Dict[str, StreamMetrics] = {} + self.active_streams: set[str] = set() + self.pending_streams: set[str] = set() + self.stream_metrics: dict[str, StreamMetrics] = {} # 异步信号量 self.semaphore = asyncio.Semaphore(base_concurrent_limit) self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量 # 系统监控 - self.system_metrics: List[SystemMetrics] = [] + self.system_metrics: list[SystemMetrics] = [] self.last_adjustment_time = 0.0 # 统计信息 @@ -95,8 +94,8 @@ class AdaptiveStreamManager: } # 监控任务 - self.monitor_task: Optional[asyncio.Task] = None - self.adjustment_task: Optional[asyncio.Task] = None + self.monitor_task: asyncio.Task | None = None + self.adjustment_task: asyncio.Task | None = None self.is_running = False logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})") @@ -443,7 +442,7 @@ class AdaptiveStreamManager: if hasattr(metrics, key): setattr(metrics, key, value) - def get_stats(self) -> Dict: + def get_stats(self) -> dict: """获取统计信息""" stats = self.stats.copy() stats.update({ @@ -465,7 +464,7 @@ class AdaptiveStreamManager: # 全局自适应管理器实例 -_adaptive_manager: Optional[AdaptiveStreamManager] = None +_adaptive_manager: AdaptiveStreamManager | None = None def get_adaptive_stream_manager() -> AdaptiveStreamManager: @@ -485,4 +484,4 @@ async def init_adaptive_stream_manager(): async def shutdown_adaptive_stream_manager(): """关闭自适应流管理器""" manager = get_adaptive_stream_manager() - await manager.stop() \ No newline at end of file + await manager.stop() diff --git a/src/chat/message_manager/batch_database_writer.py b/src/chat/message_manager/batch_database_writer.py index de40efa51..cb8e87c3d 100644 --- a/src/chat/message_manager/batch_database_writer.py +++ b/src/chat/message_manager/batch_database_writer.py @@ -5,9 +5,9 @@ import asyncio import time -from typing import Any, Dict, List, Optional -from dataclasses import dataclass, field from collections import defaultdict +from dataclasses import dataclass, field +from typing import Any from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_models import ChatStreams @@ -21,7 +21,7 @@ logger = get_logger("batch_database_writer") class StreamUpdatePayload: """流更新数据结构""" stream_id: str - update_data: Dict[str, Any] + update_data: dict[str, Any] priority: int = 0 # 优先级,数字越大优先级越高 timestamp: float = field(default_factory=time.time) @@ -47,7 +47,7 @@ class BatchDatabaseWriter: # 运行状态 self.is_running = False - self.writer_task: Optional[asyncio.Task] = None + self.writer_task: asyncio.Task | None = None # 统计信息 self.stats = { @@ -60,7 +60,7 @@ class BatchDatabaseWriter: } # 按优先级分类的批次 - self.priority_batches: Dict[int, List[StreamUpdatePayload]] = defaultdict(list) + self.priority_batches: dict[int, list[StreamUpdatePayload]] = defaultdict(list) logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)") @@ -98,7 +98,7 @@ class BatchDatabaseWriter: async def schedule_stream_update( self, stream_id: str, - update_data: Dict[str, Any], + update_data: dict[str, Any], priority: int = 0 ) -> bool: """ @@ -166,7 +166,7 @@ class BatchDatabaseWriter: await self._flush_all_batches() logger.info("批量写入循环结束") - async def _collect_batch(self) -> List[StreamUpdatePayload]: + async def _collect_batch(self) -> list[StreamUpdatePayload]: """收集一个批次的数据""" batch = [] deadline = time.time() + self.flush_interval @@ -189,7 +189,7 @@ class BatchDatabaseWriter: return batch - async def _write_batch(self, batch: List[StreamUpdatePayload]): + async def _write_batch(self, batch: list[StreamUpdatePayload]): """批量写入数据库""" if not batch: return @@ -228,7 +228,7 @@ class BatchDatabaseWriter: except Exception as single_e: logger.error(f"单个写入也失败: {single_e}") - async def _batch_write_to_database(self, payloads: List[StreamUpdatePayload]): + async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]): """批量写入数据库""" async with get_db_session() as session: for payload in payloads: @@ -268,7 +268,7 @@ class BatchDatabaseWriter: await session.commit() - async def _direct_write(self, stream_id: str, update_data: Dict[str, Any]): + async def _direct_write(self, stream_id: str, update_data: dict[str, Any]): """直接写入数据库(降级方案)""" async with get_db_session() as session: if global_config.database.database_type == "sqlite": @@ -315,7 +315,7 @@ class BatchDatabaseWriter: if remaining_batch: await self._write_batch(remaining_batch) - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """获取统计信息""" stats = self.stats.copy() stats["is_running"] = self.is_running @@ -324,7 +324,7 @@ class BatchDatabaseWriter: # 全局批量写入器实例 -_batch_writer: Optional[BatchDatabaseWriter] = None +_batch_writer: BatchDatabaseWriter | None = None def get_batch_writer() -> BatchDatabaseWriter: @@ -344,4 +344,4 @@ async def init_batch_writer(): async def shutdown_batch_writer(): """关闭批量写入器""" writer = get_batch_writer() - await writer.stop() \ No newline at end of file + await writer.stop() diff --git a/src/chat/message_manager/distribution_manager.py b/src/chat/message_manager/distribution_manager.py index e0fc5899c..fffd699f1 100644 --- a/src/chat/message_manager/distribution_manager.py +++ b/src/chat/message_manager/distribution_manager.py @@ -117,7 +117,7 @@ class StreamLoopManager: # 使用自适应流管理器获取槽位 use_adaptive = False try: - from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager, StreamPriority + from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager adaptive_manager = get_adaptive_stream_manager() if adaptive_manager.is_running: @@ -137,7 +137,7 @@ class StreamLoopManager: else: logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案") else: - logger.debug(f"自适应管理器未运行,使用原始方法") + logger.debug("自适应管理器未运行,使用原始方法") except Exception as e: logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}") diff --git a/src/chat/message_manager/stream_cache_manager.py b/src/chat/message_manager/stream_cache_manager.py index 19f0590d9..3e8cdebac 100644 --- a/src/chat/message_manager/stream_cache_manager.py +++ b/src/chat/message_manager/stream_cache_manager.py @@ -5,13 +5,13 @@ import asyncio import time -from typing import Dict, List, Optional, Set -from dataclasses import dataclass from collections import OrderedDict +from dataclasses import dataclass from maim_message import GroupInfo, UserInfo -from src.common.logger import get_logger + from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream +from src.common.logger import get_logger logger = get_logger("stream_cache_manager") @@ -52,14 +52,14 @@ class TieredStreamCache: # 三层缓存存储 self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据(LRU) - self.warm_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间) - self.cold_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间) + self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间) + self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间) # 统计信息 self.stats = StreamCacheStats() # 清理任务 - self.cleanup_task: Optional[asyncio.Task] = None + self.cleanup_task: asyncio.Task | None = None self.is_running = False logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})") @@ -96,8 +96,8 @@ class TieredStreamCache: stream_id: str, platform: str, user_info: UserInfo, - group_info: Optional[GroupInfo] = None, - data: Optional[Dict] = None, + group_info: GroupInfo | None = None, + data: dict | None = None, ) -> OptimizedChatStream: """获取或创建流 - 优化版本""" current_time = time.time() @@ -255,7 +255,7 @@ class TieredStreamCache: hot_to_demote = [] for stream_id, stream in self.hot_cache.items(): # 获取最后访问时间(简化:使用创建时间作为近似) - last_access = getattr(stream, 'last_active_time', stream.create_time) + last_access = getattr(stream, "last_active_time", stream.create_time) if current_time - last_access > self.hot_timeout: hot_to_demote.append(stream_id) @@ -341,7 +341,7 @@ class TieredStreamCache: logger.info("所有缓存已清空") - async def get_stream_snapshot(self, stream_id: str) -> Optional[OptimizedChatStream]: + async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None: """获取流的快照(不修改缓存状态)""" if stream_id in self.hot_cache: return self.hot_cache[stream_id].create_snapshot() @@ -351,13 +351,13 @@ class TieredStreamCache: return self.cold_storage[stream_id][0].create_snapshot() return None - def get_cached_stream_ids(self) -> Set[str]: + def get_cached_stream_ids(self) -> set[str]: """获取所有缓存的流ID""" return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys()) # 全局缓存管理器实例 -_cache_manager: Optional[TieredStreamCache] = None +_cache_manager: TieredStreamCache | None = None def get_stream_cache_manager() -> TieredStreamCache: @@ -377,4 +377,4 @@ async def init_stream_cache_manager(): async def shutdown_stream_cache_manager(): """关闭流缓存管理器""" manager = get_stream_cache_manager() - await manager.stop() \ No newline at end of file + await manager.stop() diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 61c0616b6..c0e68661a 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -313,11 +313,11 @@ class ChatStream: except Exception as e: logger.error(f"计算消息兴趣值失败: {e}", exc_info=True) # 异常情况下使用默认值 - if hasattr(db_message, 'interest_value'): + if hasattr(db_message, "interest_value"): db_message.interest_value = 0.3 - if hasattr(db_message, 'should_reply'): + if hasattr(db_message, "should_reply"): db_message.should_reply = False - if hasattr(db_message, 'should_act'): + if hasattr(db_message, "should_act"): db_message.should_act = False def _extract_reply_from_segment(self, segment) -> str | None: @@ -894,10 +894,10 @@ def _convert_to_original_stream(self, optimized_stream) -> "ChatStream": original_stream.saved = optimized_stream.saved # 复制上下文信息(如果存在) - if hasattr(optimized_stream, '_stream_context') and optimized_stream._stream_context: + if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context: original_stream.stream_context = optimized_stream._stream_context - if hasattr(optimized_stream, '_context_manager') and optimized_stream._context_manager: + if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager: original_stream.context_manager = optimized_stream._context_manager return original_stream diff --git a/src/chat/message_receive/optimized_chat_stream.py b/src/chat/message_receive/optimized_chat_stream.py index 438f4e65c..c9b32d6f8 100644 --- a/src/chat/message_receive/optimized_chat_stream.py +++ b/src/chat/message_receive/optimized_chat_stream.py @@ -3,17 +3,12 @@ 避免不必要的深拷贝开销,提升多流并发性能 """ -import asyncio -import copy -import hashlib import time -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import TYPE_CHECKING, Any from maim_message import GroupInfo, UserInfo from rich.traceback import install -from src.common.database.sqlalchemy_database_api import get_db_session -from src.common.database.sqlalchemy_models import ChatStreams from src.common.logger import get_logger from src.config.config import global_config @@ -28,7 +23,7 @@ logger = get_logger("optimized_chat_stream") class SharedContext: """共享上下文数据 - 只读数据结构""" - def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None): + def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None): self.stream_id = stream_id self.platform = platform self.user_info = user_info @@ -37,7 +32,7 @@ class SharedContext: self._frozen = True def __setattr__(self, name, value): - if hasattr(self, '_frozen') and self._frozen and name not in ['_frozen']: + if hasattr(self, "_frozen") and self._frozen and name not in ["_frozen"]: raise AttributeError(f"SharedContext is frozen, cannot modify {name}") super().__setattr__(name, value) @@ -46,7 +41,7 @@ class LocalChanges: """本地修改跟踪器""" def __init__(self): - self._changes: Dict[str, Any] = {} + self._changes: dict[str, Any] = {} self._dirty = False def set_change(self, key: str, value: Any): @@ -62,7 +57,7 @@ class LocalChanges: """是否有修改""" return self._dirty - def get_changes(self) -> Dict[str, Any]: + def get_changes(self) -> dict[str, Any]: """获取所有修改""" return self._changes.copy() @@ -80,8 +75,8 @@ class OptimizedChatStream: stream_id: str, platform: str, user_info: UserInfo, - group_info: Optional[GroupInfo] = None, - data: Optional[Dict] = None, + group_info: GroupInfo | None = None, + data: dict | None = None, ): # 共享的只读数据 self._shared_context = SharedContext( @@ -129,42 +124,42 @@ class OptimizedChatStream: """修改用户信息时触发写时复制""" self._ensure_copy_on_write() # 由于SharedContext是frozen的,我们需要在本地修改中记录 - self._local_changes.set_change('user_info', value) + self._local_changes.set_change("user_info", value) @property - def group_info(self) -> Optional[GroupInfo]: - if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes: - return self._local_changes.get_change('group_info') + def group_info(self) -> GroupInfo | None: + if self._local_changes.has_changes() and "group_info" in self._local_changes._changes: + return self._local_changes.get_change("group_info") return self._shared_context.group_info @group_info.setter - def group_info(self, value: Optional[GroupInfo]): + def group_info(self, value: GroupInfo | None): """修改群组信息时触发写时复制""" self._ensure_copy_on_write() - self._local_changes.set_change('group_info', value) + self._local_changes.set_change("group_info", value) @property def create_time(self) -> float: - if self._local_changes.has_changes() and 'create_time' in self._local_changes._changes: - return self._local_changes.get_change('create_time') + if self._local_changes.has_changes() and "create_time" in self._local_changes._changes: + return self._local_changes.get_change("create_time") return self._shared_context.create_time @property def last_active_time(self) -> float: - return self._local_changes.get_change('last_active_time', self.create_time) + return self._local_changes.get_change("last_active_time", self.create_time) @last_active_time.setter def last_active_time(self, value: float): - self._local_changes.set_change('last_active_time', value) + self._local_changes.set_change("last_active_time", value) self.saved = False @property def sleep_pressure(self) -> float: - return self._local_changes.get_change('sleep_pressure', 0.0) + return self._local_changes.get_change("sleep_pressure", 0.0) @sleep_pressure.setter def sleep_pressure(self, value: float): - self._local_changes.set_change('sleep_pressure', value) + self._local_changes.set_change("sleep_pressure", value) self.saved = False def _ensure_copy_on_write(self): @@ -176,14 +171,14 @@ class OptimizedChatStream: def _get_effective_user_info(self) -> UserInfo: """获取有效的用户信息""" - if self._local_changes.has_changes() and 'user_info' in self._local_changes._changes: - return self._local_changes.get_change('user_info') + if self._local_changes.has_changes() and "user_info" in self._local_changes._changes: + return self._local_changes.get_change("user_info") return self._shared_context.user_info - def _get_effective_group_info(self) -> Optional[GroupInfo]: + def _get_effective_group_info(self) -> GroupInfo | None: """获取有效的群组信息""" - if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes: - return self._local_changes.get_change('group_info') + if self._local_changes.has_changes() and "group_info" in self._local_changes._changes: + return self._local_changes.get_change("group_info") return self._shared_context.group_info def update_active_time(self): @@ -199,6 +194,7 @@ class OptimizedChatStream: # 将MessageRecv转换为DatabaseMessages并设置到stream_context import json + from src.common.data_models.database_data_model import DatabaseMessages message_info = getattr(message, "message_info", {}) @@ -298,7 +294,7 @@ class OptimizedChatStream: self._create_stream_context() return self._context_manager - def to_dict(self) -> Dict[str, Any]: + def to_dict(self) -> dict[str, Any]: """转换为字典格式 - 考虑本地修改""" user_info = self._get_effective_user_info() group_info = self._get_effective_group_info() @@ -319,7 +315,7 @@ class OptimizedChatStream: } @classmethod - def from_dict(cls, data: Dict) -> "OptimizedChatStream": + def from_dict(cls, data: dict) -> "OptimizedChatStream": """从字典创建实例""" user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None @@ -481,8 +477,8 @@ def create_optimized_chat_stream( stream_id: str, platform: str, user_info: UserInfo, - group_info: Optional[GroupInfo] = None, - data: Optional[Dict] = None, + group_info: GroupInfo | None = None, + data: dict | None = None, ) -> OptimizedChatStream: """创建优化版聊天流实例""" return OptimizedChatStream( @@ -491,4 +487,4 @@ def create_optimized_chat_stream( user_info=user_info, group_info=group_info, data=data - ) \ No newline at end of file + ) diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 1b31cf48b..0fd30456d 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -15,7 +15,7 @@ from src.plugin_system.base.component_types import ActionActivationType, ActionI from src.plugin_system.core.global_announcement_manager import global_announcement_manager if TYPE_CHECKING: - from src.chat.message_receive.chat_stream import ChatStream + pass logger = get_logger("action_manager") diff --git a/src/chat/utils/prompt.py b/src/chat/utils/prompt.py index 5686e2c84..53f11f500 100644 --- a/src/chat/utils/prompt.py +++ b/src/chat/utils/prompt.py @@ -536,7 +536,7 @@ class Prompt: style = expr.get("style", "") if situation and style: formatted_expressions.append(f"- {situation}:{style}") - + if formatted_expressions: style_habits_str = "\n".join(formatted_expressions) expression_habits_block = f"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}" diff --git a/src/chat/utils/typo_generator.py b/src/chat/utils/typo_generator.py index 07a8a9160..8c0ea8e45 100644 --- a/src/chat/utils/typo_generator.py +++ b/src/chat/utils/typo_generator.py @@ -9,8 +9,8 @@ import time from collections import defaultdict from pathlib import Path -import rjieba import orjson +import rjieba from pypinyin import Style, pinyin from src.common.logger import get_logger diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index db73d83d2..55a667c25 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -6,8 +6,8 @@ import time from collections import Counter from typing import Any -import rjieba import numpy as np +import rjieba from maim_message import UserInfo from src.chat.message_receive.chat_stream import get_chat_manager diff --git a/src/chat/utils/utils_video.py b/src/chat/utils/utils_video.py index 6a6fc6245..78ea3a11c 100644 --- a/src/chat/utils/utils_video.py +++ b/src/chat/utils/utils_video.py @@ -1,9 +1,17 @@ #!/usr/bin/env python3 +"""纯 inkfox 视频关键帧分析工具 + +仅依赖 `inkfox.video` 提供的 Rust 扩展能力: + - extract_keyframes_from_video + - get_system_info + +功能: + - 关键帧提取 (base64, timestamp) + - 批量 / 逐帧 LLM 描述 + - 自动模式 (<=3 帧批量,否则逐帧) """ -视频分析器模块 - Rust优化版本 -集成了Rust视频关键帧提取模块,提供高性能的视频分析功能 -支持SIMD优化、多线程处理和智能关键帧检测 -""" + +from __future__ import annotations import asyncio import base64 @@ -13,913 +21,301 @@ import os import tempfile import time from pathlib import Path +from typing import Any -import numpy as np from PIL import Image -from sqlalchemy import select +from sqlalchemy import exc as sa_exc # type: ignore +from sqlalchemy import insert, select, update # type: ignore -from src.common.database.sqlalchemy_models import Videos, get_db_session +from src.common.database.sqlalchemy_models import Videos, get_db_session # type: ignore from src.common.logger import get_logger from src.config.config import global_config, model_config from src.llm_models.utils_model import LLMRequest +# 简易并发控制:同一 hash 只处理一次 +_video_locks: dict[str, asyncio.Lock] = {} +_locks_guard = asyncio.Lock() + logger = get_logger("utils_video") -# Rust模块可用性检测 -RUST_VIDEO_AVAILABLE = False -try: - import rust_video # pyright: ignore[reportMissingImports] - - RUST_VIDEO_AVAILABLE = True - logger.info("✅ Rust 视频处理模块加载成功") -except ImportError as e: - logger.warning(f"⚠️ Rust 视频处理模块加载失败: {e}") - logger.warning("⚠️ 视频识别功能将自动禁用") -except Exception as e: - logger.error(f"❌ 加载Rust模块时发生错误: {e}") - RUST_VIDEO_AVAILABLE = False - -# 全局正在处理的视频哈希集合,用于防止重复处理 -processing_videos = set() -processing_lock = asyncio.Lock() -# 为每个视频hash创建独立的锁和事件 -video_locks = {} -video_events = {} -video_lock_manager = asyncio.Lock() +from inkfox import video class VideoAnalyzer: - """优化的视频分析器类""" + """基于 inkfox 的视频关键帧 + LLM 描述分析器""" - def __init__(self): - """初始化视频分析器""" - # 检查是否有任何可用的视频处理实现 - opencv_available = False + def __init__(self) -> None: + cfg = getattr(global_config, "video_analysis", object()) + self.max_frames: int = getattr(cfg, "max_frames", 20) + self.frame_quality: int = getattr(cfg, "frame_quality", 85) + self.max_image_size: int = getattr(cfg, "max_image_size", 600) + self.enable_frame_timing: bool = getattr(cfg, "enable_frame_timing", True) + self.use_simd: bool = getattr(cfg, "rust_use_simd", True) + self.threads: int = getattr(cfg, "rust_threads", 0) + self.ffmpeg_path: str = getattr(cfg, "ffmpeg_path", "ffmpeg") + self.analysis_mode: str = getattr(cfg, "analysis_mode", "auto") + self.frame_analysis_delay: float = 0.3 + + # 人格与提示模板 try: - import cv2 - - opencv_available = True - except ImportError: - pass - - if not RUST_VIDEO_AVAILABLE and not opencv_available: - logger.error("❌ 没有可用的视频处理实现,视频分析器将被禁用") - self.disabled = True - return - elif not RUST_VIDEO_AVAILABLE: - logger.warning("⚠️ Rust视频处理模块不可用,将使用Python降级实现") - elif not opencv_available: - logger.warning("⚠️ OpenCV不可用,仅支持Rust关键帧模式") - - self.disabled = False - - # 使用专用的视频分析配置 - try: - self.video_llm = LLMRequest( - model_set=model_config.model_task_config.video_analysis, request_type="video_analysis" - ) - logger.debug("✅ 使用video_analysis模型配置") - except (AttributeError, KeyError) as e: - # 如果video_analysis不存在,使用vlm配置 - self.video_llm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="vlm") - logger.warning(f"video_analysis配置不可用({e}),回退使用vlm配置") - - # 从配置文件读取参数,如果配置不存在则使用默认值 - config = global_config.video_analysis - - # 使用 getattr 统一获取配置参数,如果配置不存在则使用默认值 - self.max_frames = getattr(config, "max_frames", 6) - self.frame_quality = getattr(config, "frame_quality", 85) - self.max_image_size = getattr(config, "max_image_size", 600) - self.enable_frame_timing = getattr(config, "enable_frame_timing", True) - - # Rust模块相关配置 - self.rust_keyframe_threshold = getattr(config, "rust_keyframe_threshold", 2.0) - self.rust_use_simd = getattr(config, "rust_use_simd", True) - self.rust_block_size = getattr(config, "rust_block_size", 8192) - self.rust_threads = getattr(config, "rust_threads", 0) - self.ffmpeg_path = getattr(config, "ffmpeg_path", "ffmpeg") - - # 从personality配置中获取人格信息 - try: - personality_config = global_config.personality - self.personality_core = getattr(personality_config, "personality_core", "是一个积极向上的女大学生") - self.personality_side = getattr( - personality_config, "personality_side", "用一句话或几句话描述人格的侧面特点" - ) - except AttributeError: - # 如果没有personality配置,使用默认值 + persona = global_config.personality + self.personality_core = getattr(persona, "personality_core", "是一个积极向上的女大学生") + self.personality_side = getattr(persona, "personality_side", "用一句话或几句话描述人格的侧面特点") + except Exception: # pragma: no cover self.personality_core = "是一个积极向上的女大学生" self.personality_side = "用一句话或几句话描述人格的侧面特点" self.batch_analysis_prompt = getattr( - config, + cfg, "batch_analysis_prompt", - """请以第一人称的视角来观看这一个视频,你看到的这些是从视频中按时间顺序提取的关键帧。 - -你的核心人设是:{personality_core}。 -你的人格细节是:{personality_side}。 - -请提供详细的视频内容描述,涵盖以下方面: -1. 视频的整体内容和主题 -2. 主要人物、对象和场景描述 -3. 动作、情节和时间线发展 -4. 视觉风格和艺术特点 -5. 整体氛围和情感表达 -6. 任何特殊的视觉效果或文字内容 - -请用中文回答,结果要详细准确。""", + """请以第一人称视角阅读这些按时间顺序提取的关键帧。\n核心:{personality_core}\n人格:{personality_side}\n请详细描述视频(主题/人物与场景/动作与时间线/视觉风格/情绪氛围/特殊元素)。""", ) - # 新增的线程池配置 - self.use_multiprocessing = getattr(config, "use_multiprocessing", True) - self.max_workers = getattr(config, "max_workers", 2) - self.frame_extraction_mode = getattr(config, "frame_extraction_mode", "fixed_number") - self.frame_interval_seconds = getattr(config, "frame_interval_seconds", 2.0) - - # 将配置文件中的模式映射到内部使用的模式名称 - config_mode = getattr(config, "analysis_mode", "auto") - if config_mode == "batch_frames": - self.analysis_mode = "batch" - elif config_mode == "frame_by_frame": - self.analysis_mode = "sequential" - elif config_mode == "auto": - self.analysis_mode = "auto" - else: - logger.warning(f"无效的分析模式: {config_mode},使用默认的auto模式") - self.analysis_mode = "auto" - - self.frame_analysis_delay = 0.3 # API调用间隔(秒) - self.frame_interval = 1.0 # 抽帧时间间隔(秒) - self.batch_size = 3 # 批处理时每批处理的帧数 - self.timeout = 60.0 # 分析超时时间(秒) - - if config: - logger.debug("✅ 从配置文件读取视频分析参数") - else: - logger.warning("配置文件中缺少video_analysis配置,使用默认值") - - # 系统提示词 - self.system_prompt = "你是一个专业的视频内容分析助手。请仔细观察用户提供的视频关键帧,详细描述视频内容。" - - logger.debug(f"✅ 视频分析器初始化完成,分析模式: {self.analysis_mode}, 线程池: {self.use_multiprocessing}") - - # 获取Rust模块系统信息 - self._log_system_info() - - def _log_system_info(self): - """记录系统信息""" - if not RUST_VIDEO_AVAILABLE: - logger.info("⚠️ Rust模块不可用,跳过系统信息获取") - return - try: - system_info = rust_video.get_system_info() - logger.debug(f"🔧 系统信息: 线程数={system_info.get('threads', '未知')}") - - # 记录CPU特性 - features = [] - if system_info.get("avx2_supported"): - features.append("AVX2") - if system_info.get("sse2_supported"): - features.append("SSE2") - if system_info.get("simd_supported"): - features.append("SIMD") - - if features: - logger.debug(f"🚀 CPU特性: {', '.join(features)}") - else: - logger.debug("⚠️ 未检测到SIMD支持") - - logger.debug(f"📦 Rust模块版本: {system_info.get('version', '未知')}") - - except Exception as e: - logger.warning(f"获取系统信息失败: {e}") - - def _calculate_video_hash(self, video_data: bytes) -> str: - """计算视频文件的hash值""" - hash_obj = hashlib.sha256() - hash_obj.update(video_data) - return hash_obj.hexdigest() - - async def _check_video_exists(self, video_hash: str) -> Videos | None: - """检查视频是否已经分析过""" - try: - async with get_db_session() as session: - if not session: - logger.warning("无法获取数据库会话,跳过视频存在性检查。") - return None - # 明确刷新会话以确保看到其他事务的最新提交 - await session.expire_all() - stmt = select(Videos).where(Videos.video_hash == video_hash) - result = await session.execute(stmt) - return result.scalar_one_or_none() - except Exception as e: - logger.warning(f"检查视频是否存在时出错: {e}") - return None - - async def _store_video_result( - self, video_hash: str, description: str, metadata: dict | None = None - ) -> Videos | None: - """存储视频分析结果到数据库""" - # 检查描述是否为错误信息,如果是则不保存 - if description.startswith("❌"): - logger.warning(f"⚠️ 检测到错误信息,不保存到数据库: {description[:50]}...") - return None - - try: - async with get_db_session() as session: - if not session: - logger.warning("无法获取数据库会话,跳过视频结果存储。") - return None - # 只根据video_hash查找 - stmt = select(Videos).where(Videos.video_hash == video_hash) - result = await session.execute(stmt) - existing_video = result.scalar_one_or_none() - - if existing_video: - # 如果已存在,更新描述和计数 - existing_video.description = description - existing_video.count += 1 - existing_video.timestamp = time.time() - if metadata: - existing_video.duration = metadata.get("duration") - existing_video.frame_count = metadata.get("frame_count") - existing_video.fps = metadata.get("fps") - existing_video.resolution = metadata.get("resolution") - existing_video.file_size = metadata.get("file_size") - await session.commit() - await session.refresh(existing_video) - logger.info(f"✅ 更新已存在的视频记录,hash: {video_hash[:16]}..., count: {existing_video.count}") - return existing_video - else: - video_record = Videos( - video_hash=video_hash, description=description, timestamp=time.time(), count=1 - ) - if metadata: - video_record.duration = metadata.get("duration") - video_record.frame_count = metadata.get("frame_count") - video_record.fps = metadata.get("fps") - video_record.resolution = metadata.get("resolution") - video_record.file_size = metadata.get("file_size") - - session.add(video_record) - await session.commit() - await session.refresh(video_record) - logger.info(f"✅ 新视频分析结果已保存到数据库,hash: {video_hash[:16]}...") - return video_record - except Exception as e: - logger.error(f"❌ 存储视频分析结果时出错: {e}") - return None - - def set_analysis_mode(self, mode: str): - """设置分析模式""" - if mode in ["batch", "sequential", "auto"]: - self.analysis_mode = mode - # logger.info(f"分析模式已设置为: {mode}") - else: - logger.warning(f"无效的分析模式: {mode}") - - async def extract_frames(self, video_path: str) -> list[tuple[str, float]]: - """提取视频帧 - 智能选择最佳实现""" - # 检查是否应该使用Rust实现 - if RUST_VIDEO_AVAILABLE and self.frame_extraction_mode == "keyframe": - # 优先尝试Rust关键帧提取 - try: - return await self._extract_frames_rust_advanced(video_path) - except Exception as e: - logger.warning(f"Rust高级接口失败: {e},尝试基础接口") - try: - return await self._extract_frames_rust(video_path) - except Exception as e2: - logger.warning(f"Rust基础接口也失败: {e2},降级到Python实现") - return await self._extract_frames_python_fallback(video_path) - else: - # 使用Python实现(支持time_interval和fixed_number模式) - if not RUST_VIDEO_AVAILABLE: - logger.info("🔄 Rust模块不可用,使用Python抽帧实现") - else: - logger.info(f"🔄 抽帧模式为 {self.frame_extraction_mode},使用Python抽帧实现") - return await self._extract_frames_python_fallback(video_path) - - async def _extract_frames_rust_advanced(self, video_path: str) -> list[tuple[str, float]]: - """使用 Rust 高级接口的帧提取""" - try: - logger.info("🔄 使用 Rust 高级接口提取关键帧...") - - # 创建 Rust 视频处理器,使用配置参数 - extractor = rust_video.VideoKeyframeExtractor( - ffmpeg_path=self.ffmpeg_path, - threads=self.rust_threads, - verbose=False, # 使用固定值,不需要配置 + self.video_llm = LLMRequest( + model_set=model_config.model_task_config.video_analysis, request_type="video_analysis" ) + except Exception: + self.video_llm = LLMRequest(model_set=model_config.model_task_config.vlm, request_type="vlm") - # 1. 提取所有帧 - frames_data, width, height = extractor.extract_frames( + self._log_system() + + # ---- 系统信息 ---- + def _log_system(self) -> None: + try: + info = video.get_system_info() # type: ignore[attr-defined] + logger.info( + f"inkfox: threads={info.get('threads')} version={info.get('version')} simd={info.get('simd_supported')}" + ) + except Exception as e: # pragma: no cover + logger.debug(f"获取系统信息失败: {e}") + + # ---- 关键帧提取 ---- + async def extract_keyframes(self, video_path: str) -> list[tuple[str, float]]: + """提取关键帧并返回 (base64, timestamp_seconds) 列表""" + with tempfile.TemporaryDirectory() as tmp: + result = video.extract_keyframes_from_video( # type: ignore[attr-defined] video_path=video_path, - max_frames=self.max_frames * 3, # 提取更多帧用于关键帧检测 + output_dir=tmp, + max_keyframes=self.max_frames * 2, # 先多抓一点再截断 + max_save=self.max_frames, + ffmpeg_path=self.ffmpeg_path, + use_simd=self.use_simd, + threads=self.threads, + verbose=False, ) - - logger.info(f"提取到 {len(frames_data)} 帧,视频尺寸: {width}x{height}") - - # 2. 检测关键帧,使用配置参数 - keyframe_indices = extractor.extract_keyframes( - frames=frames_data, - threshold=self.rust_keyframe_threshold, - use_simd=self.rust_use_simd, - block_size=self.rust_block_size, - ) - - logger.info(f"检测到 {len(keyframe_indices)} 个关键帧") - - # 3. 转换选定的关键帧为 base64 - frames = [] - frame_count = 0 - - for idx in keyframe_indices[: self.max_frames]: - if idx < len(frames_data): - try: - frame = frames_data[idx] - frame_data = frame.get_data() - - # 将灰度数据转换为PIL图像 - frame_array = np.frombuffer(frame_data, dtype=np.uint8).reshape((frame.height, frame.width)) - pil_image = Image.fromarray( - frame_array, - mode="L", # 灰度模式 - ) - - # 转换为RGB模式以便保存为JPEG - pil_image = pil_image.convert("RGB") - - # 调整图像大小 - if max(pil_image.size) > self.max_image_size: - ratio = self.max_image_size / max(pil_image.size) - new_size = tuple(int(dim * ratio) for dim in pil_image.size) - pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) - - # 转换为 base64 - buffer = io.BytesIO() - pil_image.save(buffer, format="JPEG", quality=self.frame_quality) - frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") - - # 估算时间戳 - estimated_timestamp = frame.frame_number * (1.0 / 30.0) # 假设30fps - - frames.append((frame_base64, estimated_timestamp)) - frame_count += 1 - - logger.debug( - f"处理关键帧 {frame_count}: 帧号 {frame.frame_number}, 时间 {estimated_timestamp:.2f}s" - ) - - except Exception as e: - logger.error(f"处理关键帧 {idx} 失败: {e}") - continue - - logger.info(f"✅ Rust 高级提取完成: {len(frames)} 关键帧") + files = sorted(Path(tmp).glob("keyframe_*.jpg"))[: self.max_frames] + total_ms = getattr(result, "total_time_ms", 0) + frames: list[tuple[str, float]] = [] + for i, f in enumerate(files): + img = Image.open(f).convert("RGB") + if max(img.size) > self.max_image_size: + scale = self.max_image_size / max(img.size) + img = img.resize((int(img.width * scale), int(img.height * scale)), Image.Resampling.LANCZOS) + buf = io.BytesIO() + img.save(buf, format="JPEG", quality=self.frame_quality) + b64 = base64.b64encode(buf.getvalue()).decode() + ts = (i / max(1, len(files) - 1)) * (total_ms / 1000.0) if total_ms else float(i) + frames.append((b64, ts)) return frames - except Exception as e: - logger.error(f"❌ Rust 高级帧提取失败: {e}") - # 回退到基础方法 - logger.info("回退到基础 Rust 方法") - return await self._extract_frames_rust(video_path) + # ---- 批量分析 ---- + async def _analyze_batch(self, frames: list[tuple[str, float]], question: str | None) -> str: + from src.llm_models.payload_content.message import MessageBuilder + from src.llm_models.utils_model import RequestType - async def _extract_frames_rust(self, video_path: str) -> list[tuple[str, float]]: - """使用 Rust 实现的帧提取""" - try: - logger.info("🔄 使用 Rust 模块提取关键帧...") - - # 创建临时输出目录 - with tempfile.TemporaryDirectory() as temp_dir: - # 使用便捷函数进行关键帧提取,使用配置参数 - result = rust_video.extract_keyframes_from_video( - video_path=video_path, - output_dir=temp_dir, - threshold=self.rust_keyframe_threshold, - max_frames=self.max_frames * 2, # 提取更多帧以便筛选 - max_save=self.max_frames, - ffmpeg_path=self.ffmpeg_path, - use_simd=self.rust_use_simd, - threads=self.rust_threads, - verbose=False, # 使用固定值,不需要配置 - ) - - logger.info( - f"Rust 处理完成: 总帧数 {result.total_frames}, 关键帧 {result.keyframes_extracted}, 处理速度 {result.processing_fps:.1f} FPS" - ) - - # 转换保存的关键帧为 base64 格式 - frames = [] - temp_dir_path = Path(temp_dir) - - # 获取所有保存的关键帧文件 - keyframe_files = sorted(temp_dir_path.glob("keyframe_*.jpg")) - - for i, keyframe_file in enumerate(keyframe_files): - if len(frames) >= self.max_frames: - break - - try: - # 读取关键帧文件 - with open(keyframe_file, "rb") as f: - image_data = f.read() - - # 转换为 PIL 图像并压缩 - pil_image = Image.open(io.BytesIO(image_data)) - - # 调整图像大小 - if max(pil_image.size) > self.max_image_size: - ratio = self.max_image_size / max(pil_image.size) - new_size = tuple(int(dim * ratio) for dim in pil_image.size) - pil_image = pil_image.resize(new_size, Image.Resampling.LANCZOS) - - # 转换为 base64 - buffer = io.BytesIO() - pil_image.save(buffer, format="JPEG", quality=self.frame_quality) - frame_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") - - # 估算时间戳(基于帧索引和总时长) - if result.total_frames > 0: - # 假设关键帧在时间上均匀分布 - estimated_timestamp = (i * result.total_time_ms / 1000.0) / result.keyframes_extracted - else: - estimated_timestamp = i * 1.0 # 默认每秒一帧 - - frames.append((frame_base64, estimated_timestamp)) - - logger.debug(f"处理关键帧 {i + 1}: 估算时间 {estimated_timestamp:.2f}s") - - except Exception as e: - logger.error(f"处理关键帧 {keyframe_file.name} 失败: {e}") - continue - - logger.info(f"✅ Rust 提取完成: {len(frames)} 关键帧") - return frames - - except Exception as e: - logger.error(f"❌ Rust 帧提取失败: {e}") - raise e - - async def _extract_frames_python_fallback(self, video_path: str) -> list[tuple[str, float]]: - """Python降级抽帧实现 - 支持多种抽帧模式""" - try: - # 导入旧版本分析器 - from .utils_video_legacy import get_legacy_video_analyzer - - logger.info("🔄 使用Python降级抽帧实现...") - legacy_analyzer = get_legacy_video_analyzer() - - # 同步配置参数 - legacy_analyzer.max_frames = self.max_frames - legacy_analyzer.frame_quality = self.frame_quality - legacy_analyzer.max_image_size = self.max_image_size - legacy_analyzer.frame_extraction_mode = self.frame_extraction_mode - legacy_analyzer.frame_interval_seconds = self.frame_interval_seconds - legacy_analyzer.use_multiprocessing = self.use_multiprocessing - - # 使用旧版本的抽帧功能 - frames = await legacy_analyzer.extract_frames(video_path) - - logger.info(f"✅ Python降级抽帧完成: {len(frames)} 帧") - return frames - - except Exception as e: - logger.error(f"❌ Python降级抽帧失败: {e}") - return [] - - async def analyze_frames_batch(self, frames: list[tuple[str, float]], user_question: str = None) -> str: - """批量分析所有帧""" - logger.info(f"开始批量分析{len(frames)}帧") - - if not frames: - return "❌ 没有可分析的帧" - - # 构建提示词并格式化人格信息,要不然占位符的那个会爆炸 prompt = self.batch_analysis_prompt.format( personality_core=self.personality_core, personality_side=self.personality_side ) + if question: + prompt += f"\n用户关注: {question}" - if user_question: - prompt += f"\n\n用户问题: {user_question}" + desc = [ + (f"第{i+1}帧 (时间: {ts:.2f}s)" if self.enable_frame_timing else f"第{i+1}帧") + for i, (_b, ts) in enumerate(frames) + ] + prompt += "\n帧列表: " + ", ".join(desc) - # 添加帧信息到提示词 - frame_info = [] - for i, (_frame_base64, timestamp) in enumerate(frames): - if self.enable_frame_timing: - frame_info.append(f"第{i + 1}帧 (时间: {timestamp:.2f}s)") - else: - frame_info.append(f"第{i + 1}帧") + message_builder = MessageBuilder().add_text_content(prompt) + for b64, _ in frames: + message_builder.add_image_content(image_format="jpeg", image_base64=b64) + messages = [message_builder.build()] - prompt += f"\n\n视频包含{len(frames)}帧图像:{', '.join(frame_info)}" - prompt += "\n\n请基于所有提供的帧图像进行综合分析,关注并描述视频的完整内容和故事发展。" - - try: - # 使用多图片分析 - response = await self._analyze_multiple_frames(frames, prompt) - logger.info("✅ 视频识别完成") - return response - - except Exception as e: - logger.error(f"❌ 视频识别失败: {e}") - raise e - - async def _analyze_multiple_frames(self, frames: list[tuple[str, float]], prompt: str) -> str: - """使用多图片分析方法""" - logger.info(f"开始构建包含{len(frames)}帧的分析请求") - - # 导入MessageBuilder用于构建多图片消息 - from src.llm_models.payload_content.message import MessageBuilder, RoleType - from src.llm_models.utils_model import RequestType - - # 构建包含多张图片的消息 - message_builder = MessageBuilder().set_role(RoleType.User).add_text_content(prompt) - - # 添加所有帧图像 - for _i, (frame_base64, _timestamp) in enumerate(frames): - message_builder.add_image_content("jpeg", frame_base64) - # logger.info(f"已添加第{i+1}帧到分析请求 (时间: {timestamp:.2f}s, 图片大小: {len(frame_base64)} chars)") - - message = message_builder.build() - # logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片") - - # 获取模型信息和客户端 - selection_result = self.video_llm._model_selector.select_best_available_model(set(), "response") - if not selection_result: - raise RuntimeError("无法为视频分析选择可用模型。") - model_info, api_provider, client = selection_result - # logger.info(f"使用模型: {model_info.name} 进行多帧分析") - - # 直接执行多图片请求 - api_response = await self.video_llm._executor.execute_request( - api_provider=api_provider, - client=client, - request_type=RequestType.RESPONSE, - model_info=model_info, - message_list=[message], - temperature=None, - max_tokens=None, + # 使用封装好的高级策略执行请求,而不是直接调用内部方法 + response, _ = await self.video_llm._strategy.execute_with_failover( + RequestType.RESPONSE, + raise_when_empty=False, # 即使失败也返回默认值,避免程序崩溃 + message_list=messages, + temperature=self.video_llm.model_for_task.temperature, + max_tokens=self.video_llm.model_for_task.max_tokens, ) - logger.info(f"视频识别完成,响应长度: {len(api_response.content or '')} ") - return api_response.content or "❌ 未获得响应内容" + return response.content or "❌ 未获得响应" - async def analyze_frames_sequential(self, frames: list[tuple[str, float]], user_question: str = None) -> str: - """逐帧分析并汇总""" - logger.info(f"开始逐帧分析{len(frames)}帧") - - frame_analyses = [] - - for i, (frame_base64, timestamp) in enumerate(frames): + # ---- 逐帧分析 ---- + async def _analyze_sequential(self, frames: list[tuple[str, float]], question: str | None) -> str: + results: list[str] = [] + for i, (b64, ts) in enumerate(frames): + prompt = f"分析第{i+1}帧" + (f" (时间: {ts:.2f}s)" if self.enable_frame_timing else "") + if question: + prompt += f"\n关注: {question}" try: - prompt = f"请分析这个视频的第{i + 1}帧" - if self.enable_frame_timing: - prompt += f" (时间: {timestamp:.2f}s)" - prompt += "。描述你看到的内容,包括人物、动作、场景、文字等。" - - if user_question: - prompt += f"\n特别关注: {user_question}" - - response, _ = await self.video_llm.generate_response_for_image( - prompt=prompt, image_base64=frame_base64, image_format="jpeg" + text, _ = await self.video_llm.generate_response_for_image( + prompt=prompt, image_base64=b64, image_format="jpeg" ) - - frame_analyses.append(f"第{i + 1}帧 ({timestamp:.2f}s): {response}") - logger.debug(f"✅ 第{i + 1}帧分析完成") - - # API调用间隔 - if i < len(frames) - 1: - await asyncio.sleep(self.frame_analysis_delay) - - except Exception as e: - logger.error(f"❌ 第{i + 1}帧分析失败: {e}") - frame_analyses.append(f"第{i + 1}帧: 分析失败 - {e}") - - # 生成汇总 - logger.info("开始生成汇总分析") - summary_prompt = f"""基于以下各帧的分析结果,请提供一个完整的视频内容总结: - -{chr(10).join(frame_analyses)} - -请综合所有帧的信息,描述视频的整体内容、故事线、主要元素和特点。""" - - if user_question: - summary_prompt += f"\n特别回答用户的问题: {user_question}" - + results.append(f"第{i+1}帧: {text}") + except Exception as e: # pragma: no cover + results.append(f"第{i+1}帧: 失败 {e}") + if i < len(frames) - 1: + await asyncio.sleep(self.frame_analysis_delay) + summary_prompt = "基于以下逐帧结果给出完整总结:\n\n" + "\n".join(results) try: - # 使用最后一帧进行汇总分析 - if frames: - last_frame_base64, _ = frames[-1] - summary, _ = await self.video_llm.generate_response_for_image( - prompt=summary_prompt, image_base64=last_frame_base64, image_format="jpeg" - ) - logger.info("✅ 逐帧分析和汇总完成") - return summary - else: - return "❌ 没有可用于汇总的帧" - except Exception as e: - logger.error(f"❌ 汇总分析失败: {e}") - # 如果汇总失败,返回各帧分析结果 - return f"视频逐帧分析结果:\n\n{chr(10).join(frame_analyses)}" + final, _ = await self.video_llm.generate_response_for_image( + prompt=summary_prompt, image_base64=frames[-1][0], image_format="jpeg" + ) + return final + except Exception: # pragma: no cover + return "\n".join(results) - async def analyze_video(self, video_path: str, user_question: str = None) -> tuple[bool, str]: - """分析视频的主要方法 - - Returns: - Tuple[bool, str]: (是否成功, 分析结果或错误信息) - """ - if self.disabled: - error_msg = "❌ 视频分析功能已禁用:没有可用的视频处理实现" - logger.warning(error_msg) - return (False, error_msg) - - try: - logger.info(f"开始分析视频: {os.path.basename(video_path)}") - - # 提取帧 - frames = await self.extract_frames(video_path) - if not frames: - error_msg = "❌ 无法从视频中提取有效帧" - return (False, error_msg) - - # 根据模式选择分析方法 - if self.analysis_mode == "auto": - # 智能选择:少于等于3帧用批量,否则用逐帧 - mode = "batch" if len(frames) <= 3 else "sequential" - logger.info(f"自动选择分析模式: {mode} (基于{len(frames)}帧)") - else: - mode = self.analysis_mode - - # 执行分析 - if mode == "batch": - result = await self.analyze_frames_batch(frames, user_question) - else: # sequential - result = await self.analyze_frames_sequential(frames, user_question) - - logger.info("✅ 视频分析完成") - return (True, result) - - except Exception as e: - error_msg = f"❌ 视频分析失败: {e!s}" - logger.error(error_msg) - return (False, error_msg) + # ---- 主入口 ---- + async def analyze_video(self, video_path: str, question: str | None = None) -> tuple[bool, str]: + if not os.path.exists(video_path): + return False, "❌ 文件不存在" + frames = await self.extract_keyframes(video_path) + if not frames: + return False, "❌ 未提取到关键帧" + mode = self.analysis_mode + if mode == "auto": + mode = "batch" if len(frames) <= 20 else "sequential" + text = await (self._analyze_batch(frames, question) if mode == "batch" else self._analyze_sequential(frames, question)) + return True, text async def analyze_video_from_bytes( - self, video_bytes: bytes, filename: str = None, user_question: str = None, prompt: str = None + self, + video_bytes: bytes, + filename: str | None = None, + prompt: str | None = None, + question: str | None = None, ) -> dict[str, str]: - """从字节数据分析视频 + """从内存字节分析视频,兼容旧调用 (prompt / question 二选一) 返回 {"summary": str}.""" + if not video_bytes: + return {"summary": "❌ 空视频数据"} + # 兼容参数:prompt 优先,其次 question + q = prompt if prompt is not None else question + video_hash = hashlib.sha256(video_bytes).hexdigest() - Args: - video_bytes: 视频字节数据 - filename: 文件名(可选,仅用于日志) - user_question: 用户问题(旧参数名,保持兼容性) - prompt: 提示词(新参数名,与系统调用保持一致) + # 查缓存(第一次,未加锁) + cached = await self._get_cached(video_hash) + if cached: + logger.info(f"视频缓存命中(预检查) hash={video_hash[:16]}") + return {"summary": cached} - Returns: - Dict[str, str]: 包含分析结果的字典,格式为 {"summary": "分析结果"} - """ - if self.disabled: - return {"summary": "❌ 视频分析功能已禁用:没有可用的视频处理实现"} + # 获取锁避免重复处理 + async with _locks_guard: + lock = _video_locks.get(video_hash) + if lock is None: + lock = asyncio.Lock() + _video_locks[video_hash] = lock + async with lock: + # 双检缓存 + cached2 = await self._get_cached(video_hash) + if cached2: + logger.info(f"视频缓存命中(锁后) hash={video_hash[:16]}") + return {"summary": cached2} - video_hash = None - video_event = None - - try: - logger.info("开始从字节数据分析视频") - - # 兼容性处理:如果传入了prompt参数,使用prompt;否则使用user_question - question = prompt if prompt is not None else user_question - - # 检查视频数据是否有效 - if not video_bytes: - return {"summary": "❌ 视频数据为空"} - - # 计算视频hash值 - video_hash = self._calculate_video_hash(video_bytes) - logger.info(f"视频hash: {video_hash}") - - # 改进的并发控制:使用每个视频独立的锁和事件 - async with video_lock_manager: - if video_hash not in video_locks: - video_locks[video_hash] = asyncio.Lock() - video_events[video_hash] = asyncio.Event() - - video_lock = video_locks[video_hash] - video_event = video_events[video_hash] - - # 尝试获取该视频的专用锁 - if video_lock.locked(): - logger.info(f"⏳ 相同视频正在处理中,等待处理完成... (hash: {video_hash[:16]}...)") - try: - # 等待处理完成的事件信号,最多等待60秒 - await asyncio.wait_for(video_event.wait(), timeout=60.0) - logger.info("✅ 等待结束,检查是否有处理结果") - - # 检查是否有结果了 - existing_video = await self._check_video_exists(video_hash) - if existing_video: - logger.info(f"✅ 找到了处理结果,直接返回 (id: {existing_video.id})") - return {"summary": existing_video.description} - else: - logger.warning("⚠️ 等待完成但未找到结果,可能处理失败") - except asyncio.TimeoutError: - logger.warning("⚠️ 等待超时(60秒),放弃等待") - - # 获取锁开始处理 - async with video_lock: - logger.info(f"🔒 获得视频处理锁,开始处理 (hash: {video_hash[:16]}...)") - - # 再次检查数据库(可能在等待期间已经有结果了) - existing_video = await self._check_video_exists(video_hash) - if existing_video: - logger.info(f"✅ 获得锁后发现已有结果,直接返回 (id: {existing_video.id})") - video_event.set() # 通知其他等待者 - return {"summary": existing_video.description} - - # 未找到已存在记录,开始新的分析 - logger.info("未找到已存在的视频记录,开始新的分析") - - # 创建临时文件进行分析 - with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file: - temp_file.write(video_bytes) - temp_path = temp_file.name - - try: - # 检查临时文件是否创建成功 - if not os.path.exists(temp_path): - video_event.set() # 通知等待者 - return {"summary": "❌ 临时文件创建失败"} - - # 使用临时文件进行分析 - success, result = await self.analyze_video(temp_path, question) - - finally: - # 清理临时文件 - if os.path.exists(temp_path): - os.unlink(temp_path) - - # 保存分析结果到数据库(仅保存成功的结果) - if success and not result.startswith("❌"): - metadata = {"filename": filename, "file_size": len(video_bytes), "analysis_timestamp": time.time()} - await self._store_video_result(video_hash=video_hash, description=result, metadata=metadata) - logger.info("✅ 分析结果已保存到数据库") - else: - logger.warning("⚠️ 分析失败,不保存到数据库以便后续重试") - - # 处理完成,通知等待者并清理资源 - video_event.set() - async with video_lock_manager: - # 清理资源 - video_locks.pop(video_hash, None) - video_events.pop(video_hash, None) - - return {"summary": result} - - except Exception as e: - error_msg = f"❌ 从字节数据分析视频失败: {e!s}" - logger.error(error_msg) - - # 不保存错误信息到数据库,允许后续重试 - logger.info("💡 错误信息不保存到数据库,允许后续重试") - - # 处理失败,通知等待者并清理资源 try: - if video_hash and video_event: - async with video_lock_manager: - if video_hash in video_events: - video_events[video_hash].set() - video_locks.pop(video_hash, None) - video_events.pop(video_hash, None) - except Exception as cleanup_e: - logger.error(f"❌ 清理锁资源失败: {cleanup_e}") - - return {"summary": error_msg} - - def is_supported_video(self, file_path: str) -> bool: - """检查是否为支持的视频格式""" - supported_formats = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"} - return Path(file_path).suffix.lower() in supported_formats - - def get_processing_capabilities(self) -> dict[str, any]: - """获取处理能力信息""" - if not RUST_VIDEO_AVAILABLE: - return {"error": "Rust视频处理模块不可用", "available": False, "reason": "rust_video模块未安装或加载失败"} + with tempfile.NamedTemporaryFile(delete=False) as fp: + fp.write(video_bytes) + temp_path = fp.name + try: + ok, summary = await self.analyze_video(temp_path, q) + # 写入缓存(仅成功) + if ok: + await self._save_cache(video_hash, summary, len(video_bytes)) + return {"summary": summary} + finally: + if os.path.exists(temp_path): + try: + os.remove(temp_path) + except Exception: # pragma: no cover + pass + except Exception as e: # pragma: no cover + return {"summary": f"❌ 处理失败: {e}"} + # ---- 缓存辅助 ---- + async def _get_cached(self, video_hash: str) -> str | None: try: - system_info = rust_video.get_system_info() + async with get_db_session() as session: # type: ignore + result = await session.execute(select(Videos).where(Videos.video_hash == video_hash)) # type: ignore + obj: Videos | None = result.scalar_one_or_none() # type: ignore + if obj and obj.vlm_processed and obj.description: + # 更新使用次数 + try: + await session.execute( + update(Videos) + .where(Videos.id == obj.id) # type: ignore + .values(count=obj.count + 1 if obj.count is not None else 1) + ) + await session.commit() + except Exception: # pragma: no cover + await session.rollback() + return obj.description + except Exception: # pragma: no cover + pass + return None - # 创建一个临时的extractor来获取CPU特性 - extractor = rust_video.VideoKeyframeExtractor(threads=0, verbose=False) - cpu_features = extractor.get_cpu_features() - - capabilities = { - "system": { - "threads": system_info.get("threads", 0), - "rust_version": system_info.get("version", "unknown"), - }, - "cpu_features": cpu_features, - "recommended_settings": self._get_recommended_settings(cpu_features), - "analysis_modes": ["auto", "batch", "sequential"], - "supported_formats": [".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".m4v", ".3gp", ".webm"], - "available": True, - } - - return capabilities - - except Exception as e: - logger.error(f"获取处理能力信息失败: {e}") - return {"error": str(e), "available": False} - - def _get_recommended_settings(self, cpu_features: dict[str, bool]) -> dict[str, any]: - """根据CPU特性推荐最佳设置""" - settings = { - "use_simd": any(cpu_features.values()), - "block_size": 8192, - "threads": 0, # 自动检测 - } - - # 根据CPU特性调整设置 - if cpu_features.get("avx2", False): - settings["block_size"] = 16384 # AVX2支持更大的块 - settings["optimization_level"] = "avx2" - elif cpu_features.get("sse2", False): - settings["block_size"] = 8192 - settings["optimization_level"] = "sse2" - else: - settings["use_simd"] = False - settings["block_size"] = 4096 - settings["optimization_level"] = "scalar" - - return settings + async def _save_cache(self, video_hash: str, summary: str, file_size: int) -> None: + try: + async with get_db_session() as session: # type: ignore + stmt = insert(Videos).values( # type: ignore + video_id="", + video_hash=video_hash, + description=summary, + count=1, + timestamp=time.time(), + vlm_processed=True, + duration=None, + frame_count=None, + fps=None, + resolution=None, + file_size=file_size, + ) + try: + await session.execute(stmt) + await session.commit() + logger.debug(f"视频缓存写入 success hash={video_hash}") + except sa_exc.IntegrityError: # 可能并发已写入 + await session.rollback() + logger.debug(f"视频缓存已存在 hash={video_hash}") + except Exception: # pragma: no cover + logger.debug("视频缓存写入失败") -# 全局实例 -_video_analyzer = None +# ---- 外部接口 ---- +_INSTANCE: VideoAnalyzer | None = None def get_video_analyzer() -> VideoAnalyzer: - """获取视频分析器实例(单例模式)""" - global _video_analyzer - if _video_analyzer is None: - _video_analyzer = VideoAnalyzer() - return _video_analyzer + global _INSTANCE + if _INSTANCE is None: + _INSTANCE = VideoAnalyzer() + return _INSTANCE def is_video_analysis_available() -> bool: - """检查视频分析功能是否可用 + return True - Returns: - bool: 如果有任何可用的视频处理实现则返回True - """ - # 现在即使Rust模块不可用,也可以使用Python降级实现 + +def get_video_analysis_status() -> dict[str, Any]: try: - import cv2 - - return True - except ImportError: - return False - - -def get_video_analysis_status() -> dict[str, any]: - """获取视频分析功能的详细状态信息 - - Returns: - Dict[str, any]: 包含功能状态信息的字典 - """ - # 检查OpenCV是否可用 - opencv_available = False - try: - import cv2 - - opencv_available = True - except ImportError: - pass - - status = { - "available": opencv_available or RUST_VIDEO_AVAILABLE, - "implementations": { - "rust_keyframe": { - "available": RUST_VIDEO_AVAILABLE, - "description": "Rust智能关键帧提取", - "supported_modes": ["keyframe"], - }, - "python_legacy": { - "available": opencv_available, - "description": "Python传统抽帧方法", - "supported_modes": ["fixed_number", "time_interval"], - }, - }, - "supported_modes": [], + info = video.get_system_info() # type: ignore[attr-defined] + except Exception as e: # pragma: no cover + return {"available": False, "error": str(e)} + inst = get_video_analyzer() + return { + "available": True, + "system": info, + "modes": ["auto", "batch", "sequential"], + "max_frames_default": inst.max_frames, + "implementation": "inkfox", } - - # 汇总支持的模式 - if RUST_VIDEO_AVAILABLE: - status["supported_modes"].extend(["keyframe"]) - if opencv_available: - status["supported_modes"].extend(["fixed_number", "time_interval"]) - - if not status["available"]: - status.update({"error": "没有可用的视频处理实现", "solution": "请安装opencv-python或rust_video模块"}) - - return status diff --git a/src/chat/utils/utils_video_legacy.py b/src/chat/utils/utils_video_legacy.py index 46eb13857..7f5d0a35b 100644 --- a/src/chat/utils/utils_video_legacy.py +++ b/src/chat/utils/utils_video_legacy.py @@ -461,14 +461,11 @@ class LegacyVideoAnalyzer: # logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片") # 获取模型信息和客户端 - selection_result = self.video_llm._model_selector.select_best_available_model(set(), "response") - if not selection_result: - raise RuntimeError("无法为视频分析选择可用模型 (legacy)。") - model_info, api_provider, client = selection_result + model_info, api_provider, client = self.video_llm._select_model() # logger.info(f"使用模型: {model_info.name} 进行多帧分析") # 直接执行多图片请求 - api_response = await self.video_llm._executor.execute_request( + api_response = await self.video_llm._execute_request( api_provider=api_provider, client=client, request_type=RequestType.RESPONSE, diff --git a/src/common/database/connection_pool_manager.py b/src/common/database/connection_pool_manager.py index 622e02820..7bd1e5684 100644 --- a/src/common/database/connection_pool_manager.py +++ b/src/common/database/connection_pool_manager.py @@ -5,9 +5,8 @@ import asyncio import time -import weakref from contextlib import asynccontextmanager -from typing import Any, Dict, Optional, Set +from typing import Any from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker @@ -69,7 +68,7 @@ class ConnectionPoolManager: self.max_idle = max_idle # 连接池 - self._connections: Set[ConnectionInfo] = set() + self._connections: set[ConnectionInfo] = set() self._lock = asyncio.Lock() # 统计信息 @@ -83,7 +82,7 @@ class ConnectionPoolManager: } # 后台清理任务 - self._cleanup_task: Optional[asyncio.Task] = None + self._cleanup_task: asyncio.Task | None = None self._should_cleanup = False logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})") @@ -144,7 +143,7 @@ class ConnectionPoolManager: yield connection_info.session - except Exception as e: + except Exception: # 发生错误时回滚连接 if connection_info and connection_info.session: try: @@ -157,7 +156,7 @@ class ConnectionPoolManager: if connection_info: connection_info.mark_released() - async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> Optional[ConnectionInfo]: + async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> ConnectionInfo | None: """获取可复用的连接""" async with self._lock: # 清理过期连接 @@ -231,7 +230,7 @@ class ConnectionPoolManager: self._connections.clear() logger.info("所有连接已关闭") - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """获取连接池统计信息""" return { **self._stats, @@ -244,7 +243,7 @@ class ConnectionPoolManager: # 全局连接池管理器实例 -_connection_pool_manager: Optional[ConnectionPoolManager] = None +_connection_pool_manager: ConnectionPoolManager | None = None def get_connection_pool_manager() -> ConnectionPoolManager: @@ -266,4 +265,4 @@ async def stop_connection_pool(): global _connection_pool_manager if _connection_pool_manager: await _connection_pool_manager.stop() - _connection_pool_manager = None \ No newline at end of file + _connection_pool_manager = None diff --git a/src/common/database/database.py b/src/common/database/database.py index 8cca5dda3..894ec5f2b 100644 --- a/src/common/database/database.py +++ b/src/common/database/database.py @@ -2,15 +2,16 @@ import os from rich.traceback import install +from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool + +# 数据库批量调度器和连接池 +from src.common.database.db_batch_scheduler import get_db_batch_scheduler + # SQLAlchemy相关导入 from src.common.database.sqlalchemy_init import initialize_database_compat from src.common.database.sqlalchemy_models import get_db_session, get_engine from src.common.logger import get_logger -# 数据库批量调度器和连接池 -from src.common.database.db_batch_scheduler import get_db_batch_scheduler -from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool - install(extra_lines=3) _sql_engine = None diff --git a/src/common/database/db_batch_scheduler.py b/src/common/database/db_batch_scheduler.py index 4a3f18936..428ccaf4a 100644 --- a/src/common/database/db_batch_scheduler.py +++ b/src/common/database/db_batch_scheduler.py @@ -6,19 +6,19 @@ import asyncio import time from collections import defaultdict, deque -from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar +from collections.abc import Callable from contextlib import asynccontextmanager +from dataclasses import dataclass +from typing import Any, TypeVar -from sqlalchemy import select, delete, insert, update -from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import delete, insert, select, update from src.common.database.sqlalchemy_database_api import get_db_session from src.common.logger import get_logger logger = get_logger("db_batch_scheduler") -T = TypeVar('T') +T = TypeVar("T") @dataclass @@ -26,10 +26,10 @@ class BatchOperation: """批量操作基础类""" operation_type: str # 'select', 'insert', 'update', 'delete' model_class: Any - conditions: Dict[str, Any] - data: Optional[Dict[str, Any]] = None - callback: Optional[Callable] = None - future: Optional[asyncio.Future] = None + conditions: dict[str, Any] + data: dict[str, Any] | None = None + callback: Callable | None = None + future: asyncio.Future | None = None timestamp: float = 0.0 def __post_init__(self): @@ -42,7 +42,7 @@ class BatchResult: """批量操作结果""" success: bool data: Any = None - error: Optional[str] = None + error: str | None = None class DatabaseBatchScheduler: @@ -57,23 +57,23 @@ class DatabaseBatchScheduler: self.max_queue_size = max_queue_size # 操作队列,按操作类型和模型分类 - self.operation_queues: Dict[str, deque] = defaultdict(deque) + self.operation_queues: dict[str, deque] = defaultdict(deque) # 调度控制 - self._scheduler_task: Optional[asyncio.Task] = None + self._scheduler_task: asyncio.Task | None = None self._is_running = bool = False self._lock = asyncio.Lock() # 统计信息 self.stats = { - 'total_operations': 0, - 'batched_operations': 0, - 'cache_hits': 0, - 'execution_time': 0.0 + "total_operations": 0, + "batched_operations": 0, + "cache_hits": 0, + "execution_time": 0.0 } # 简单的结果缓存(用于频繁的查询) - self._result_cache: Dict[str, Tuple[Any, float]] = {} + self._result_cache: dict[str, tuple[Any, float]] = {} self._cache_ttl = 5.0 # 5秒缓存 async def start(self): @@ -102,7 +102,7 @@ class DatabaseBatchScheduler: await self._flush_all_queues() logger.info("数据库批量调度器已停止") - def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: Dict[str, Any]) -> str: + def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str: """生成缓存键""" # 简单的缓存键生成,实际可以根据需要优化 key_parts = [ @@ -112,12 +112,12 @@ class DatabaseBatchScheduler: ] return "|".join(key_parts) - def _get_from_cache(self, cache_key: str) -> Optional[Any]: + def _get_from_cache(self, cache_key: str) -> Any | None: """从缓存获取结果""" if cache_key in self._result_cache: result, timestamp = self._result_cache[cache_key] if time.time() - timestamp < self._cache_ttl: - self.stats['cache_hits'] += 1 + self.stats["cache_hits"] += 1 return result else: # 清理过期缓存 @@ -131,7 +131,7 @@ class DatabaseBatchScheduler: async def add_operation(self, operation: BatchOperation) -> asyncio.Future: """添加操作到队列""" # 检查是否可以立即返回缓存结果 - if operation.operation_type == 'select': + if operation.operation_type == "select": cache_key = self._generate_cache_key( operation.operation_type, operation.model_class, @@ -158,7 +158,7 @@ class DatabaseBatchScheduler: await self._execute_operations([operation]) else: self.operation_queues[queue_key].append(operation) - self.stats['total_operations'] += 1 + self.stats["total_operations"] += 1 return future @@ -193,7 +193,7 @@ class DatabaseBatchScheduler: if operations: await self._execute_operations(list(operations)) - async def _execute_operations(self, operations: List[BatchOperation]): + async def _execute_operations(self, operations: list[BatchOperation]): """执行批量操作""" if not operations: return @@ -209,13 +209,13 @@ class DatabaseBatchScheduler: # 为每种操作类型创建批量执行任务 tasks = [] for op_type, ops in op_groups.items(): - if op_type == 'select': + if op_type == "select": tasks.append(self._execute_select_batch(ops)) - elif op_type == 'insert': + elif op_type == "insert": tasks.append(self._execute_insert_batch(ops)) - elif op_type == 'update': + elif op_type == "update": tasks.append(self._execute_update_batch(ops)) - elif op_type == 'delete': + elif op_type == "delete": tasks.append(self._execute_delete_batch(ops)) # 并发执行所有操作 @@ -238,7 +238,7 @@ class DatabaseBatchScheduler: operation.future.set_result(result) # 缓存查询结果 - if operation.operation_type == 'select': + if operation.operation_type == "select": cache_key = self._generate_cache_key( operation.operation_type, operation.model_class, @@ -246,7 +246,7 @@ class DatabaseBatchScheduler: ) self._set_cache(cache_key, result) - self.stats['batched_operations'] += len(operations) + self.stats["batched_operations"] += len(operations) except Exception as e: logger.error(f"批量操作执行失败: {e}", exc_info="") @@ -255,9 +255,9 @@ class DatabaseBatchScheduler: if operation.future and not operation.future.done(): operation.future.set_exception(e) finally: - self.stats['execution_time'] += time.time() - start_time + self.stats["execution_time"] += time.time() - start_time - async def _execute_select_batch(self, operations: List[BatchOperation]): + async def _execute_select_batch(self, operations: list[BatchOperation]): """批量执行查询操作""" # 合并相似的查询条件 merged_conditions = self._merge_select_conditions(operations) @@ -302,7 +302,7 @@ class DatabaseBatchScheduler: return results if len(results) > 1 else results[0] if results else [] - async def _execute_insert_batch(self, operations: List[BatchOperation]): + async def _execute_insert_batch(self, operations: list[BatchOperation]): """批量执行插入操作""" async with get_db_session() as session: try: @@ -323,7 +323,7 @@ class DatabaseBatchScheduler: logger.error(f"批量插入失败: {e}", exc_info=True) return [0] * len(operations) - async def _execute_update_batch(self, operations: List[BatchOperation]): + async def _execute_update_batch(self, operations: list[BatchOperation]): """批量执行更新操作""" async with get_db_session() as session: try: @@ -353,7 +353,7 @@ class DatabaseBatchScheduler: logger.error(f"批量更新失败: {e}", exc_info=True) return [0] * len(operations) - async def _execute_delete_batch(self, operations: List[BatchOperation]): + async def _execute_delete_batch(self, operations: list[BatchOperation]): """批量执行删除操作""" async with get_db_session() as session: try: @@ -382,7 +382,7 @@ class DatabaseBatchScheduler: logger.error(f"批量删除失败: {e}", exc_info=True) return [0] * len(operations) - def _merge_select_conditions(self, operations: List[BatchOperation]) -> Dict[Tuple, List[BatchOperation]]: + def _merge_select_conditions(self, operations: list[BatchOperation]) -> dict[tuple, list[BatchOperation]]: """合并相似的查询条件""" merged = {} @@ -405,15 +405,15 @@ class DatabaseBatchScheduler: # 记录操作 if condition_key not in merged: - merged[condition_key] = {'_operations': []} - if '_operations' not in merged[condition_key]: - merged[condition_key]['_operations'] = [] - merged[condition_key]['_operations'].append(op) + merged[condition_key] = {"_operations": []} + if "_operations" not in merged[condition_key]: + merged[condition_key]["_operations"] = [] + merged[condition_key]["_operations"].append(op) # 去重并构建最终条件 final_merged = {} for condition_key, conditions in merged.items(): - operations = conditions.pop('_operations') + operations = conditions.pop("_operations") # 去重 for field_name, values in conditions.items(): @@ -423,13 +423,13 @@ class DatabaseBatchScheduler: return final_merged - def get_stats(self) -> Dict[str, Any]: + def get_stats(self) -> dict[str, Any]: """获取统计信息""" return { **self.stats, - 'cache_size': len(self._result_cache), - 'queue_sizes': {k: len(v) for k, v in self.operation_queues.items()}, - 'is_running': self._is_running + "cache_size": len(self._result_cache), + "queue_sizes": {k: len(v) for k, v in self.operation_queues.items()}, + "is_running": self._is_running } @@ -450,20 +450,20 @@ async def get_batch_session(): # 便捷函数 -async def batch_select(model_class: Any, conditions: Dict[str, Any]) -> Any: +async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any: """批量查询""" operation = BatchOperation( - operation_type='select', + operation_type="select", model_class=model_class, conditions=conditions ) return await db_batch_scheduler.add_operation(operation) -async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int: +async def batch_insert(model_class: Any, data: dict[str, Any]) -> int: """批量插入""" operation = BatchOperation( - operation_type='insert', + operation_type="insert", model_class=model_class, conditions={}, data=data @@ -471,10 +471,10 @@ async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int: return await db_batch_scheduler.add_operation(operation) -async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[str, Any]) -> int: +async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int: """批量更新""" operation = BatchOperation( - operation_type='update', + operation_type="update", model_class=model_class, conditions=conditions, data=data @@ -482,10 +482,10 @@ async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[ return await db_batch_scheduler.add_operation(operation) -async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int: +async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int: """批量删除""" operation = BatchOperation( - operation_type='delete', + operation_type="delete", model_class=model_class, conditions=conditions ) @@ -494,4 +494,4 @@ async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int: def get_db_batch_scheduler() -> DatabaseBatchScheduler: """获取数据库批量调度器实例""" - return db_batch_scheduler \ No newline at end of file + return db_batch_scheduler diff --git a/src/common/database/sqlalchemy_models.py b/src/common/database/sqlalchemy_models.py index 0e8dd651e..274723ea8 100644 --- a/src/common/database/sqlalchemy_models.py +++ b/src/common/database/sqlalchemy_models.py @@ -15,8 +15,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import Mapped, mapped_column -from src.common.logger import get_logger from src.common.database.connection_pool_manager import get_connection_pool_manager +from src.common.logger import get_logger logger = get_logger("sqlalchemy_models") diff --git a/src/common/logger.py b/src/common/logger.py index bcb311b9f..f8b1a1751 100644 --- a/src/common/logger.py +++ b/src/common/logger.py @@ -1,13 +1,13 @@ # 使用基于时间戳的文件处理器,简单的轮转份数限制 import logging +import tarfile import threading import time -import tarfile from collections.abc import Callable from datetime import datetime, timedelta from pathlib import Path -from typing import Any, Optional, Dict +from typing import Any import orjson import structlog @@ -18,15 +18,15 @@ LOG_DIR = Path("logs") LOG_DIR.mkdir(exist_ok=True) # 全局handler实例,避免重复创建(可能为None表示禁用文件日志) -_file_handler: Optional[logging.Handler] = None -_console_handler: Optional[logging.Handler] = None +_file_handler: logging.Handler | None = None +_console_handler: logging.Handler | None = None # 动态 logger 元数据注册表 (name -> {alias:str|None, color:str|None}) _LOGGER_META_LOCK = threading.Lock() -_LOGGER_META: Dict[str, Dict[str, Optional[str]]] = {} +_LOGGER_META: dict[str, dict[str, str | None]] = {} -def _normalize_color(color: Optional[str]) -> Optional[str]: +def _normalize_color(color: str | None) -> str | None: """接受 ANSI 码 / #RRGGBB / rgb(r,g,b) / 颜色名(直接返回) -> ANSI 码. 不做复杂解析,只支持 #RRGGBB 转 24bit ANSI。 """ @@ -49,13 +49,13 @@ def _normalize_color(color: Optional[str]) -> Optional[str]: nums = color[color.find("(") + 1 : -1].split(",") r, g, b = (int(x) for x in nums[:3]) return f"\033[38;2;{r};{g};{b}m" - except Exception: # noqa: BLE001 + except Exception: return None # 其他情况直接返回,假设是短ANSI或名称(控制台渲染器不做翻译,仅输出) return color -def _register_logger_meta(name: str, *, alias: Optional[str] = None, color: Optional[str] = None): +def _register_logger_meta(name: str, *, alias: str | None = None, color: str | None = None): """注册/更新 logger 元数据。""" if not name: return @@ -67,7 +67,7 @@ def _register_logger_meta(name: str, *, alias: Optional[str] = None, color: Opti meta["color"] = _normalize_color(color) -def get_logger_meta(name: str) -> Dict[str, Optional[str]]: +def get_logger_meta(name: str) -> dict[str, str | None]: with _LOGGER_META_LOCK: return _LOGGER_META.get(name, {"alias": None, "color": None}).copy() @@ -170,7 +170,7 @@ class TimestampedFileHandler(logging.Handler): try: self._compress_stale_logs() self._cleanup_old_files() - except Exception as e: # noqa: BLE001 + except Exception as e: print(f"[日志轮转] 轮转过程出错: {e}") def _compress_stale_logs(self): # sourcery skip: extract-method @@ -184,12 +184,12 @@ class TimestampedFileHandler(logging.Handler): continue # 压缩 try: - with tarfile.open(tar_path, "w:gz") as tf: # noqa: SIM117 + with tarfile.open(tar_path, "w:gz") as tf: tf.add(f, arcname=f.name) f.unlink(missing_ok=True) - except Exception as e: # noqa: BLE001 + except Exception as e: print(f"[日志压缩] 压缩 {f.name} 失败: {e}") - except Exception as e: # noqa: BLE001 + except Exception as e: print(f"[日志压缩] 过程出错: {e}") def _cleanup_old_files(self): @@ -206,9 +206,9 @@ class TimestampedFileHandler(logging.Handler): mtime = datetime.fromtimestamp(f.stat().st_mtime) if mtime < cutoff: f.unlink(missing_ok=True) - except Exception as e: # noqa: BLE001 + except Exception as e: print(f"[日志清理] 删除 {f} 失败: {e}") - except Exception as e: # noqa: BLE001 + except Exception as e: print(f"[日志清理] 清理过程出错: {e}") def emit(self, record): @@ -850,7 +850,7 @@ class ModuleColoredConsoleRenderer: if logger_name: # 获取别名,如果没有别名则使用原名称 # 若上面条件不成立需要再次获取 meta - if 'meta' not in locals(): + if "meta" not in locals(): meta = get_logger_meta(logger_name) display_name = meta.get("alias") or DEFAULT_MODULE_ALIASES.get(logger_name, logger_name) @@ -1066,7 +1066,7 @@ raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger() binds: dict[str, Callable] = {} -def get_logger(name: str | None, *, color: Optional[str] = None, alias: Optional[str] = None) -> structlog.stdlib.BoundLogger: +def get_logger(name: str | None, *, color: str | None = None, alias: str | None = None) -> structlog.stdlib.BoundLogger: """获取/创建 structlog logger。 新增: @@ -1132,10 +1132,10 @@ def cleanup_old_logs(): tar_path = f.with_suffix(f.suffix + ".tar.gz") if tar_path.exists(): continue - with tarfile.open(tar_path, "w:gz") as tf: # noqa: SIM117 + with tarfile.open(tar_path, "w:gz") as tf: tf.add(f, arcname=f.name) f.unlink(missing_ok=True) - except Exception as e: # noqa: BLE001 + except Exception as e: logger = get_logger("logger") logger.warning(f"周期压缩日志时出错: {e}") @@ -1152,7 +1152,7 @@ def cleanup_old_logs(): log_file.unlink(missing_ok=True) deleted_count += 1 deleted_size += size - except Exception as e: # noqa: BLE001 + except Exception as e: logger = get_logger("logger") logger.warning(f"清理日志文件 {log_file} 时出错: {e}") if deleted_count: @@ -1160,7 +1160,7 @@ def cleanup_old_logs(): logger.info( f"清理 {deleted_count} 个过期日志 (≈{deleted_size / 1024 / 1024:.2f}MB), 保留策略={retention_days}天" ) - except Exception as e: # noqa: BLE001 + except Exception as e: logger = get_logger("logger") logger.error(f"清理旧日志文件时出错: {e}") @@ -1183,7 +1183,7 @@ def start_log_cleanup_task(): while True: try: cleanup_old_logs() - except Exception as e: # noqa: BLE001 + except Exception as e: print(f"[日志任务] 执行清理出错: {e}") # 再次等待到下一个午夜 time.sleep(max(1, seconds_until_next_midnight())) diff --git a/src/main.py b/src/main.py index bce9b7d6c..d3a2bc387 100644 --- a/src/main.py +++ b/src/main.py @@ -120,10 +120,10 @@ class MainSystem: logger.warning("未发现任何兴趣计算器组件") return - logger.info(f"发现的兴趣计算器组件:") + logger.info("发现的兴趣计算器组件:") for calc_name, calc_info in interest_calculators.items(): - enabled = getattr(calc_info, 'enabled', True) - default_enabled = getattr(calc_info, 'enabled_by_default', True) + enabled = getattr(calc_info, "enabled", True) + default_enabled = getattr(calc_info, "enabled_by_default", True) logger.info(f" - {calc_name}: 启用: {enabled}, 默认启用: {default_enabled}") # 初始化兴趣度管理器 @@ -136,8 +136,8 @@ class MainSystem: # 使用组件注册表获取组件类并注册 for calc_name, calc_info in interest_calculators.items(): - enabled = getattr(calc_info, 'enabled', True) - default_enabled = getattr(calc_info, 'enabled_by_default', True) + enabled = getattr(calc_info, "enabled", True) + default_enabled = getattr(calc_info, "enabled_by_default", True) if not enabled or not default_enabled: logger.info(f"兴趣计算器 {calc_name} 未启用,跳过") @@ -183,7 +183,7 @@ class MainSystem: async def _async_cleanup(self): """异步清理资源""" try: - + # 停止数据库服务 try: from src.common.database.database import stop_database @@ -343,8 +343,8 @@ MoFox_Bot(第三方修改版) # 初始化表情管理器 get_emoji_manager().initialize() logger.info("表情包管理器初始化成功") - - ''' + + """ # 初始化回复后关系追踪系统 try: from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system @@ -356,8 +356,8 @@ MoFox_Bot(第三方修改版) except Exception as e: logger.error(f"回复后关系追踪系统初始化失败: {e}") relationship_tracker = None - ''' - + """ + # 启动情绪管理器 await mood_manager.start() logger.info("情绪管理器初始化成功") @@ -487,10 +487,10 @@ MoFox_Bot(第三方修改版) # 关闭应用 (MessageServer可能没有shutdown方法) try: if self.app: - if hasattr(self.app, 'shutdown'): + if hasattr(self.app, "shutdown"): await self.app.shutdown() logger.info("应用已关闭") - elif hasattr(self.app, 'stop'): + elif hasattr(self.app, "stop"): await self.app.stop() logger.info("应用已停止") else: diff --git a/src/mood/mood_manager.py b/src/mood/mood_manager.py index c0516f1db..7974dba4d 100644 --- a/src/mood/mood_manager.py +++ b/src/mood/mood_manager.py @@ -2,7 +2,6 @@ import math import random import time -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.message import MessageRecv from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.prompt import Prompt, global_prompt_manager diff --git a/src/person_info/relationship_builder.py b/src/person_info/relationship_builder.py index 05c8ab25e..554af3260 100644 --- a/src/person_info/relationship_builder.py +++ b/src/person_info/relationship_builder.py @@ -5,7 +5,6 @@ import time import traceback from typing import Any -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( get_raw_msg_before_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat, diff --git a/src/person_info/relationship_fetcher.py b/src/person_info/relationship_fetcher.py index 344c8428c..55954835b 100644 --- a/src/person_info/relationship_fetcher.py +++ b/src/person_info/relationship_fetcher.py @@ -5,7 +5,6 @@ from typing import Any import orjson from json_repair import repair_json -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.logger import get_logger from src.config.config import global_config, model_config diff --git a/src/person_info/relationship_manager.py b/src/person_info/relationship_manager.py index 23752c365..89673581c 100644 --- a/src/person_info/relationship_manager.py +++ b/src/person_info/relationship_manager.py @@ -4,8 +4,8 @@ from datetime import datetime from difflib import SequenceMatcher from typing import Any -import rjieba import orjson +import rjieba from json_repair import repair_json from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.metrics.pairwise import cosine_similarity diff --git a/src/plugin_system/__init__.py b/src/plugin_system/__init__.py index 0f2509116..aecd91f2c 100644 --- a/src/plugin_system/__init__.py +++ b/src/plugin_system/__init__.py @@ -49,7 +49,6 @@ from .base import ( ToolParamType, create_plus_command_adapter, ) - from .utils.dependency_config import configure_dependency_settings, get_dependency_config # 导入依赖管理模块 diff --git a/src/plugin_system/base/base_interest_calculator.py b/src/plugin_system/base/base_interest_calculator.py index 7ee4d9004..f12db90fa 100644 --- a/src/plugin_system/base/base_interest_calculator.py +++ b/src/plugin_system/base/base_interest_calculator.py @@ -113,7 +113,7 @@ class BaseInterestCalculator(ABC): try: self._enabled = True return True - except Exception as e: + except Exception: self._enabled = False return False @@ -170,7 +170,7 @@ class BaseInterestCalculator(ABC): if not self._enabled: return InterestCalculationResult( success=False, - message_id=getattr(message, 'message_id', ''), + message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message="组件未启用" ) @@ -184,9 +184,9 @@ class BaseInterestCalculator(ABC): except Exception as e: result = InterestCalculationResult( success=False, - message_id=getattr(message, 'message_id', ''), + message_id=getattr(message, "message_id", ""), interest_value=0.0, - error_message=f"计算执行失败: {str(e)}", + error_message=f"计算执行失败: {e!s}", calculation_time=time.time() - start_time ) self._update_statistics(result) @@ -201,7 +201,7 @@ class BaseInterestCalculator(ABC): Returns: InterestCalculatorInfo: 生成的兴趣计算器信息对象 """ - name = getattr(cls, 'component_name', cls.__name__.lower().replace('calculator', '')) + name = getattr(cls, "component_name", cls.__name__.lower().replace("calculator", "")) if "." in name: logger.error(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代") raise ValueError(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代") @@ -209,12 +209,12 @@ class BaseInterestCalculator(ABC): return InterestCalculatorInfo( name=name, component_type=ComponentType.INTEREST_CALCULATOR, - description=getattr(cls, 'component_description', cls.__doc__ or "兴趣度计算器"), - enabled_by_default=getattr(cls, 'enabled_by_default', True), + description=getattr(cls, "component_description", cls.__doc__ or "兴趣度计算器"), + enabled_by_default=getattr(cls, "enabled_by_default", True), ) def __repr__(self) -> str: return (f"{self.__class__.__name__}(" f"name={self.component_name}, " f"version={self.component_version}, " - f"enabled={self._enabled})") \ No newline at end of file + f"enabled={self._enabled})") diff --git a/src/plugin_system/base/base_plugin.py b/src/plugin_system/base/base_plugin.py index 6b28b8868..37c6e5ed5 100644 --- a/src/plugin_system/base/base_plugin.py +++ b/src/plugin_system/base/base_plugin.py @@ -43,21 +43,21 @@ class BasePlugin(PluginBase): 对应类型的ComponentInfo对象 """ if component_type == ComponentType.COMMAND: - if hasattr(component_class, 'get_command_info'): + if hasattr(component_class, "get_command_info"): return component_class.get_command_info() else: logger.warning(f"Command类 {component_class.__name__} 缺少 get_command_info 方法") return None elif component_type == ComponentType.ACTION: - if hasattr(component_class, 'get_action_info'): + if hasattr(component_class, "get_action_info"): return component_class.get_action_info() else: logger.warning(f"Action类 {component_class.__name__} 缺少 get_action_info 方法") return None elif component_type == ComponentType.INTEREST_CALCULATOR: - if hasattr(component_class, 'get_interest_calculator_info'): + if hasattr(component_class, "get_interest_calculator_info"): return component_class.get_interest_calculator_info() else: logger.warning(f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法") diff --git a/src/plugin_system/base/plugin_metadata.py b/src/plugin_system/base/plugin_metadata.py index d39a384d9..638cbb15e 100644 --- a/src/plugin_system/base/plugin_metadata.py +++ b/src/plugin_system/base/plugin_metadata.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Set +from typing import Any + @dataclass class PluginMetadata: @@ -11,15 +12,15 @@ class PluginMetadata: usage: str # 插件使用方法 # 以下为可选字段,参考自 _manifest.json 和 NoneBot 设计 - type: Optional[str] = None # 插件类别: "library", "application" + type: str | None = None # 插件类别: "library", "application" # 从原 _manifest.json 迁移的字段 version: str = "1.0.0" # 插件版本 author: str = "" # 作者名称 - license: Optional[str] = None # 开源协议 - repository_url: Optional[str] = None # 仓库地址 - keywords: List[str] = field(default_factory=list) # 关键词 - categories: List[str] = field(default_factory=list) # 分类 + license: str | None = None # 开源协议 + repository_url: str | None = None # 仓库地址 + keywords: list[str] = field(default_factory=list) # 关键词 + categories: list[str] = field(default_factory=list) # 分类 # 扩展字段 - extra: Dict[str, Any] = field(default_factory=dict) # 其他任意信息 \ No newline at end of file + extra: dict[str, Any] = field(default_factory=dict) # 其他任意信息 diff --git a/src/plugin_system/core/plugin_manager.py b/src/plugin_system/core/plugin_manager.py index c146f9c00..c3ab2ab99 100644 --- a/src/plugin_system/core/plugin_manager.py +++ b/src/plugin_system/core/plugin_manager.py @@ -1,7 +1,6 @@ import asyncio import importlib import os -import traceback from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path from typing import Any, Optional @@ -104,7 +103,7 @@ class PluginManager: return False, 1 module = self.plugin_modules.get(plugin_name) - + if not module or not hasattr(module, "__plugin_meta__"): self.failed_plugins[plugin_name] = "插件模块中缺少 __plugin_meta__" logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少 __plugin_meta__") @@ -288,7 +287,7 @@ class PluginManager: return loaded_count, failed_count - def _load_plugin_module_file(self, plugin_file: str) -> Optional[Any]: + def _load_plugin_module_file(self, plugin_file: str) -> Any | None: # sourcery skip: extract-method """加载单个插件模块文件 diff --git a/src/plugin_system/core/tool_use.py b/src/plugin_system/core/tool_use.py index 0749a8472..e74ded8ab 100644 --- a/src/plugin_system/core/tool_use.py +++ b/src/plugin_system/core/tool_use.py @@ -2,7 +2,6 @@ import inspect import time from typing import Any -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.prompt import Prompt, global_prompt_manager from src.common.cache_manager import tool_cache from src.common.logger import get_logger diff --git a/src/plugin_system/utils/__init__.py b/src/plugin_system/utils/__init__.py index 26bc638b7..3106fe84e 100644 --- a/src/plugin_system/utils/__init__.py +++ b/src/plugin_system/utils/__init__.py @@ -2,4 +2,4 @@ 插件系统工具模块 提供插件开发和管理的实用工具 -""" \ No newline at end of file +""" diff --git a/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py b/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py index 0aa447df7..81e568e3b 100644 --- a/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py +++ b/src/plugins/built_in/affinity_flow_chatter/affinity_interest_calculator.py @@ -52,7 +52,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 用户关系数据缓存 self.user_relationships: dict[str, float] = {} # user_id -> relationship_score - logger.info(f"[Affinity兴趣计算器] 初始化完成:") + logger.info("[Affinity兴趣计算器] 初始化完成:") logger.info(f" - 权重配置: {self.score_weights}") logger.info(f" - 回复阈值: {self.reply_threshold}") logger.info(f" - 智能匹配: {self.use_smart_matching}") @@ -69,9 +69,9 @@ class AffinityInterestCalculator(BaseInterestCalculator): """执行AffinityFlow风格的兴趣值计算""" try: start_time = time.time() - message_id = getattr(message, 'message_id', '') - content = getattr(message, 'processed_plain_text', '') - user_id = getattr(message, 'user_info', {}).user_id if hasattr(message, 'user_info') and hasattr(message.user_info, 'user_id') else '' + message_id = getattr(message, "message_id", "") + content = getattr(message, "processed_plain_text", "") + user_id = getattr(message, "user_info", {}).user_id if hasattr(message, "user_info") and hasattr(message.user_info, "user_id") else "" logger.debug(f"[Affinity兴趣计算] 开始处理消息 {message_id}") logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...") @@ -135,7 +135,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True) return InterestCalculationResult( success=False, - message_id=getattr(message, 'message_id', ''), + message_id=getattr(message, "message_id", ""), interest_value=0.0, error_message=str(e) ) @@ -206,9 +206,9 @@ class AffinityInterestCalculator(BaseInterestCalculator): def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float: """计算提及分""" - is_mentioned = getattr(message, 'is_mentioned', False) - is_at = getattr(message, 'is_at', False) - processed_plain_text = getattr(message, 'processed_plain_text', '') + is_mentioned = getattr(message, "is_mentioned", False) + is_at = getattr(message, "is_at", False) + processed_plain_text = getattr(message, "processed_plain_text", "") if is_mentioned: if is_at: @@ -238,7 +238,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): keywords = [] # 尝试从 key_words 字段提取(存储的是JSON字符串) - key_words = getattr(message, 'key_words', '') + key_words = getattr(message, "key_words", "") if key_words: try: import orjson @@ -250,7 +250,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 如果没有 keywords,尝试从 key_words_lite 提取 if not keywords: - key_words_lite = getattr(message, 'key_words_lite', '') + key_words_lite = getattr(message, "key_words_lite", "") if key_words_lite: try: import orjson @@ -262,7 +262,7 @@ class AffinityInterestCalculator(BaseInterestCalculator): # 如果还是没有,从消息内容中提取(降级方案) if not keywords: - content = getattr(message, 'processed_plain_text', '') or '' + content = getattr(message, "processed_plain_text", "") or "" keywords = self._extract_keywords_from_content(content) return keywords[:15] # 返回前15个关键词 @@ -298,4 +298,4 @@ class AffinityInterestCalculator(BaseInterestCalculator): self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count) # 是否使用智能兴趣匹配(作为类属性) - use_smart_matching = True \ No newline at end of file + use_smart_matching = True diff --git a/src/plugins/built_in/affinity_flow_chatter/planner.py b/src/plugins/built_in/affinity_flow_chatter/planner.py index bb77d7758..a7ffc9048 100644 --- a/src/plugins/built_in/affinity_flow_chatter/planner.py +++ b/src/plugins/built_in/affinity_flow_chatter/planner.py @@ -107,9 +107,9 @@ class ChatterActionPlanner: # 直接使用消息中已计算的标志,无需重复计算兴趣值 for message in unread_messages: try: - message_interest = getattr(message, 'interest_value', 0.3) - message_should_reply = getattr(message, 'should_reply', False) - message_should_act = getattr(message, 'should_act', False) + message_interest = getattr(message, "interest_value", 0.3) + message_should_reply = getattr(message, "should_reply", False) + message_should_act = getattr(message, "should_act", False) # 确保interest_value不是None if message_interest is None: diff --git a/src/plugins/built_in/affinity_flow_chatter/plugin.py b/src/plugins/built_in/affinity_flow_chatter/plugin.py index f214da010..63b682061 100644 --- a/src/plugins/built_in/affinity_flow_chatter/plugin.py +++ b/src/plugins/built_in/affinity_flow_chatter/plugin.py @@ -5,7 +5,7 @@ from src.common.logger import get_logger from src.plugin_system.apis.plugin_register_api import register_plugin from src.plugin_system.base.base_plugin import BasePlugin -from src.plugin_system.base.component_types import ComponentInfo, ComponentType, InterestCalculatorInfo +from src.plugin_system.base.component_types import ComponentInfo logger = get_logger("affinity_chatter_plugin") @@ -52,4 +52,3 @@ class AffinityChatterPlugin(BasePlugin): return components - \ No newline at end of file diff --git a/src/plugins/built_in/core_actions/__init__.py b/src/plugins/built_in/core_actions/__init__.py index 7107531a6..00f14f526 100644 --- a/src/plugins/built_in/core_actions/__init__.py +++ b/src/plugins/built_in/core_actions/__init__.py @@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata( "is_built_in": True, "plugin_type": "action_provider", } -) \ No newline at end of file +) diff --git a/src/plugins/built_in/permission_management/__init__.py b/src/plugins/built_in/permission_management/__init__.py index 8ba15b88d..c54c8d279 100644 --- a/src/plugins/built_in/permission_management/__init__.py +++ b/src/plugins/built_in/permission_management/__init__.py @@ -13,4 +13,4 @@ __plugin_meta__ = PluginMetadata( "is_built_in": True, "plugin_type": "permission", } -) \ No newline at end of file +) diff --git a/src/plugins/built_in/plugin_management/__init__.py b/src/plugins/built_in/plugin_management/__init__.py index c75d4668b..6b56066d0 100644 --- a/src/plugins/built_in/plugin_management/__init__.py +++ b/src/plugins/built_in/plugin_management/__init__.py @@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata( "is_built_in": True, "plugin_type": "plugin_management", } -) \ No newline at end of file +) diff --git a/src/plugins/built_in/proactive_thinker/__init__.py b/src/plugins/built_in/proactive_thinker/__init__.py index 7800a04dd..81359e471 100644 --- a/src/plugins/built_in/proactive_thinker/__init__.py +++ b/src/plugins/built_in/proactive_thinker/__init__.py @@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata( "is_built_in": True, "plugin_type": "functional" } -) \ No newline at end of file +) diff --git a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py index ffd663d18..1899361c4 100644 --- a/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py +++ b/src/plugins/built_in/proactive_thinker/proacive_thinker_event.py @@ -65,10 +65,10 @@ class ColdStartTask(AsyncTask): nickname = await person_api.get_person_value(person_id, "nickname") user_nickname = nickname or f"用户{user_id}" user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname) - + # 使用 get_or_create_stream 来安全地获取或创建流 stream = await self.chat_manager.get_or_create_stream(platform, user_info) - + formatted_stream_id = f"{stream.user_info.platform}:{stream.user_info.user_id}:private" await self.executor.execute(stream_id=formatted_stream_id, start_mode="cold_start") logger.info(f"【冷启动】已为用户 {chat_id} (昵称: {user_nickname}) 发送唤醒/问候消息。") diff --git a/src/plugins/built_in/social_toolkit_plugin/__init__.py b/src/plugins/built_in/social_toolkit_plugin/__init__.py index 01b6fba64..92b89ca6f 100644 --- a/src/plugins/built_in/social_toolkit_plugin/__init__.py +++ b/src/plugins/built_in/social_toolkit_plugin/__init__.py @@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata( "is_built_in": "true", "plugin_type": "functional" } -) \ No newline at end of file +) diff --git a/src/plugins/built_in/tts_plugin/__init__.py b/src/plugins/built_in/tts_plugin/__init__.py index 591d27882..e2595960d 100644 --- a/src/plugins/built_in/tts_plugin/__init__.py +++ b/src/plugins/built_in/tts_plugin/__init__.py @@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata( "is_built_in": True, "plugin_type": "audio_processor", } -) \ No newline at end of file +) diff --git a/src/plugins/built_in/web_search_tool/__init__.py b/src/plugins/built_in/web_search_tool/__init__.py index 5c2024cae..588af2378 100644 --- a/src/plugins/built_in/web_search_tool/__init__.py +++ b/src/plugins/built_in/web_search_tool/__init__.py @@ -13,4 +13,4 @@ __plugin_meta__ = PluginMetadata( extra={ "is_built_in": True, } -) \ No newline at end of file +)