style: 统一代码风格并进行现代化改进
对整个代码库进行了一次全面的风格统一和现代化改进。主要变更包括:
- 将 `hasattr` 等内置函数中的字符串参数从单引号 `'` 统一为双引号 `"`。
- 采用现代类型注解,例如将 `Optional[T]` 替换为 `T | None`,`List[T]` 替换为 `list[T]` 等。
- 移除不再需要的 Python 2 兼容性声明 `# -*- coding: utf-8 -*-`。
- 清理了多余的空行、注释和未使用的导入。
- 统一了文件末尾的换行符。
- 优化了部分日志输出和字符串格式化 (`f"{e!s}"`)。
这些改动旨在提升代码的可读性、一致性和可维护性,使其更符合现代 Python 编码规范。
This commit is contained in:
committed by
Windpicker-owo
parent
1fb2ab6450
commit
cd84373828
@@ -1,4 +1,3 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
海马体双峰分布采样器
|
||||
基于旧版海马体的采样策略,适配新版记忆系统
|
||||
@@ -8,16 +7,15 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from typing import List, Optional, Tuple, Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import orjson
|
||||
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
get_raw_msg_by_timestamp,
|
||||
build_readable_messages,
|
||||
get_raw_msg_by_timestamp,
|
||||
get_raw_msg_by_timestamp_with_chat,
|
||||
)
|
||||
from src.chat.utils.utils import translate_timestamp_to_human_readable
|
||||
@@ -47,7 +45,7 @@ class HippocampusSampleConfig:
|
||||
batch_size: int = 5 # 批处理大小
|
||||
|
||||
@classmethod
|
||||
def from_global_config(cls) -> 'HippocampusSampleConfig':
|
||||
def from_global_config(cls) -> "HippocampusSampleConfig":
|
||||
"""从全局配置创建海马体采样配置"""
|
||||
config = global_config.memory.hippocampus_distribution_config
|
||||
return cls(
|
||||
@@ -74,12 +72,12 @@ class HippocampusSampler:
|
||||
self.is_running = False
|
||||
|
||||
# 记忆构建模型
|
||||
self.memory_builder_model: Optional[LLMRequest] = None
|
||||
self.memory_builder_model: LLMRequest | None = None
|
||||
|
||||
# 统计信息
|
||||
self.sample_count = 0
|
||||
self.success_count = 0
|
||||
self.last_sample_results: List[Dict[str, Any]] = []
|
||||
self.last_sample_results: list[dict[str, Any]] = []
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化采样器"""
|
||||
@@ -101,7 +99,7 @@ class HippocampusSampler:
|
||||
logger.error(f"❌ 海马体采样器初始化失败: {e}")
|
||||
raise
|
||||
|
||||
def generate_time_samples(self) -> List[datetime]:
|
||||
def generate_time_samples(self) -> list[datetime]:
|
||||
"""生成双峰分布的时间采样点"""
|
||||
# 计算每个分布的样本数
|
||||
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
|
||||
@@ -132,7 +130,7 @@ class HippocampusSampler:
|
||||
# 按时间排序(从最早到最近)
|
||||
return sorted(timestamps)
|
||||
|
||||
async def collect_message_samples(self, target_timestamp: float) -> Optional[List[Dict[str, Any]]]:
|
||||
async def collect_message_samples(self, target_timestamp: float) -> list[dict[str, Any]] | None:
|
||||
"""收集指定时间戳附近的消息样本"""
|
||||
try:
|
||||
# 随机时间窗口:5-30分钟
|
||||
@@ -190,7 +188,7 @@ class HippocampusSampler:
|
||||
logger.error(f"收集消息样本失败: {e}")
|
||||
return None
|
||||
|
||||
async def build_memory_from_samples(self, messages: List[Dict[str, Any]], target_timestamp: float) -> Optional[str]:
|
||||
async def build_memory_from_samples(self, messages: list[dict[str, Any]], target_timestamp: float) -> str | None:
|
||||
"""从消息样本构建记忆"""
|
||||
if not messages or not self.memory_system or not self.memory_builder_model:
|
||||
return None
|
||||
@@ -262,7 +260,7 @@ class HippocampusSampler:
|
||||
logger.error(f"海马体采样构建记忆失败: {e}")
|
||||
return None
|
||||
|
||||
async def perform_sampling_cycle(self) -> Dict[str, Any]:
|
||||
async def perform_sampling_cycle(self) -> dict[str, Any]:
|
||||
"""执行一次完整的采样周期(优化版:批量融合构建)"""
|
||||
if not self.should_sample():
|
||||
return {"status": "skipped", "reason": "interval_not_met"}
|
||||
@@ -363,7 +361,7 @@ class HippocampusSampler:
|
||||
"duration": time.time() - start_time,
|
||||
}
|
||||
|
||||
async def _collect_all_message_samples(self, time_samples: List[datetime]) -> List[List[Dict[str, Any]]]:
|
||||
async def _collect_all_message_samples(self, time_samples: list[datetime]) -> list[list[dict[str, Any]]]:
|
||||
"""批量收集所有时间点的消息样本"""
|
||||
collected_messages = []
|
||||
max_concurrent = min(5, len(time_samples)) # 提高并发数到5
|
||||
@@ -394,7 +392,7 @@ class HippocampusSampler:
|
||||
|
||||
return collected_messages
|
||||
|
||||
async def _fuse_and_deduplicate_messages(self, collected_messages: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]:
|
||||
async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
|
||||
"""融合和去重消息样本"""
|
||||
if not collected_messages:
|
||||
return []
|
||||
@@ -450,7 +448,7 @@ class HippocampusSampler:
|
||||
# 返回原始消息组作为备选
|
||||
return collected_messages[:5] # 限制返回数量
|
||||
|
||||
def _merge_adjacent_messages(self, messages: List[Dict[str, Any]], time_gap: int = 1800) -> List[List[Dict[str, Any]]]:
|
||||
def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]:
|
||||
"""合并时间间隔内的消息"""
|
||||
if not messages:
|
||||
return []
|
||||
@@ -481,7 +479,7 @@ class HippocampusSampler:
|
||||
|
||||
return result_groups
|
||||
|
||||
async def _build_batch_memory(self, fused_messages: List[List[Dict[str, Any]]], time_samples: List[datetime]) -> Dict[str, Any]:
|
||||
async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]:
|
||||
"""批量构建记忆"""
|
||||
if not fused_messages:
|
||||
return {"memory_count": 0, "memories": []}
|
||||
@@ -557,7 +555,7 @@ class HippocampusSampler:
|
||||
logger.error(f"批量构建记忆失败: {e}")
|
||||
return {"memory_count": 0, "error": str(e)}
|
||||
|
||||
async def _build_fused_conversation_text(self, fused_messages: List[List[Dict[str, Any]]]) -> str:
|
||||
async def _build_fused_conversation_text(self, fused_messages: list[list[dict[str, Any]]]) -> str:
|
||||
"""构建融合后的对话文本"""
|
||||
try:
|
||||
# 添加批次标识
|
||||
@@ -589,7 +587,7 @@ class HippocampusSampler:
|
||||
logger.error(f"构建融合文本失败: {e}")
|
||||
return ""
|
||||
|
||||
async def _fallback_individual_build(self, fused_messages: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
|
||||
async def _fallback_individual_build(self, fused_messages: list[list[dict[str, Any]]]) -> dict[str, Any]:
|
||||
"""备选方案:单独构建每个消息组"""
|
||||
total_memories = []
|
||||
total_count = 0
|
||||
@@ -609,7 +607,7 @@ class HippocampusSampler:
|
||||
"fallback_mode": True
|
||||
}
|
||||
|
||||
async def process_sample_timestamp(self, target_timestamp: float) -> Optional[str]:
|
||||
async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
|
||||
"""处理单个时间戳采样(保留作为备选方法)"""
|
||||
try:
|
||||
# 收集消息样本
|
||||
@@ -676,7 +674,7 @@ class HippocampusSampler:
|
||||
self.is_running = False
|
||||
logger.info("🛑 停止海马体后台采样任务")
|
||||
|
||||
def get_sampling_stats(self) -> Dict[str, Any]:
|
||||
def get_sampling_stats(self) -> dict[str, Any]:
|
||||
"""获取采样统计信息"""
|
||||
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
|
||||
|
||||
@@ -713,7 +711,7 @@ class HippocampusSampler:
|
||||
|
||||
|
||||
# 全局海马体采样器实例
|
||||
_hippocampus_sampler: Optional[HippocampusSampler] = None
|
||||
_hippocampus_sampler: HippocampusSampler | None = None
|
||||
|
||||
|
||||
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
|
||||
@@ -728,4 +726,4 @@ async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampl
|
||||
"""初始化全局海马体采样器"""
|
||||
sampler = get_hippocampus_sampler(memory_system)
|
||||
await sampler.initialize()
|
||||
return sampler
|
||||
return sampler
|
||||
|
||||
@@ -32,7 +32,7 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Type, TypeVar
|
||||
from typing import Any, TypeVar
|
||||
|
||||
E = TypeVar("E", bound=Enum)
|
||||
|
||||
@@ -503,7 +503,7 @@ class MemoryBuilder:
|
||||
logger.warning(f"无法解析未知的记忆类型 '{type_str}',回退到上下文类型")
|
||||
return MemoryType.CONTEXTUAL
|
||||
|
||||
def _parse_enum_value(self, enum_cls: Type[E], raw_value: Any, default: E, field_name: str) -> E:
|
||||
def _parse_enum_value(self, enum_cls: type[E], raw_value: Any, default: E, field_name: str) -> E:
|
||||
"""解析枚举值,兼容数字/字符串表示"""
|
||||
if isinstance(raw_value, enum_cls):
|
||||
return raw_value
|
||||
|
||||
@@ -556,11 +556,11 @@ class MemorySystem:
|
||||
context = dict(context or {})
|
||||
|
||||
# 获取配置的采样模式
|
||||
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'precision')
|
||||
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
|
||||
current_mode = MemorySamplingMode(sampling_mode)
|
||||
|
||||
|
||||
context['__sampling_mode'] = current_mode.value
|
||||
context["__sampling_mode"] = current_mode.value
|
||||
logger.debug(f"使用记忆采样模式: {current_mode.value}")
|
||||
|
||||
# 根据采样模式处理记忆
|
||||
@@ -636,7 +636,7 @@ class MemorySystem:
|
||||
|
||||
# 检查信息价值阈值
|
||||
value_score = await self._assess_information_value(conversation_text, normalized_context)
|
||||
threshold = getattr(global_config.memory, 'precision_memory_reply_threshold', 0.5)
|
||||
threshold = getattr(global_config.memory, "precision_memory_reply_threshold", 0.5)
|
||||
|
||||
if value_score < threshold:
|
||||
logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建")
|
||||
@@ -1614,8 +1614,8 @@ async def initialize_memory_system(llm_model: LLMRequest | None = None):
|
||||
await memory_system.initialize()
|
||||
|
||||
# 根据配置启动海马体采样
|
||||
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate')
|
||||
if sampling_mode in ['hippocampus', 'all']:
|
||||
sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "immediate")
|
||||
if sampling_mode in ["hippocampus", "all"]:
|
||||
memory_system.start_hippocampus_sampling()
|
||||
|
||||
return memory_system
|
||||
|
||||
Reference in New Issue
Block a user