This commit is contained in:
雅诺狐
2025-10-05 16:35:59 +08:00
52 changed files with 566 additions and 1186 deletions

View File

@@ -112,7 +112,7 @@ class InterestManager:
# 返回默认结果
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
message_id=getattr(message, "message_id", ""),
interest_value=0.3,
error_message="没有可用的兴趣值计算组件"
)
@@ -129,7 +129,7 @@ class InterestManager:
logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {getattr(message, 'message_id', '')} 使用默认兴趣值 0.5")
return InterestCalculationResult(
success=True,
message_id=getattr(message, 'message_id', ''),
message_id=getattr(message, "message_id", ""),
interest_value=0.5, # 固定默认兴趣值
should_reply=False,
should_act=False,
@@ -140,9 +140,9 @@ class InterestManager:
logger.error(f"兴趣值计算异常: {e}")
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
message_id=getattr(message, "message_id", ""),
interest_value=0.3,
error_message=f"计算异常: {str(e)}"
error_message=f"计算异常: {e!s}"
)
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
@@ -168,9 +168,9 @@ class InterestManager:
logger.error(f"兴趣值计算异常: {e}", exc_info=True)
return InterestCalculationResult(
success=False,
message_id=getattr(message, 'message_id', ''),
message_id=getattr(message, "message_id", ""),
interest_value=0.0,
error_message=f"计算异常: {str(e)}",
error_message=f"计算异常: {e!s}",
calculation_time=time.time() - start_time
)
@@ -245,4 +245,4 @@ def get_interest_manager() -> InterestManager:
global _interest_manager
if _interest_manager is None:
_interest_manager = InterestManager()
return _interest_manager
return _interest_manager

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
"""
海马体双峰分布采样器
基于旧版海马体的采样策略,适配新版记忆系统
@@ -8,16 +7,15 @@
import asyncio
import random
import time
from datetime import datetime, timedelta
from typing import List, Optional, Tuple, Dict, Any
from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
import numpy as np
import orjson
from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
build_readable_messages,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat,
)
from src.chat.utils.utils import translate_timestamp_to_human_readable
@@ -47,7 +45,7 @@ class HippocampusSampleConfig:
batch_size: int = 5 # 批处理大小
@classmethod
def from_global_config(cls) -> 'HippocampusSampleConfig':
def from_global_config(cls) -> "HippocampusSampleConfig":
"""从全局配置创建海马体采样配置"""
config = global_config.memory.hippocampus_distribution_config
return cls(
@@ -74,12 +72,12 @@ class HippocampusSampler:
self.is_running = False
# 记忆构建模型
self.memory_builder_model: Optional[LLMRequest] = None
self.memory_builder_model: LLMRequest | None = None
# 统计信息
self.sample_count = 0
self.success_count = 0
self.last_sample_results: List[Dict[str, Any]] = []
self.last_sample_results: list[dict[str, Any]] = []
async def initialize(self):
"""初始化采样器"""
@@ -101,7 +99,7 @@ class HippocampusSampler:
logger.error(f"❌ 海马体采样器初始化失败: {e}")
raise
def generate_time_samples(self) -> List[datetime]:
def generate_time_samples(self) -> list[datetime]:
"""生成双峰分布的时间采样点"""
# 计算每个分布的样本数
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
@@ -132,7 +130,7 @@ class HippocampusSampler:
# 按时间排序(从最早到最近)
return sorted(timestamps)
async def collect_message_samples(self, target_timestamp: float) -> Optional[List[Dict[str, Any]]]:
async def collect_message_samples(self, target_timestamp: float) -> list[dict[str, Any]] | None:
"""收集指定时间戳附近的消息样本"""
try:
# 随机时间窗口5-30分钟
@@ -190,7 +188,7 @@ class HippocampusSampler:
logger.error(f"收集消息样本失败: {e}")
return None
async def build_memory_from_samples(self, messages: List[Dict[str, Any]], target_timestamp: float) -> Optional[str]:
async def build_memory_from_samples(self, messages: list[dict[str, Any]], target_timestamp: float) -> str | None:
"""从消息样本构建记忆"""
if not messages or not self.memory_system or not self.memory_builder_model:
return None
@@ -262,7 +260,7 @@ class HippocampusSampler:
logger.error(f"海马体采样构建记忆失败: {e}")
return None
async def perform_sampling_cycle(self) -> Dict[str, Any]:
async def perform_sampling_cycle(self) -> dict[str, Any]:
"""执行一次完整的采样周期(优化版:批量融合构建)"""
if not self.should_sample():
return {"status": "skipped", "reason": "interval_not_met"}
@@ -363,7 +361,7 @@ class HippocampusSampler:
"duration": time.time() - start_time,
}
async def _collect_all_message_samples(self, time_samples: List[datetime]) -> List[List[Dict[str, Any]]]:
async def _collect_all_message_samples(self, time_samples: list[datetime]) -> list[list[dict[str, Any]]]:
"""批量收集所有时间点的消息样本"""
collected_messages = []
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
@@ -394,7 +392,7 @@ class HippocampusSampler:
return collected_messages
async def _fuse_and_deduplicate_messages(self, collected_messages: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
"""融合和去重消息样本"""
if not collected_messages:
return []
@@ -450,7 +448,7 @@ class HippocampusSampler:
# 返回原始消息组作为备选
return collected_messages[:5] # 限制返回数量
def _merge_adjacent_messages(self, messages: List[Dict[str, Any]], time_gap: int = 1800) -> List[List[Dict[str, Any]]]:
def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]:
"""合并时间间隔内的消息"""
if not messages:
return []
@@ -481,7 +479,7 @@ class HippocampusSampler:
return result_groups
async def _build_batch_memory(self, fused_messages: List[List[Dict[str, Any]]], time_samples: List[datetime]) -> Dict[str, Any]:
async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]:
"""批量构建记忆"""
if not fused_messages:
return {"memory_count": 0, "memories": []}
@@ -557,7 +555,7 @@ class HippocampusSampler:
logger.error(f"批量构建记忆失败: {e}")
return {"memory_count": 0, "error": str(e)}
async def _build_fused_conversation_text(self, fused_messages: List[List[Dict[str, Any]]]) -> str:
async def _build_fused_conversation_text(self, fused_messages: list[list[dict[str, Any]]]) -> str:
"""构建融合后的对话文本"""
try:
# 添加批次标识
@@ -589,7 +587,7 @@ class HippocampusSampler:
logger.error(f"构建融合文本失败: {e}")
return ""
async def _fallback_individual_build(self, fused_messages: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
async def _fallback_individual_build(self, fused_messages: list[list[dict[str, Any]]]) -> dict[str, Any]:
"""备选方案:单独构建每个消息组"""
total_memories = []
total_count = 0
@@ -609,7 +607,7 @@ class HippocampusSampler:
"fallback_mode": True
}
async def process_sample_timestamp(self, target_timestamp: float) -> Optional[str]:
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
"""处理单个时间戳采样(保留作为备选方法)"""
try:
# 收集消息样本
@@ -676,7 +674,7 @@ class HippocampusSampler:
self.is_running = False
logger.info("🛑 停止海马体后台采样任务")
def get_sampling_stats(self) -> Dict[str, Any]:
def get_sampling_stats(self) -> dict[str, Any]:
"""获取采样统计信息"""
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
@@ -713,7 +711,7 @@ class HippocampusSampler:
# 全局海马体采样器实例
_hippocampus_sampler: Optional[HippocampusSampler] = None
_hippocampus_sampler: HippocampusSampler | None = None
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
@@ -728,4 +726,4 @@ async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampl
"""初始化全局海马体采样器"""
sampler = get_hippocampus_sampler(memory_system)
await sampler.initialize()
return sampler
return sampler

View File

@@ -32,7 +32,7 @@ import time
from dataclasses import dataclass
from datetime import datetime
from enum import Enum
from typing import Any, Type, TypeVar
from typing import Any, TypeVar
E = TypeVar("E", bound=Enum)
@@ -503,7 +503,7 @@ class MemoryBuilder:
logger.warning(f"无法解析未知的记忆类型 '{type_str}',回退到上下文类型")
return MemoryType.CONTEXTUAL
def _parse_enum_value(self, enum_cls: Type[E], raw_value: Any, default: E, field_name: str) -> E:
def _parse_enum_value(self, enum_cls: type[E], raw_value: Any, default: E, field_name: str) -> E:
"""解析枚举值,兼容数字/字符串表示"""
if isinstance(raw_value, enum_cls):
return raw_value

View File

@@ -215,8 +215,8 @@ class MemoryFusionEngine:
if not keywords1 or not keywords2:
return 0.0
set1 = set(k.lower() for k in keywords1)
set2 = set(k.lower() for k in keywords2)
set1 = set(k.lower() for k in keywords1) # noqa: C401
set2 = set(k.lower() for k in keywords2) # noqa: C401
intersection = set1 & set2
union = set1 | set2

View File

@@ -69,14 +69,11 @@ class MemoryManager:
# 初始化记忆系统
self.memory_system = await initialize_memory_system(llm_model)
# 设置全局实例
global_memory_manager = self.memory_system
self.is_initialized = True
logger.info(" 记忆系统初始化完成")
logger.info(" 记忆系统初始化完成")
except Exception as e:
logger.error(f"记忆系统初始化失败: {e}")
logger.error(f"记忆系统初始化失败: {e}")
# 如果系统初始化失败,创建一个空的管理器避免系统崩溃
self.memory_system = None
self.is_initialized = True # 标记为已初始化但系统不可用
@@ -439,7 +436,7 @@ class MemoryManager:
formatted_items = [self._format_object(item) for item in obj]
filtered = [item for item in formatted_items if item]
return self._clean_text("".join(filtered)) if filtered else ""
if isinstance(obj, (int, float)):
if isinstance(obj, int | float):
return str(obj)
text = self._truncate(str(obj).strip())
return self._clean_text(text)
@@ -449,12 +446,12 @@ class MemoryManager:
for key in keys:
if obj.get(key):
value = obj[key]
if isinstance(value, (dict, list)):
if isinstance(value, dict | list):
return self._clean_text(self._format_object(value))
return self._clean_text(value)
if isinstance(obj, list) and obj:
return self._clean_text(self._format_object(obj[0]))
if isinstance(obj, (str, int, float)):
if isinstance(obj, str | int | float):
return self._clean_text(obj)
return None
@@ -471,7 +468,7 @@ class MemoryManager:
try:
if self.memory_system:
await self.memory_system.shutdown()
logger.info(" 记忆系统已关闭")
logger.info(" 记忆系统已关闭")
except Exception as e:
logger.error(f"关闭记忆系统失败: {e}")

View File

@@ -19,6 +19,8 @@ from src.chat.memory_system.memory_builder import MemoryBuilder, MemoryExtractio
from src.chat.memory_system.memory_chunk import MemoryChunk
from src.chat.memory_system.memory_fusion import MemoryFusionEngine
from src.chat.memory_system.memory_query_planner import MemoryQueryPlanner
# 记忆采样模式枚举
class MemorySamplingMode(Enum):
"""记忆采样模式"""
@@ -31,9 +33,10 @@ from src.llm_models.utils_model import LLMRequest
if TYPE_CHECKING:
from src.chat.memory_system.memory_forgetting_engine import MemoryForgettingEngine
from src.chat.memory_system.vector_memory_storage_v2 import VectorMemoryStorage
from src.common.data_models.database_data_model import DatabaseMessages
logger = get_logger(__name__)
logger = get_logger("memory_system")
# 全局记忆作用域(共享记忆库)
GLOBAL_MEMORY_SCOPE = "global"
@@ -133,15 +136,15 @@ class MemorySystem:
self.status = MemorySystemStatus.INITIALIZING
# 核心组件(简化版)
self.memory_builder: MemoryBuilder = None
self.fusion_engine: MemoryFusionEngine = None
self.unified_storage = None # 统一存储系统
self.query_planner: MemoryQueryPlanner = None
self.memory_builder: MemoryBuilder | None = None
self.fusion_engine: MemoryFusionEngine | None = None
self.unified_storage: VectorMemoryStorage | None = None # 统一存储系统
self.query_planner: MemoryQueryPlanner | None = None
self.forgetting_engine: MemoryForgettingEngine | None = None
# LLM模型
self.value_assessment_model: LLMRequest = None
self.memory_extraction_model: LLMRequest = None
self.value_assessment_model: LLMRequest | None = None
self.memory_extraction_model: LLMRequest | None = None
# 统计信息
self.total_memories = 0
@@ -162,7 +165,6 @@ class MemorySystem:
async def initialize(self):
"""异步初始化记忆系统"""
try:
logger.info("正在初始化记忆系统...")
# 初始化LLM模型
fallback_task = getattr(self.llm_model, "model_for_task", None) if self.llm_model else None
@@ -249,7 +251,7 @@ class MemorySystem:
self.forgetting_engine = MemoryForgettingEngine(forgetting_config)
planner_task_config = getattr(model_config.model_task_config, "utils_small", None)
planner_task_config = model_config.model_task_config.utils_small
planner_model: LLMRequest | None = None
try:
planner_model = LLMRequest(model_set=planner_task_config, request_type="memory.query_planner")
@@ -269,10 +271,8 @@ class MemorySystem:
self.hippocampus_sampler = None
# 统一存储已经自动加载数据,无需额外加载
logger.info("✅ 简化版记忆系统初始化完成")
self.status = MemorySystemStatus.READY
logger.info("✅ 记忆系统初始化完成")
except Exception as e:
self.status = MemorySystemStatus.ERROR
@@ -479,7 +479,7 @@ class MemorySystem:
existing_id = self._memory_fingerprints.get(fingerprint_key)
if existing_id and existing_id not in new_memory_ids:
candidate_ids.add(existing_id)
except Exception as exc:
except Exception as exc: # noqa: PERF203
logger.debug("构建记忆指纹失败,跳过候选收集: %s", exc)
# 基于主体索引的候选(使用统一存储)
@@ -557,11 +557,11 @@ class MemorySystem:
context = dict(context or {})
# 获取配置的采样模式
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'precision')
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
current_mode = MemorySamplingMode(sampling_mode)
context['__sampling_mode'] = current_mode.value
context["__sampling_mode"] = current_mode.value
logger.debug(f"使用记忆采样模式: {current_mode.value}")
# 根据采样模式处理记忆
@@ -637,7 +637,7 @@ class MemorySystem:
# 检查信息价值阈值
value_score = await self._assess_information_value(conversation_text, normalized_context)
threshold = getattr(global_config.memory, 'precision_memory_reply_threshold', 0.5)
threshold = getattr(global_config.memory, "precision_memory_reply_threshold", 0.5)
if value_score < threshold:
logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建")
@@ -843,7 +843,7 @@ class MemorySystem:
for i, (mem, score, details) in enumerate(scored_memories[:3], 1):
try:
summary = mem.content[:60] if hasattr(mem, "content") and mem.content else ""
except:
except Exception:
summary = ""
logger.info(
f" #{i} | final={details['final']:.3f} "
@@ -1440,8 +1440,8 @@ class MemorySystem:
context_keywords = context.get("keywords") or []
keyword_overlap = 0.0
if context_keywords:
memory_keywords = set(k.lower() for k in memory.keywords)
keyword_overlap = len(memory_keywords & set(k.lower() for k in context_keywords)) / max(
memory_keywords = {k.lower() for k in memory.keywords}
keyword_overlap = len(memory_keywords & {k.lower() for k in context_keywords}) / max(
len(context_keywords), 1
)
@@ -1489,7 +1489,7 @@ class MemorySystem:
"""启动海马体采样"""
if self.hippocampus_sampler:
asyncio.create_task(self.hippocampus_sampler.start_background_sampling())
logger.info("🚀 海马体后台采样已启动")
logger.info("海马体后台采样已启动")
else:
logger.warning("海马体采样器未初始化,无法启动采样")
@@ -1497,7 +1497,7 @@ class MemorySystem:
"""停止海马体采样"""
if self.hippocampus_sampler:
self.hippocampus_sampler.stop_background_sampling()
logger.info("🛑 海马体后台采样已停止")
logger.info("海马体后台采样已停止")
def get_system_stats(self) -> dict[str, Any]:
"""获取系统统计信息"""
@@ -1536,10 +1536,10 @@ class MemorySystem:
if self.unified_storage:
self.unified_storage.cleanup()
logger.info("简化记忆系统已关闭")
logger.info("简化记忆系统已关闭")
except Exception as e:
logger.error(f"记忆系统关闭失败: {e}", exc_info=True)
logger.error(f"记忆系统关闭失败: {e}", exc_info=True)
async def _rebuild_vector_storage_if_needed(self):
"""重建向量存储(如果需要)"""
@@ -1553,12 +1553,13 @@ class MemorySystem:
# 收集需要重建向量的记忆
memories_to_rebuild = []
for memory_id, memory in self.unified_storage.memory_cache.items():
# 检查记忆是否有有效的 display 文本
if memory.display and memory.display.strip():
memories_to_rebuild.append(memory)
elif memory.text_content and memory.text_content.strip():
memories_to_rebuild.append(memory)
if self.unified_storage:
for memory in self.unified_storage.memory_cache.values():
# 检查记忆是否有有效的 display 文本
if memory.display and memory.display.strip():
memories_to_rebuild.append(memory)
elif memory.text_content and memory.text_content.strip():
memories_to_rebuild.append(memory)
if not memories_to_rebuild:
logger.warning("没有找到可重建向量的记忆")
@@ -1583,14 +1584,16 @@ class MemorySystem:
logger.error(f"批量重建向量失败: {e}")
continue
# 保存重建的向量存储
await self.unified_storage.save_storage()
final_count = self.unified_storage.storage_stats.get("total_vectors", 0)
logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}")
# 向量数据在 store_memories 中已保存,此处无需额外操作
if self.unified_storage:
storage_stats = self.unified_storage.get_storage_stats()
final_count = storage_stats.get("total_vectors", 0)
logger.info(f"✅ 向量存储重建完成,最终向量数量: {final_count}")
else:
logger.warning("向量存储重建完成,但无法获取最终向量数量,因为存储系统未初始化")
except Exception as e:
logger.error(f"向量存储重建失败: {e}", exc_info=True)
logger.error(f"向量存储重建失败: {e}", exc_info=True)
# 全局记忆系统实例
@@ -1613,8 +1616,8 @@ async def initialize_memory_system(llm_model: LLMRequest | None = None):
await memory_system.initialize()
# 根据配置启动海马体采样
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate')
if sampling_mode in ['hippocampus', 'all']:
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "immediate")
if sampling_mode in ["hippocampus", "all"]:
memory_system.start_hippocampus_sampling()
return memory_system

View File

@@ -4,14 +4,13 @@
"""
import asyncio
import psutil
import time
from typing import Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field
from enum import Enum
import psutil
from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("adaptive_stream_manager")
@@ -71,16 +70,16 @@ class AdaptiveStreamManager:
# 当前状态
self.current_limit = base_concurrent_limit
self.active_streams: Set[str] = set()
self.pending_streams: Set[str] = set()
self.stream_metrics: Dict[str, StreamMetrics] = {}
self.active_streams: set[str] = set()
self.pending_streams: set[str] = set()
self.stream_metrics: dict[str, StreamMetrics] = {}
# 异步信号量
self.semaphore = asyncio.Semaphore(base_concurrent_limit)
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
# 系统监控
self.system_metrics: List[SystemMetrics] = []
self.system_metrics: list[SystemMetrics] = []
self.last_adjustment_time = 0.0
# 统计信息
@@ -95,8 +94,8 @@ class AdaptiveStreamManager:
}
# 监控任务
self.monitor_task: Optional[asyncio.Task] = None
self.adjustment_task: Optional[asyncio.Task] = None
self.monitor_task: asyncio.Task | None = None
self.adjustment_task: asyncio.Task | None = None
self.is_running = False
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
@@ -443,7 +442,7 @@ class AdaptiveStreamManager:
if hasattr(metrics, key):
setattr(metrics, key, value)
def get_stats(self) -> Dict:
def get_stats(self) -> dict:
"""获取统计信息"""
stats = self.stats.copy()
stats.update({
@@ -465,7 +464,7 @@ class AdaptiveStreamManager:
# 全局自适应管理器实例
_adaptive_manager: Optional[AdaptiveStreamManager] = None
_adaptive_manager: AdaptiveStreamManager | None = None
def get_adaptive_stream_manager() -> AdaptiveStreamManager:
@@ -485,4 +484,4 @@ async def init_adaptive_stream_manager():
async def shutdown_adaptive_stream_manager():
"""关闭自适应流管理器"""
manager = get_adaptive_stream_manager()
await manager.stop()
await manager.stop()

View File

@@ -5,9 +5,9 @@
import asyncio
import time
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
@@ -21,7 +21,7 @@ logger = get_logger("batch_database_writer")
class StreamUpdatePayload:
"""流更新数据结构"""
stream_id: str
update_data: Dict[str, Any]
update_data: dict[str, Any]
priority: int = 0 # 优先级,数字越大优先级越高
timestamp: float = field(default_factory=time.time)
@@ -47,7 +47,7 @@ class BatchDatabaseWriter:
# 运行状态
self.is_running = False
self.writer_task: Optional[asyncio.Task] = None
self.writer_task: asyncio.Task | None = None
# 统计信息
self.stats = {
@@ -60,7 +60,7 @@ class BatchDatabaseWriter:
}
# 按优先级分类的批次
self.priority_batches: Dict[int, List[StreamUpdatePayload]] = defaultdict(list)
self.priority_batches: dict[int, list[StreamUpdatePayload]] = defaultdict(list)
logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)")
@@ -98,7 +98,7 @@ class BatchDatabaseWriter:
async def schedule_stream_update(
self,
stream_id: str,
update_data: Dict[str, Any],
update_data: dict[str, Any],
priority: int = 0
) -> bool:
"""
@@ -166,7 +166,7 @@ class BatchDatabaseWriter:
await self._flush_all_batches()
logger.info("批量写入循环结束")
async def _collect_batch(self) -> List[StreamUpdatePayload]:
async def _collect_batch(self) -> list[StreamUpdatePayload]:
"""收集一个批次的数据"""
batch = []
deadline = time.time() + self.flush_interval
@@ -189,7 +189,7 @@ class BatchDatabaseWriter:
return batch
async def _write_batch(self, batch: List[StreamUpdatePayload]):
async def _write_batch(self, batch: list[StreamUpdatePayload]):
"""批量写入数据库"""
if not batch:
return
@@ -228,7 +228,7 @@ class BatchDatabaseWriter:
except Exception as single_e:
logger.error(f"单个写入也失败: {single_e}")
async def _batch_write_to_database(self, payloads: List[StreamUpdatePayload]):
async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
"""批量写入数据库"""
async with get_db_session() as session:
for payload in payloads:
@@ -268,7 +268,7 @@ class BatchDatabaseWriter:
await session.commit()
async def _direct_write(self, stream_id: str, update_data: Dict[str, Any]):
async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
"""直接写入数据库(降级方案)"""
async with get_db_session() as session:
if global_config.database.database_type == "sqlite":
@@ -315,7 +315,7 @@ class BatchDatabaseWriter:
if remaining_batch:
await self._write_batch(remaining_batch)
def get_stats(self) -> Dict[str, Any]:
def get_stats(self) -> dict[str, Any]:
"""获取统计信息"""
stats = self.stats.copy()
stats["is_running"] = self.is_running
@@ -324,7 +324,7 @@ class BatchDatabaseWriter:
# 全局批量写入器实例
_batch_writer: Optional[BatchDatabaseWriter] = None
_batch_writer: BatchDatabaseWriter | None = None
def get_batch_writer() -> BatchDatabaseWriter:
@@ -344,4 +344,4 @@ async def init_batch_writer():
async def shutdown_batch_writer():
"""关闭批量写入器"""
writer = get_batch_writer()
await writer.stop()
await writer.stop()

View File

@@ -117,7 +117,7 @@ class StreamLoopManager:
# 使用自适应流管理器获取槽位
use_adaptive = False
try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager, StreamPriority
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager()
if adaptive_manager.is_running:
@@ -137,7 +137,7 @@ class StreamLoopManager:
else:
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
else:
logger.debug(f"自适应管理器未运行,使用原始方法")
logger.debug("自适应管理器未运行,使用原始方法")
except Exception as e:
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")

View File

@@ -5,13 +5,13 @@
import asyncio
import time
from typing import Dict, List, Optional, Set
from dataclasses import dataclass
from collections import OrderedDict
from dataclasses import dataclass
from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
from src.common.logger import get_logger
logger = get_logger("stream_cache_manager")
@@ -52,14 +52,14 @@ class TieredStreamCache:
# 三层缓存存储
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据LRU
self.warm_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
self.cold_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
# 统计信息
self.stats = StreamCacheStats()
# 清理任务
self.cleanup_task: Optional[asyncio.Task] = None
self.cleanup_task: asyncio.Task | None = None
self.is_running = False
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
@@ -96,8 +96,8 @@ class TieredStreamCache:
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[Dict] = None,
group_info: GroupInfo | None = None,
data: dict | None = None,
) -> OptimizedChatStream:
"""获取或创建流 - 优化版本"""
current_time = time.time()
@@ -255,7 +255,7 @@ class TieredStreamCache:
hot_to_demote = []
for stream_id, stream in self.hot_cache.items():
# 获取最后访问时间(简化:使用创建时间作为近似)
last_access = getattr(stream, 'last_active_time', stream.create_time)
last_access = getattr(stream, "last_active_time", stream.create_time)
if current_time - last_access > self.hot_timeout:
hot_to_demote.append(stream_id)
@@ -341,7 +341,7 @@ class TieredStreamCache:
logger.info("所有缓存已清空")
async def get_stream_snapshot(self, stream_id: str) -> Optional[OptimizedChatStream]:
async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None:
"""获取流的快照(不修改缓存状态)"""
if stream_id in self.hot_cache:
return self.hot_cache[stream_id].create_snapshot()
@@ -351,13 +351,13 @@ class TieredStreamCache:
return self.cold_storage[stream_id][0].create_snapshot()
return None
def get_cached_stream_ids(self) -> Set[str]:
def get_cached_stream_ids(self) -> set[str]:
"""获取所有缓存的流ID"""
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
# 全局缓存管理器实例
_cache_manager: Optional[TieredStreamCache] = None
_cache_manager: TieredStreamCache | None = None
def get_stream_cache_manager() -> TieredStreamCache:
@@ -377,4 +377,4 @@ async def init_stream_cache_manager():
async def shutdown_stream_cache_manager():
"""关闭流缓存管理器"""
manager = get_stream_cache_manager()
await manager.stop()
await manager.stop()

View File

@@ -313,11 +313,11 @@ class ChatStream:
except Exception as e:
logger.error(f"计算消息兴趣值失败: {e}", exc_info=True)
# 异常情况下使用默认值
if hasattr(db_message, 'interest_value'):
if hasattr(db_message, "interest_value"):
db_message.interest_value = 0.3
if hasattr(db_message, 'should_reply'):
if hasattr(db_message, "should_reply"):
db_message.should_reply = False
if hasattr(db_message, 'should_act'):
if hasattr(db_message, "should_act"):
db_message.should_act = False
def _extract_reply_from_segment(self, segment) -> str | None:
@@ -894,10 +894,10 @@ def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
original_stream.saved = optimized_stream.saved
# 复制上下文信息(如果存在)
if hasattr(optimized_stream, '_stream_context') and optimized_stream._stream_context:
if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
original_stream.stream_context = optimized_stream._stream_context
if hasattr(optimized_stream, '_context_manager') and optimized_stream._context_manager:
if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
original_stream.context_manager = optimized_stream._context_manager
return original_stream

View File

@@ -3,17 +3,12 @@
避免不必要的深拷贝开销,提升多流并发性能
"""
import asyncio
import copy
import hashlib
import time
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import TYPE_CHECKING, Any
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger
from src.config.config import global_config
@@ -28,7 +23,7 @@ logger = get_logger("optimized_chat_stream")
class SharedContext:
"""共享上下文数据 - 只读数据结构"""
def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None):
def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None):
self.stream_id = stream_id
self.platform = platform
self.user_info = user_info
@@ -37,7 +32,7 @@ class SharedContext:
self._frozen = True
def __setattr__(self, name, value):
if hasattr(self, '_frozen') and self._frozen and name not in ['_frozen']:
if hasattr(self, "_frozen") and self._frozen and name not in ["_frozen"]:
raise AttributeError(f"SharedContext is frozen, cannot modify {name}")
super().__setattr__(name, value)
@@ -46,7 +41,7 @@ class LocalChanges:
"""本地修改跟踪器"""
def __init__(self):
self._changes: Dict[str, Any] = {}
self._changes: dict[str, Any] = {}
self._dirty = False
def set_change(self, key: str, value: Any):
@@ -62,7 +57,7 @@ class LocalChanges:
"""是否有修改"""
return self._dirty
def get_changes(self) -> Dict[str, Any]:
def get_changes(self) -> dict[str, Any]:
"""获取所有修改"""
return self._changes.copy()
@@ -80,8 +75,8 @@ class OptimizedChatStream:
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[Dict] = None,
group_info: GroupInfo | None = None,
data: dict | None = None,
):
# 共享的只读数据
self._shared_context = SharedContext(
@@ -129,42 +124,42 @@ class OptimizedChatStream:
"""修改用户信息时触发写时复制"""
self._ensure_copy_on_write()
# 由于SharedContext是frozen的我们需要在本地修改中记录
self._local_changes.set_change('user_info', value)
self._local_changes.set_change("user_info", value)
@property
def group_info(self) -> Optional[GroupInfo]:
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
return self._local_changes.get_change('group_info')
def group_info(self) -> GroupInfo | None:
if self._local_changes.has_changes() and "group_info" in self._local_changes._changes:
return self._local_changes.get_change("group_info")
return self._shared_context.group_info
@group_info.setter
def group_info(self, value: Optional[GroupInfo]):
def group_info(self, value: GroupInfo | None):
"""修改群组信息时触发写时复制"""
self._ensure_copy_on_write()
self._local_changes.set_change('group_info', value)
self._local_changes.set_change("group_info", value)
@property
def create_time(self) -> float:
if self._local_changes.has_changes() and 'create_time' in self._local_changes._changes:
return self._local_changes.get_change('create_time')
if self._local_changes.has_changes() and "create_time" in self._local_changes._changes:
return self._local_changes.get_change("create_time")
return self._shared_context.create_time
@property
def last_active_time(self) -> float:
return self._local_changes.get_change('last_active_time', self.create_time)
return self._local_changes.get_change("last_active_time", self.create_time)
@last_active_time.setter
def last_active_time(self, value: float):
self._local_changes.set_change('last_active_time', value)
self._local_changes.set_change("last_active_time", value)
self.saved = False
@property
def sleep_pressure(self) -> float:
return self._local_changes.get_change('sleep_pressure', 0.0)
return self._local_changes.get_change("sleep_pressure", 0.0)
@sleep_pressure.setter
def sleep_pressure(self, value: float):
self._local_changes.set_change('sleep_pressure', value)
self._local_changes.set_change("sleep_pressure", value)
self.saved = False
def _ensure_copy_on_write(self):
@@ -176,14 +171,14 @@ class OptimizedChatStream:
def _get_effective_user_info(self) -> UserInfo:
"""获取有效的用户信息"""
if self._local_changes.has_changes() and 'user_info' in self._local_changes._changes:
return self._local_changes.get_change('user_info')
if self._local_changes.has_changes() and "user_info" in self._local_changes._changes:
return self._local_changes.get_change("user_info")
return self._shared_context.user_info
def _get_effective_group_info(self) -> Optional[GroupInfo]:
def _get_effective_group_info(self) -> GroupInfo | None:
"""获取有效的群组信息"""
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes:
return self._local_changes.get_change('group_info')
if self._local_changes.has_changes() and "group_info" in self._local_changes._changes:
return self._local_changes.get_change("group_info")
return self._shared_context.group_info
def update_active_time(self):
@@ -199,6 +194,7 @@ class OptimizedChatStream:
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
import json
from src.common.data_models.database_data_model import DatabaseMessages
message_info = getattr(message, "message_info", {})
@@ -298,7 +294,7 @@ class OptimizedChatStream:
self._create_stream_context()
return self._context_manager
def to_dict(self) -> Dict[str, Any]:
def to_dict(self) -> dict[str, Any]:
"""转换为字典格式 - 考虑本地修改"""
user_info = self._get_effective_user_info()
group_info = self._get_effective_group_info()
@@ -319,7 +315,7 @@ class OptimizedChatStream:
}
@classmethod
def from_dict(cls, data: Dict) -> "OptimizedChatStream":
def from_dict(cls, data: dict) -> "OptimizedChatStream":
"""从字典创建实例"""
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
@@ -481,8 +477,8 @@ def create_optimized_chat_stream(
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[Dict] = None,
group_info: GroupInfo | None = None,
data: dict | None = None,
) -> OptimizedChatStream:
"""创建优化版聊天流实例"""
return OptimizedChatStream(
@@ -491,4 +487,4 @@ def create_optimized_chat_stream(
user_info=user_info,
group_info=group_info,
data=data
)
)

View File

@@ -15,7 +15,7 @@ from src.plugin_system.base.component_types import ActionActivationType, ActionI
from src.plugin_system.core.global_announcement_manager import global_announcement_manager
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
pass
logger = get_logger("action_manager")

View File

@@ -536,7 +536,7 @@ class Prompt:
style = expr.get("style", "")
if situation and style:
formatted_expressions.append(f"- {situation}{style}")
if formatted_expressions:
style_habits_str = "\n".join(formatted_expressions)
expression_habits_block = f"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"

View File

@@ -9,8 +9,8 @@ import time
from collections import defaultdict
from pathlib import Path
import rjieba
import orjson
import rjieba
from pypinyin import Style, pinyin
from src.common.logger import get_logger

View File

@@ -6,8 +6,8 @@ import time
from collections import Counter
from typing import Any
import rjieba
import numpy as np
import rjieba
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager

File diff suppressed because it is too large Load Diff

View File

@@ -461,14 +461,11 @@ class LegacyVideoAnalyzer:
# logger.info(f"✅ 多帧消息构建完成,包含{len(frames)}张图片")
# 获取模型信息和客户端
selection_result = self.video_llm._model_selector.select_best_available_model(set(), "response")
if not selection_result:
raise RuntimeError("无法为视频分析选择可用模型 (legacy)。")
model_info, api_provider, client = selection_result
model_info, api_provider, client = self.video_llm._select_model()
# logger.info(f"使用模型: {model_info.name} 进行多帧分析")
# 直接执行多图片请求
api_response = await self.video_llm._executor.execute_request(
api_response = await self.video_llm._execute_request(
api_provider=api_provider,
client=client,
request_type=RequestType.RESPONSE,