style: 统一代码风格并进行现代化改进

对整个代码库进行了一次全面的风格统一和现代化改进。主要变更包括:

- 将 `hasattr` 等内置函数中的字符串参数从单引号 `'` 统一为双引号 `"`。
- 采用现代类型注解,例如将 `Optional[T]` 替换为 `T | None`,`List[T]` 替换为 `list[T]` 等。
- 移除不再需要的 Python 2 兼容性声明 `# -*- coding: utf-8 -*-`。
- 清理了多余的空行、注释和未使用的导入。
- 统一了文件末尾的换行符。
- 优化了部分日志输出和字符串格式化 (`f"{e!s}"`)。

这些改动旨在提升代码的可读性、一致性和可维护性,使其更符合现代 Python 编码规范。
This commit is contained in:
minecraft1024a
2025-10-05 13:21:27 +08:00
parent ad613a180b
commit 2c74b472ab
47 changed files with 274 additions and 287 deletions

8
bot.py
View File

@@ -103,7 +103,7 @@ async def graceful_shutdown(main_system_instance):
logger.info("正在优雅关闭麦麦...") logger.info("正在优雅关闭麦麦...")
# 停止MainSystem中的组件它会处理服务器等 # 停止MainSystem中的组件它会处理服务器等
if main_system_instance and hasattr(main_system_instance, 'shutdown'): if main_system_instance and hasattr(main_system_instance, "shutdown"):
logger.info("正在关闭MainSystem...") logger.info("正在关闭MainSystem...")
await main_system_instance.shutdown() await main_system_instance.shutdown()
@@ -111,7 +111,7 @@ async def graceful_shutdown(main_system_instance):
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
if hasattr(chat_manager, '_stop_auto_save'): if hasattr(chat_manager, "_stop_auto_save"):
logger.info("正在停止聊天管理器...") logger.info("正在停止聊天管理器...")
chat_manager._stop_auto_save() chat_manager._stop_auto_save()
except Exception as e: except Exception as e:
@@ -120,7 +120,7 @@ async def graceful_shutdown(main_system_instance):
# 停止情绪管理器 # 停止情绪管理器
try: try:
from src.mood.mood_manager import mood_manager from src.mood.mood_manager import mood_manager
if hasattr(mood_manager, 'stop'): if hasattr(mood_manager, "stop"):
logger.info("正在停止情绪管理器...") logger.info("正在停止情绪管理器...")
await mood_manager.stop() await mood_manager.stop()
except Exception as e: except Exception as e:
@@ -129,7 +129,7 @@ async def graceful_shutdown(main_system_instance):
# 停止记忆系统 # 停止记忆系统
try: try:
from src.chat.memory_system.memory_manager import memory_manager from src.chat.memory_system.memory_manager import memory_manager
if hasattr(memory_manager, 'shutdown'): if hasattr(memory_manager, "shutdown"):
logger.info("正在停止记忆系统...") logger.info("正在停止记忆系统...")
await memory_manager.shutdown() await memory_manager.shutdown()
except Exception as e: except Exception as e:

View File

@@ -7,4 +7,4 @@ __plugin_meta__ = PluginMetadata(
version="1.0.0", version="1.0.0",
author="Your Name", author="Your Name",
license="MIT", license="MIT",
) )

View File

@@ -7,4 +7,4 @@ __plugin_meta__ = PluginMetadata(
version="1.0.0", version="1.0.0",
author="Your Name", author="Your Name",
license="MIT", license="MIT",
) )

View File

@@ -1 +1 @@
# This file makes src/api a Python package. # This file makes src/api a Python package.

View File

@@ -3,10 +3,10 @@ from typing import Literal
from fastapi import APIRouter, HTTPException, Query from fastapi import APIRouter, HTTPException, Query
from src.config.config import global_config
from src.plugin_system.apis import message_api, chat_api, person_api
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config
from src.plugin_system.apis import message_api, person_api
logger = get_logger("HTTP消息API") logger = get_logger("HTTP消息API")
@@ -86,7 +86,7 @@ async def get_message_stats_by_chat(
if group_by_user: if group_by_user:
if user_id not in stats[chat_id]["user_stats"]: if user_id not in stats[chat_id]["user_stats"]:
stats[chat_id]["user_stats"][user_id] = 0 stats[chat_id]["user_stats"][user_id] = 0
stats[chat_id]["user_stats"][user_id] += 1 stats[chat_id]["user_stats"][user_id] += 1
if not group_by_user: if not group_by_user:
@@ -120,7 +120,7 @@ async def get_message_stats_by_chat(
"nickname": nickname, "nickname": nickname,
"count": count "count": count
} }
formatted_stats[chat_id] = formatted_data formatted_stats[chat_id] = formatted_data
return formatted_stats return formatted_stats
@@ -164,7 +164,7 @@ async def get_bot_message_stats_by_chat(
chat_name = stream.group_info.group_name chat_name = stream.group_info.group_name
elif stream.user_info and stream.user_info.user_nickname: elif stream.user_info and stream.user_info.user_nickname:
chat_name = stream.user_info.user_nickname chat_name = stream.user_info.user_nickname
formatted_stats[chat_id] = { formatted_stats[chat_id] = {
"chat_name": chat_name, "chat_name": chat_name,
"count": count "count": count
@@ -174,4 +174,4 @@ async def get_bot_message_stats_by_chat(
return stats return stats
except Exception as e: except Exception as e:
raise HTTPException(status_code=500, detail=str(e)) raise HTTPException(status_code=500, detail=str(e))

View File

@@ -112,7 +112,7 @@ class InterestManager:
# 返回默认结果 # 返回默认结果
return InterestCalculationResult( return InterestCalculationResult(
success=False, success=False,
message_id=getattr(message, 'message_id', ''), message_id=getattr(message, "message_id", ""),
interest_value=0.3, interest_value=0.3,
error_message="没有可用的兴趣值计算组件" error_message="没有可用的兴趣值计算组件"
) )
@@ -129,7 +129,7 @@ class InterestManager:
logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {getattr(message, 'message_id', '')} 使用默认兴趣值 0.5") logger.warning(f"兴趣值计算超时 ({timeout}s),消息 {getattr(message, 'message_id', '')} 使用默认兴趣值 0.5")
return InterestCalculationResult( return InterestCalculationResult(
success=True, success=True,
message_id=getattr(message, 'message_id', ''), message_id=getattr(message, "message_id", ""),
interest_value=0.5, # 固定默认兴趣值 interest_value=0.5, # 固定默认兴趣值
should_reply=False, should_reply=False,
should_act=False, should_act=False,
@@ -140,9 +140,9 @@ class InterestManager:
logger.error(f"兴趣值计算异常: {e}") logger.error(f"兴趣值计算异常: {e}")
return InterestCalculationResult( return InterestCalculationResult(
success=False, success=False,
message_id=getattr(message, 'message_id', ''), message_id=getattr(message, "message_id", ""),
interest_value=0.3, interest_value=0.3,
error_message=f"计算异常: {str(e)}" error_message=f"计算异常: {e!s}"
) )
async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult: async def _async_calculate(self, message: "DatabaseMessages") -> InterestCalculationResult:
@@ -168,9 +168,9 @@ class InterestManager:
logger.error(f"兴趣值计算异常: {e}", exc_info=True) logger.error(f"兴趣值计算异常: {e}", exc_info=True)
return InterestCalculationResult( return InterestCalculationResult(
success=False, success=False,
message_id=getattr(message, 'message_id', ''), message_id=getattr(message, "message_id", ""),
interest_value=0.0, interest_value=0.0,
error_message=f"计算异常: {str(e)}", error_message=f"计算异常: {e!s}",
calculation_time=time.time() - start_time calculation_time=time.time() - start_time
) )
@@ -245,4 +245,4 @@ def get_interest_manager() -> InterestManager:
global _interest_manager global _interest_manager
if _interest_manager is None: if _interest_manager is None:
_interest_manager = InterestManager() _interest_manager = InterestManager()
return _interest_manager return _interest_manager

View File

@@ -1,4 +1,3 @@
# -*- coding: utf-8 -*-
""" """
海马体双峰分布采样器 海马体双峰分布采样器
基于旧版海马体的采样策略,适配新版记忆系统 基于旧版海马体的采样策略,适配新版记忆系统
@@ -8,16 +7,15 @@
import asyncio import asyncio
import random import random
import time import time
from datetime import datetime, timedelta
from typing import List, Optional, Tuple, Dict, Any
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime, timedelta
from typing import Any
import numpy as np import numpy as np
import orjson
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_by_timestamp,
build_readable_messages, build_readable_messages,
get_raw_msg_by_timestamp,
get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat,
) )
from src.chat.utils.utils import translate_timestamp_to_human_readable from src.chat.utils.utils import translate_timestamp_to_human_readable
@@ -47,7 +45,7 @@ class HippocampusSampleConfig:
batch_size: int = 5 # 批处理大小 batch_size: int = 5 # 批处理大小
@classmethod @classmethod
def from_global_config(cls) -> 'HippocampusSampleConfig': def from_global_config(cls) -> "HippocampusSampleConfig":
"""从全局配置创建海马体采样配置""" """从全局配置创建海马体采样配置"""
config = global_config.memory.hippocampus_distribution_config config = global_config.memory.hippocampus_distribution_config
return cls( return cls(
@@ -74,12 +72,12 @@ class HippocampusSampler:
self.is_running = False self.is_running = False
# 记忆构建模型 # 记忆构建模型
self.memory_builder_model: Optional[LLMRequest] = None self.memory_builder_model: LLMRequest | None = None
# 统计信息 # 统计信息
self.sample_count = 0 self.sample_count = 0
self.success_count = 0 self.success_count = 0
self.last_sample_results: List[Dict[str, Any]] = [] self.last_sample_results: list[dict[str, Any]] = []
async def initialize(self): async def initialize(self):
"""初始化采样器""" """初始化采样器"""
@@ -101,7 +99,7 @@ class HippocampusSampler:
logger.error(f"❌ 海马体采样器初始化失败: {e}") logger.error(f"❌ 海马体采样器初始化失败: {e}")
raise raise
def generate_time_samples(self) -> List[datetime]: def generate_time_samples(self) -> list[datetime]:
"""生成双峰分布的时间采样点""" """生成双峰分布的时间采样点"""
# 计算每个分布的样本数 # 计算每个分布的样本数
recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight)) recent_samples = max(1, int(self.config.total_samples * self.config.recent_weight))
@@ -132,7 +130,7 @@ class HippocampusSampler:
# 按时间排序(从最早到最近) # 按时间排序(从最早到最近)
return sorted(timestamps) return sorted(timestamps)
async def collect_message_samples(self, target_timestamp: float) -> Optional[List[Dict[str, Any]]]: async def collect_message_samples(self, target_timestamp: float) -> list[dict[str, Any]] | None:
"""收集指定时间戳附近的消息样本""" """收集指定时间戳附近的消息样本"""
try: try:
# 随机时间窗口5-30分钟 # 随机时间窗口5-30分钟
@@ -190,7 +188,7 @@ class HippocampusSampler:
logger.error(f"收集消息样本失败: {e}") logger.error(f"收集消息样本失败: {e}")
return None return None
async def build_memory_from_samples(self, messages: List[Dict[str, Any]], target_timestamp: float) -> Optional[str]: async def build_memory_from_samples(self, messages: list[dict[str, Any]], target_timestamp: float) -> str | None:
"""从消息样本构建记忆""" """从消息样本构建记忆"""
if not messages or not self.memory_system or not self.memory_builder_model: if not messages or not self.memory_system or not self.memory_builder_model:
return None return None
@@ -262,7 +260,7 @@ class HippocampusSampler:
logger.error(f"海马体采样构建记忆失败: {e}") logger.error(f"海马体采样构建记忆失败: {e}")
return None return None
async def perform_sampling_cycle(self) -> Dict[str, Any]: async def perform_sampling_cycle(self) -> dict[str, Any]:
"""执行一次完整的采样周期(优化版:批量融合构建)""" """执行一次完整的采样周期(优化版:批量融合构建)"""
if not self.should_sample(): if not self.should_sample():
return {"status": "skipped", "reason": "interval_not_met"} return {"status": "skipped", "reason": "interval_not_met"}
@@ -363,7 +361,7 @@ class HippocampusSampler:
"duration": time.time() - start_time, "duration": time.time() - start_time,
} }
async def _collect_all_message_samples(self, time_samples: List[datetime]) -> List[List[Dict[str, Any]]]: async def _collect_all_message_samples(self, time_samples: list[datetime]) -> list[list[dict[str, Any]]]:
"""批量收集所有时间点的消息样本""" """批量收集所有时间点的消息样本"""
collected_messages = [] collected_messages = []
max_concurrent = min(5, len(time_samples)) # 提高并发数到5 max_concurrent = min(5, len(time_samples)) # 提高并发数到5
@@ -394,7 +392,7 @@ class HippocampusSampler:
return collected_messages return collected_messages
async def _fuse_and_deduplicate_messages(self, collected_messages: List[List[Dict[str, Any]]]) -> List[List[Dict[str, Any]]]: async def _fuse_and_deduplicate_messages(self, collected_messages: list[list[dict[str, Any]]]) -> list[list[dict[str, Any]]]:
"""融合和去重消息样本""" """融合和去重消息样本"""
if not collected_messages: if not collected_messages:
return [] return []
@@ -450,7 +448,7 @@ class HippocampusSampler:
# 返回原始消息组作为备选 # 返回原始消息组作为备选
return collected_messages[:5] # 限制返回数量 return collected_messages[:5] # 限制返回数量
def _merge_adjacent_messages(self, messages: List[Dict[str, Any]], time_gap: int = 1800) -> List[List[Dict[str, Any]]]: def _merge_adjacent_messages(self, messages: list[dict[str, Any]], time_gap: int = 1800) -> list[list[dict[str, Any]]]:
"""合并时间间隔内的消息""" """合并时间间隔内的消息"""
if not messages: if not messages:
return [] return []
@@ -481,7 +479,7 @@ class HippocampusSampler:
return result_groups return result_groups
async def _build_batch_memory(self, fused_messages: List[List[Dict[str, Any]]], time_samples: List[datetime]) -> Dict[str, Any]: async def _build_batch_memory(self, fused_messages: list[list[dict[str, Any]]], time_samples: list[datetime]) -> dict[str, Any]:
"""批量构建记忆""" """批量构建记忆"""
if not fused_messages: if not fused_messages:
return {"memory_count": 0, "memories": []} return {"memory_count": 0, "memories": []}
@@ -557,7 +555,7 @@ class HippocampusSampler:
logger.error(f"批量构建记忆失败: {e}") logger.error(f"批量构建记忆失败: {e}")
return {"memory_count": 0, "error": str(e)} return {"memory_count": 0, "error": str(e)}
async def _build_fused_conversation_text(self, fused_messages: List[List[Dict[str, Any]]]) -> str: async def _build_fused_conversation_text(self, fused_messages: list[list[dict[str, Any]]]) -> str:
"""构建融合后的对话文本""" """构建融合后的对话文本"""
try: try:
# 添加批次标识 # 添加批次标识
@@ -589,7 +587,7 @@ class HippocampusSampler:
logger.error(f"构建融合文本失败: {e}") logger.error(f"构建融合文本失败: {e}")
return "" return ""
async def _fallback_individual_build(self, fused_messages: List[List[Dict[str, Any]]]) -> Dict[str, Any]: async def _fallback_individual_build(self, fused_messages: list[list[dict[str, Any]]]) -> dict[str, Any]:
"""备选方案:单独构建每个消息组""" """备选方案:单独构建每个消息组"""
total_memories = [] total_memories = []
total_count = 0 total_count = 0
@@ -609,7 +607,7 @@ class HippocampusSampler:
"fallback_mode": True "fallback_mode": True
} }
async def process_sample_timestamp(self, target_timestamp: float) -> Optional[str]: async def process_sample_timestamp(self, target_timestamp: float) -> str | None:
"""处理单个时间戳采样(保留作为备选方法)""" """处理单个时间戳采样(保留作为备选方法)"""
try: try:
# 收集消息样本 # 收集消息样本
@@ -676,7 +674,7 @@ class HippocampusSampler:
self.is_running = False self.is_running = False
logger.info("🛑 停止海马体后台采样任务") logger.info("🛑 停止海马体后台采样任务")
def get_sampling_stats(self) -> Dict[str, Any]: def get_sampling_stats(self) -> dict[str, Any]:
"""获取采样统计信息""" """获取采样统计信息"""
success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0 success_rate = (self.success_count / self.sample_count * 100) if self.sample_count > 0 else 0
@@ -713,7 +711,7 @@ class HippocampusSampler:
# 全局海马体采样器实例 # 全局海马体采样器实例
_hippocampus_sampler: Optional[HippocampusSampler] = None _hippocampus_sampler: HippocampusSampler | None = None
def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler: def get_hippocampus_sampler(memory_system=None) -> HippocampusSampler:
@@ -728,4 +726,4 @@ async def initialize_hippocampus_sampler(memory_system=None) -> HippocampusSampl
"""初始化全局海马体采样器""" """初始化全局海马体采样器"""
sampler = get_hippocampus_sampler(memory_system) sampler = get_hippocampus_sampler(memory_system)
await sampler.initialize() await sampler.initialize()
return sampler return sampler

View File

@@ -32,7 +32,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Type, TypeVar from typing import Any, TypeVar
E = TypeVar("E", bound=Enum) E = TypeVar("E", bound=Enum)
@@ -503,7 +503,7 @@ class MemoryBuilder:
logger.warning(f"无法解析未知的记忆类型 '{type_str}',回退到上下文类型") logger.warning(f"无法解析未知的记忆类型 '{type_str}',回退到上下文类型")
return MemoryType.CONTEXTUAL return MemoryType.CONTEXTUAL
def _parse_enum_value(self, enum_cls: Type[E], raw_value: Any, default: E, field_name: str) -> E: def _parse_enum_value(self, enum_cls: type[E], raw_value: Any, default: E, field_name: str) -> E:
"""解析枚举值,兼容数字/字符串表示""" """解析枚举值,兼容数字/字符串表示"""
if isinstance(raw_value, enum_cls): if isinstance(raw_value, enum_cls):
return raw_value return raw_value

View File

@@ -556,11 +556,11 @@ class MemorySystem:
context = dict(context or {}) context = dict(context or {})
# 获取配置的采样模式 # 获取配置的采样模式
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'precision') sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "precision")
current_mode = MemorySamplingMode(sampling_mode) current_mode = MemorySamplingMode(sampling_mode)
context['__sampling_mode'] = current_mode.value context["__sampling_mode"] = current_mode.value
logger.debug(f"使用记忆采样模式: {current_mode.value}") logger.debug(f"使用记忆采样模式: {current_mode.value}")
# 根据采样模式处理记忆 # 根据采样模式处理记忆
@@ -636,7 +636,7 @@ class MemorySystem:
# 检查信息价值阈值 # 检查信息价值阈值
value_score = await self._assess_information_value(conversation_text, normalized_context) value_score = await self._assess_information_value(conversation_text, normalized_context)
threshold = getattr(global_config.memory, 'precision_memory_reply_threshold', 0.5) threshold = getattr(global_config.memory, "precision_memory_reply_threshold", 0.5)
if value_score < threshold: if value_score < threshold:
logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建") logger.debug(f"信息价值评分 {value_score:.2f} 低于阈值 {threshold},跳过记忆构建")
@@ -1614,8 +1614,8 @@ async def initialize_memory_system(llm_model: LLMRequest | None = None):
await memory_system.initialize() await memory_system.initialize()
# 根据配置启动海马体采样 # 根据配置启动海马体采样
sampling_mode = getattr(global_config.memory, 'memory_sampling_mode', 'immediate') sampling_mode = getattr(global_config.memory, "memory_sampling_mode", "immediate")
if sampling_mode in ['hippocampus', 'all']: if sampling_mode in ["hippocampus", "all"]:
memory_system.start_hippocampus_sampling() memory_system.start_hippocampus_sampling()
return memory_system return memory_system

View File

@@ -4,14 +4,13 @@
""" """
import asyncio import asyncio
import psutil
import time import time
from typing import Dict, List, Optional, Set, Tuple
from dataclasses import dataclass, field from dataclasses import dataclass, field
from enum import Enum from enum import Enum
import psutil
from src.common.logger import get_logger from src.common.logger import get_logger
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("adaptive_stream_manager") logger = get_logger("adaptive_stream_manager")
@@ -71,16 +70,16 @@ class AdaptiveStreamManager:
# 当前状态 # 当前状态
self.current_limit = base_concurrent_limit self.current_limit = base_concurrent_limit
self.active_streams: Set[str] = set() self.active_streams: set[str] = set()
self.pending_streams: Set[str] = set() self.pending_streams: set[str] = set()
self.stream_metrics: Dict[str, StreamMetrics] = {} self.stream_metrics: dict[str, StreamMetrics] = {}
# 异步信号量 # 异步信号量
self.semaphore = asyncio.Semaphore(base_concurrent_limit) self.semaphore = asyncio.Semaphore(base_concurrent_limit)
self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量 self.priority_semaphore = asyncio.Semaphore(5) # 高优先级专用信号量
# 系统监控 # 系统监控
self.system_metrics: List[SystemMetrics] = [] self.system_metrics: list[SystemMetrics] = []
self.last_adjustment_time = 0.0 self.last_adjustment_time = 0.0
# 统计信息 # 统计信息
@@ -95,8 +94,8 @@ class AdaptiveStreamManager:
} }
# 监控任务 # 监控任务
self.monitor_task: Optional[asyncio.Task] = None self.monitor_task: asyncio.Task | None = None
self.adjustment_task: Optional[asyncio.Task] = None self.adjustment_task: asyncio.Task | None = None
self.is_running = False self.is_running = False
logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})") logger.info(f"自适应流管理器初始化完成 (base_limit={base_concurrent_limit}, max_limit={max_concurrent_limit})")
@@ -443,7 +442,7 @@ class AdaptiveStreamManager:
if hasattr(metrics, key): if hasattr(metrics, key):
setattr(metrics, key, value) setattr(metrics, key, value)
def get_stats(self) -> Dict: def get_stats(self) -> dict:
"""获取统计信息""" """获取统计信息"""
stats = self.stats.copy() stats = self.stats.copy()
stats.update({ stats.update({
@@ -465,7 +464,7 @@ class AdaptiveStreamManager:
# 全局自适应管理器实例 # 全局自适应管理器实例
_adaptive_manager: Optional[AdaptiveStreamManager] = None _adaptive_manager: AdaptiveStreamManager | None = None
def get_adaptive_stream_manager() -> AdaptiveStreamManager: def get_adaptive_stream_manager() -> AdaptiveStreamManager:
@@ -485,4 +484,4 @@ async def init_adaptive_stream_manager():
async def shutdown_adaptive_stream_manager(): async def shutdown_adaptive_stream_manager():
"""关闭自适应流管理器""" """关闭自适应流管理器"""
manager = get_adaptive_stream_manager() manager = get_adaptive_stream_manager()
await manager.stop() await manager.stop()

View File

@@ -5,9 +5,9 @@
import asyncio import asyncio
import time import time
from typing import Any, Dict, List, Optional
from dataclasses import dataclass, field
from collections import defaultdict from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any
from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams from src.common.database.sqlalchemy_models import ChatStreams
@@ -21,7 +21,7 @@ logger = get_logger("batch_database_writer")
class StreamUpdatePayload: class StreamUpdatePayload:
"""流更新数据结构""" """流更新数据结构"""
stream_id: str stream_id: str
update_data: Dict[str, Any] update_data: dict[str, Any]
priority: int = 0 # 优先级,数字越大优先级越高 priority: int = 0 # 优先级,数字越大优先级越高
timestamp: float = field(default_factory=time.time) timestamp: float = field(default_factory=time.time)
@@ -47,7 +47,7 @@ class BatchDatabaseWriter:
# 运行状态 # 运行状态
self.is_running = False self.is_running = False
self.writer_task: Optional[asyncio.Task] = None self.writer_task: asyncio.Task | None = None
# 统计信息 # 统计信息
self.stats = { self.stats = {
@@ -60,7 +60,7 @@ class BatchDatabaseWriter:
} }
# 按优先级分类的批次 # 按优先级分类的批次
self.priority_batches: Dict[int, List[StreamUpdatePayload]] = defaultdict(list) self.priority_batches: dict[int, list[StreamUpdatePayload]] = defaultdict(list)
logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)") logger.info(f"批量数据库写入器初始化完成 (batch_size={batch_size}, interval={flush_interval}s)")
@@ -98,7 +98,7 @@ class BatchDatabaseWriter:
async def schedule_stream_update( async def schedule_stream_update(
self, self,
stream_id: str, stream_id: str,
update_data: Dict[str, Any], update_data: dict[str, Any],
priority: int = 0 priority: int = 0
) -> bool: ) -> bool:
""" """
@@ -166,7 +166,7 @@ class BatchDatabaseWriter:
await self._flush_all_batches() await self._flush_all_batches()
logger.info("批量写入循环结束") logger.info("批量写入循环结束")
async def _collect_batch(self) -> List[StreamUpdatePayload]: async def _collect_batch(self) -> list[StreamUpdatePayload]:
"""收集一个批次的数据""" """收集一个批次的数据"""
batch = [] batch = []
deadline = time.time() + self.flush_interval deadline = time.time() + self.flush_interval
@@ -189,7 +189,7 @@ class BatchDatabaseWriter:
return batch return batch
async def _write_batch(self, batch: List[StreamUpdatePayload]): async def _write_batch(self, batch: list[StreamUpdatePayload]):
"""批量写入数据库""" """批量写入数据库"""
if not batch: if not batch:
return return
@@ -228,7 +228,7 @@ class BatchDatabaseWriter:
except Exception as single_e: except Exception as single_e:
logger.error(f"单个写入也失败: {single_e}") logger.error(f"单个写入也失败: {single_e}")
async def _batch_write_to_database(self, payloads: List[StreamUpdatePayload]): async def _batch_write_to_database(self, payloads: list[StreamUpdatePayload]):
"""批量写入数据库""" """批量写入数据库"""
async with get_db_session() as session: async with get_db_session() as session:
for payload in payloads: for payload in payloads:
@@ -268,7 +268,7 @@ class BatchDatabaseWriter:
await session.commit() await session.commit()
async def _direct_write(self, stream_id: str, update_data: Dict[str, Any]): async def _direct_write(self, stream_id: str, update_data: dict[str, Any]):
"""直接写入数据库(降级方案)""" """直接写入数据库(降级方案)"""
async with get_db_session() as session: async with get_db_session() as session:
if global_config.database.database_type == "sqlite": if global_config.database.database_type == "sqlite":
@@ -315,7 +315,7 @@ class BatchDatabaseWriter:
if remaining_batch: if remaining_batch:
await self._write_batch(remaining_batch) await self._write_batch(remaining_batch)
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> dict[str, Any]:
"""获取统计信息""" """获取统计信息"""
stats = self.stats.copy() stats = self.stats.copy()
stats["is_running"] = self.is_running stats["is_running"] = self.is_running
@@ -324,7 +324,7 @@ class BatchDatabaseWriter:
# 全局批量写入器实例 # 全局批量写入器实例
_batch_writer: Optional[BatchDatabaseWriter] = None _batch_writer: BatchDatabaseWriter | None = None
def get_batch_writer() -> BatchDatabaseWriter: def get_batch_writer() -> BatchDatabaseWriter:
@@ -344,4 +344,4 @@ async def init_batch_writer():
async def shutdown_batch_writer(): async def shutdown_batch_writer():
"""关闭批量写入器""" """关闭批量写入器"""
writer = get_batch_writer() writer = get_batch_writer()
await writer.stop() await writer.stop()

View File

@@ -117,7 +117,7 @@ class StreamLoopManager:
# 使用自适应流管理器获取槽位 # 使用自适应流管理器获取槽位
use_adaptive = False use_adaptive = False
try: try:
from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager, StreamPriority from src.chat.message_manager.adaptive_stream_manager import get_adaptive_stream_manager
adaptive_manager = get_adaptive_stream_manager() adaptive_manager = get_adaptive_stream_manager()
if adaptive_manager.is_running: if adaptive_manager.is_running:
@@ -137,7 +137,7 @@ class StreamLoopManager:
else: else:
logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案") logger.debug(f"自适应管理器拒绝槽位请求: {stream_id},尝试回退方案")
else: else:
logger.debug(f"自适应管理器未运行,使用原始方法") logger.debug("自适应管理器未运行,使用原始方法")
except Exception as e: except Exception as e:
logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}") logger.debug(f"自适应管理器获取槽位失败,使用原始方法: {e}")

View File

@@ -5,13 +5,13 @@
import asyncio import asyncio
import time import time
from typing import Dict, List, Optional, Set
from dataclasses import dataclass
from collections import OrderedDict from collections import OrderedDict
from dataclasses import dataclass
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream from src.chat.message_receive.optimized_chat_stream import OptimizedChatStream, create_optimized_chat_stream
from src.common.logger import get_logger
logger = get_logger("stream_cache_manager") logger = get_logger("stream_cache_manager")
@@ -52,14 +52,14 @@ class TieredStreamCache:
# 三层缓存存储 # 三层缓存存储
self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据LRU self.hot_cache: OrderedDict[str, OptimizedChatStream] = OrderedDict() # 热数据LRU
self.warm_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间) self.warm_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 温数据(最后访问时间)
self.cold_storage: Dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间) self.cold_storage: dict[str, tuple[OptimizedChatStream, float]] = {} # 冷数据(最后访问时间)
# 统计信息 # 统计信息
self.stats = StreamCacheStats() self.stats = StreamCacheStats()
# 清理任务 # 清理任务
self.cleanup_task: Optional[asyncio.Task] = None self.cleanup_task: asyncio.Task | None = None
self.is_running = False self.is_running = False
logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})") logger.info(f"分层流缓存管理器初始化完成 (hot:{max_hot_size}, warm:{max_warm_size}, cold:{max_cold_size})")
@@ -96,8 +96,8 @@ class TieredStreamCache:
stream_id: str, stream_id: str,
platform: str, platform: str,
user_info: UserInfo, user_info: UserInfo,
group_info: Optional[GroupInfo] = None, group_info: GroupInfo | None = None,
data: Optional[Dict] = None, data: dict | None = None,
) -> OptimizedChatStream: ) -> OptimizedChatStream:
"""获取或创建流 - 优化版本""" """获取或创建流 - 优化版本"""
current_time = time.time() current_time = time.time()
@@ -255,7 +255,7 @@ class TieredStreamCache:
hot_to_demote = [] hot_to_demote = []
for stream_id, stream in self.hot_cache.items(): for stream_id, stream in self.hot_cache.items():
# 获取最后访问时间(简化:使用创建时间作为近似) # 获取最后访问时间(简化:使用创建时间作为近似)
last_access = getattr(stream, 'last_active_time', stream.create_time) last_access = getattr(stream, "last_active_time", stream.create_time)
if current_time - last_access > self.hot_timeout: if current_time - last_access > self.hot_timeout:
hot_to_demote.append(stream_id) hot_to_demote.append(stream_id)
@@ -341,7 +341,7 @@ class TieredStreamCache:
logger.info("所有缓存已清空") logger.info("所有缓存已清空")
async def get_stream_snapshot(self, stream_id: str) -> Optional[OptimizedChatStream]: async def get_stream_snapshot(self, stream_id: str) -> OptimizedChatStream | None:
"""获取流的快照(不修改缓存状态)""" """获取流的快照(不修改缓存状态)"""
if stream_id in self.hot_cache: if stream_id in self.hot_cache:
return self.hot_cache[stream_id].create_snapshot() return self.hot_cache[stream_id].create_snapshot()
@@ -351,13 +351,13 @@ class TieredStreamCache:
return self.cold_storage[stream_id][0].create_snapshot() return self.cold_storage[stream_id][0].create_snapshot()
return None return None
def get_cached_stream_ids(self) -> Set[str]: def get_cached_stream_ids(self) -> set[str]:
"""获取所有缓存的流ID""" """获取所有缓存的流ID"""
return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys()) return set(self.hot_cache.keys()) | set(self.warm_storage.keys()) | set(self.cold_storage.keys())
# 全局缓存管理器实例 # 全局缓存管理器实例
_cache_manager: Optional[TieredStreamCache] = None _cache_manager: TieredStreamCache | None = None
def get_stream_cache_manager() -> TieredStreamCache: def get_stream_cache_manager() -> TieredStreamCache:
@@ -377,4 +377,4 @@ async def init_stream_cache_manager():
async def shutdown_stream_cache_manager(): async def shutdown_stream_cache_manager():
"""关闭流缓存管理器""" """关闭流缓存管理器"""
manager = get_stream_cache_manager() manager = get_stream_cache_manager()
await manager.stop() await manager.stop()

View File

@@ -313,11 +313,11 @@ class ChatStream:
except Exception as e: except Exception as e:
logger.error(f"计算消息兴趣值失败: {e}", exc_info=True) logger.error(f"计算消息兴趣值失败: {e}", exc_info=True)
# 异常情况下使用默认值 # 异常情况下使用默认值
if hasattr(db_message, 'interest_value'): if hasattr(db_message, "interest_value"):
db_message.interest_value = 0.3 db_message.interest_value = 0.3
if hasattr(db_message, 'should_reply'): if hasattr(db_message, "should_reply"):
db_message.should_reply = False db_message.should_reply = False
if hasattr(db_message, 'should_act'): if hasattr(db_message, "should_act"):
db_message.should_act = False db_message.should_act = False
def _extract_reply_from_segment(self, segment) -> str | None: def _extract_reply_from_segment(self, segment) -> str | None:
@@ -894,10 +894,10 @@ def _convert_to_original_stream(self, optimized_stream) -> "ChatStream":
original_stream.saved = optimized_stream.saved original_stream.saved = optimized_stream.saved
# 复制上下文信息(如果存在) # 复制上下文信息(如果存在)
if hasattr(optimized_stream, '_stream_context') and optimized_stream._stream_context: if hasattr(optimized_stream, "_stream_context") and optimized_stream._stream_context:
original_stream.stream_context = optimized_stream._stream_context original_stream.stream_context = optimized_stream._stream_context
if hasattr(optimized_stream, '_context_manager') and optimized_stream._context_manager: if hasattr(optimized_stream, "_context_manager") and optimized_stream._context_manager:
original_stream.context_manager = optimized_stream._context_manager original_stream.context_manager = optimized_stream._context_manager
return original_stream return original_stream

View File

@@ -3,17 +3,12 @@
避免不必要的深拷贝开销,提升多流并发性能 避免不必要的深拷贝开销,提升多流并发性能
""" """
import asyncio
import copy
import hashlib
import time import time
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import TYPE_CHECKING, Any
from maim_message import GroupInfo, UserInfo from maim_message import GroupInfo, UserInfo
from rich.traceback import install from rich.traceback import install
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config from src.config.config import global_config
@@ -28,7 +23,7 @@ logger = get_logger("optimized_chat_stream")
class SharedContext: class SharedContext:
"""共享上下文数据 - 只读数据结构""" """共享上下文数据 - 只读数据结构"""
def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None): def __init__(self, stream_id: str, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None):
self.stream_id = stream_id self.stream_id = stream_id
self.platform = platform self.platform = platform
self.user_info = user_info self.user_info = user_info
@@ -37,7 +32,7 @@ class SharedContext:
self._frozen = True self._frozen = True
def __setattr__(self, name, value): def __setattr__(self, name, value):
if hasattr(self, '_frozen') and self._frozen and name not in ['_frozen']: if hasattr(self, "_frozen") and self._frozen and name not in ["_frozen"]:
raise AttributeError(f"SharedContext is frozen, cannot modify {name}") raise AttributeError(f"SharedContext is frozen, cannot modify {name}")
super().__setattr__(name, value) super().__setattr__(name, value)
@@ -46,7 +41,7 @@ class LocalChanges:
"""本地修改跟踪器""" """本地修改跟踪器"""
def __init__(self): def __init__(self):
self._changes: Dict[str, Any] = {} self._changes: dict[str, Any] = {}
self._dirty = False self._dirty = False
def set_change(self, key: str, value: Any): def set_change(self, key: str, value: Any):
@@ -62,7 +57,7 @@ class LocalChanges:
"""是否有修改""" """是否有修改"""
return self._dirty return self._dirty
def get_changes(self) -> Dict[str, Any]: def get_changes(self) -> dict[str, Any]:
"""获取所有修改""" """获取所有修改"""
return self._changes.copy() return self._changes.copy()
@@ -80,8 +75,8 @@ class OptimizedChatStream:
stream_id: str, stream_id: str,
platform: str, platform: str,
user_info: UserInfo, user_info: UserInfo,
group_info: Optional[GroupInfo] = None, group_info: GroupInfo | None = None,
data: Optional[Dict] = None, data: dict | None = None,
): ):
# 共享的只读数据 # 共享的只读数据
self._shared_context = SharedContext( self._shared_context = SharedContext(
@@ -129,42 +124,42 @@ class OptimizedChatStream:
"""修改用户信息时触发写时复制""" """修改用户信息时触发写时复制"""
self._ensure_copy_on_write() self._ensure_copy_on_write()
# 由于SharedContext是frozen的我们需要在本地修改中记录 # 由于SharedContext是frozen的我们需要在本地修改中记录
self._local_changes.set_change('user_info', value) self._local_changes.set_change("user_info", value)
@property @property
def group_info(self) -> Optional[GroupInfo]: def group_info(self) -> GroupInfo | None:
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes: if self._local_changes.has_changes() and "group_info" in self._local_changes._changes:
return self._local_changes.get_change('group_info') return self._local_changes.get_change("group_info")
return self._shared_context.group_info return self._shared_context.group_info
@group_info.setter @group_info.setter
def group_info(self, value: Optional[GroupInfo]): def group_info(self, value: GroupInfo | None):
"""修改群组信息时触发写时复制""" """修改群组信息时触发写时复制"""
self._ensure_copy_on_write() self._ensure_copy_on_write()
self._local_changes.set_change('group_info', value) self._local_changes.set_change("group_info", value)
@property @property
def create_time(self) -> float: def create_time(self) -> float:
if self._local_changes.has_changes() and 'create_time' in self._local_changes._changes: if self._local_changes.has_changes() and "create_time" in self._local_changes._changes:
return self._local_changes.get_change('create_time') return self._local_changes.get_change("create_time")
return self._shared_context.create_time return self._shared_context.create_time
@property @property
def last_active_time(self) -> float: def last_active_time(self) -> float:
return self._local_changes.get_change('last_active_time', self.create_time) return self._local_changes.get_change("last_active_time", self.create_time)
@last_active_time.setter @last_active_time.setter
def last_active_time(self, value: float): def last_active_time(self, value: float):
self._local_changes.set_change('last_active_time', value) self._local_changes.set_change("last_active_time", value)
self.saved = False self.saved = False
@property @property
def sleep_pressure(self) -> float: def sleep_pressure(self) -> float:
return self._local_changes.get_change('sleep_pressure', 0.0) return self._local_changes.get_change("sleep_pressure", 0.0)
@sleep_pressure.setter @sleep_pressure.setter
def sleep_pressure(self, value: float): def sleep_pressure(self, value: float):
self._local_changes.set_change('sleep_pressure', value) self._local_changes.set_change("sleep_pressure", value)
self.saved = False self.saved = False
def _ensure_copy_on_write(self): def _ensure_copy_on_write(self):
@@ -176,14 +171,14 @@ class OptimizedChatStream:
def _get_effective_user_info(self) -> UserInfo: def _get_effective_user_info(self) -> UserInfo:
"""获取有效的用户信息""" """获取有效的用户信息"""
if self._local_changes.has_changes() and 'user_info' in self._local_changes._changes: if self._local_changes.has_changes() and "user_info" in self._local_changes._changes:
return self._local_changes.get_change('user_info') return self._local_changes.get_change("user_info")
return self._shared_context.user_info return self._shared_context.user_info
def _get_effective_group_info(self) -> Optional[GroupInfo]: def _get_effective_group_info(self) -> GroupInfo | None:
"""获取有效的群组信息""" """获取有效的群组信息"""
if self._local_changes.has_changes() and 'group_info' in self._local_changes._changes: if self._local_changes.has_changes() and "group_info" in self._local_changes._changes:
return self._local_changes.get_change('group_info') return self._local_changes.get_change("group_info")
return self._shared_context.group_info return self._shared_context.group_info
def update_active_time(self): def update_active_time(self):
@@ -199,6 +194,7 @@ class OptimizedChatStream:
# 将MessageRecv转换为DatabaseMessages并设置到stream_context # 将MessageRecv转换为DatabaseMessages并设置到stream_context
import json import json
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
message_info = getattr(message, "message_info", {}) message_info = getattr(message, "message_info", {})
@@ -298,7 +294,7 @@ class OptimizedChatStream:
self._create_stream_context() self._create_stream_context()
return self._context_manager return self._context_manager
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""转换为字典格式 - 考虑本地修改""" """转换为字典格式 - 考虑本地修改"""
user_info = self._get_effective_user_info() user_info = self._get_effective_user_info()
group_info = self._get_effective_group_info() group_info = self._get_effective_group_info()
@@ -319,7 +315,7 @@ class OptimizedChatStream:
} }
@classmethod @classmethod
def from_dict(cls, data: Dict) -> "OptimizedChatStream": def from_dict(cls, data: dict) -> "OptimizedChatStream":
"""从字典创建实例""" """从字典创建实例"""
user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None user_info = UserInfo.from_dict(data.get("user_info", {})) if data.get("user_info") else None
group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None group_info = GroupInfo.from_dict(data.get("group_info", {})) if data.get("group_info") else None
@@ -481,8 +477,8 @@ def create_optimized_chat_stream(
stream_id: str, stream_id: str,
platform: str, platform: str,
user_info: UserInfo, user_info: UserInfo,
group_info: Optional[GroupInfo] = None, group_info: GroupInfo | None = None,
data: Optional[Dict] = None, data: dict | None = None,
) -> OptimizedChatStream: ) -> OptimizedChatStream:
"""创建优化版聊天流实例""" """创建优化版聊天流实例"""
return OptimizedChatStream( return OptimizedChatStream(
@@ -491,4 +487,4 @@ def create_optimized_chat_stream(
user_info=user_info, user_info=user_info,
group_info=group_info, group_info=group_info,
data=data data=data
) )

View File

@@ -15,7 +15,7 @@ from src.plugin_system.base.component_types import ActionActivationType, ActionI
from src.plugin_system.core.global_announcement_manager import global_announcement_manager from src.plugin_system.core.global_announcement_manager import global_announcement_manager
if TYPE_CHECKING: if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream pass
logger = get_logger("action_manager") logger = get_logger("action_manager")

View File

@@ -536,7 +536,7 @@ class Prompt:
style = expr.get("style", "") style = expr.get("style", "")
if situation and style: if situation and style:
formatted_expressions.append(f"- {situation}{style}") formatted_expressions.append(f"- {situation}{style}")
if formatted_expressions: if formatted_expressions:
style_habits_str = "\n".join(formatted_expressions) style_habits_str = "\n".join(formatted_expressions)
expression_habits_block = f"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}" expression_habits_block = f"你可以参考以下的语言习惯,当情景合适就使用,但不要生硬使用,以合理的方式结合到你的回复中:\n{style_habits_str}"

View File

@@ -9,8 +9,8 @@ import time
from collections import defaultdict from collections import defaultdict
from pathlib import Path from pathlib import Path
import rjieba
import orjson import orjson
import rjieba
from pypinyin import Style, pinyin from pypinyin import Style, pinyin
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -6,8 +6,8 @@ import time
from collections import Counter from collections import Counter
from typing import Any from typing import Any
import rjieba
import numpy as np import numpy as np
import rjieba
from maim_message import UserInfo from maim_message import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager

View File

@@ -5,9 +5,8 @@
import asyncio import asyncio
import time import time
import weakref
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import Any, Dict, Optional, Set from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
@@ -69,7 +68,7 @@ class ConnectionPoolManager:
self.max_idle = max_idle self.max_idle = max_idle
# 连接池 # 连接池
self._connections: Set[ConnectionInfo] = set() self._connections: set[ConnectionInfo] = set()
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
# 统计信息 # 统计信息
@@ -83,7 +82,7 @@ class ConnectionPoolManager:
} }
# 后台清理任务 # 后台清理任务
self._cleanup_task: Optional[asyncio.Task] = None self._cleanup_task: asyncio.Task | None = None
self._should_cleanup = False self._should_cleanup = False
logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})") logger.info(f"连接池管理器初始化完成 (最大池大小: {max_pool_size})")
@@ -144,7 +143,7 @@ class ConnectionPoolManager:
yield connection_info.session yield connection_info.session
except Exception as e: except Exception:
# 发生错误时回滚连接 # 发生错误时回滚连接
if connection_info and connection_info.session: if connection_info and connection_info.session:
try: try:
@@ -157,7 +156,7 @@ class ConnectionPoolManager:
if connection_info: if connection_info:
connection_info.mark_released() connection_info.mark_released()
async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> Optional[ConnectionInfo]: async def _get_reusable_connection(self, session_factory: async_sessionmaker[AsyncSession]) -> ConnectionInfo | None:
"""获取可复用的连接""" """获取可复用的连接"""
async with self._lock: async with self._lock:
# 清理过期连接 # 清理过期连接
@@ -231,7 +230,7 @@ class ConnectionPoolManager:
self._connections.clear() self._connections.clear()
logger.info("所有连接已关闭") logger.info("所有连接已关闭")
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> dict[str, Any]:
"""获取连接池统计信息""" """获取连接池统计信息"""
return { return {
**self._stats, **self._stats,
@@ -244,7 +243,7 @@ class ConnectionPoolManager:
# 全局连接池管理器实例 # 全局连接池管理器实例
_connection_pool_manager: Optional[ConnectionPoolManager] = None _connection_pool_manager: ConnectionPoolManager | None = None
def get_connection_pool_manager() -> ConnectionPoolManager: def get_connection_pool_manager() -> ConnectionPoolManager:
@@ -266,4 +265,4 @@ async def stop_connection_pool():
global _connection_pool_manager global _connection_pool_manager
if _connection_pool_manager: if _connection_pool_manager:
await _connection_pool_manager.stop() await _connection_pool_manager.stop()
_connection_pool_manager = None _connection_pool_manager = None

View File

@@ -2,15 +2,16 @@ import os
from rich.traceback import install from rich.traceback import install
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
# 数据库批量调度器和连接池
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
# SQLAlchemy相关导入 # SQLAlchemy相关导入
from src.common.database.sqlalchemy_init import initialize_database_compat from src.common.database.sqlalchemy_init import initialize_database_compat
from src.common.database.sqlalchemy_models import get_db_session, get_engine from src.common.database.sqlalchemy_models import get_db_session, get_engine
from src.common.logger import get_logger from src.common.logger import get_logger
# 数据库批量调度器和连接池
from src.common.database.db_batch_scheduler import get_db_batch_scheduler
from src.common.database.connection_pool_manager import start_connection_pool, stop_connection_pool
install(extra_lines=3) install(extra_lines=3)
_sql_engine = None _sql_engine = None

View File

@@ -6,19 +6,19 @@
import asyncio import asyncio
import time import time
from collections import defaultdict, deque from collections import defaultdict, deque
from dataclasses import dataclass from collections.abc import Callable
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from dataclasses import dataclass
from typing import Any, TypeVar
from sqlalchemy import select, delete, insert, update from sqlalchemy import delete, insert, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from src.common.database.sqlalchemy_database_api import get_db_session from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.logger import get_logger from src.common.logger import get_logger
logger = get_logger("db_batch_scheduler") logger = get_logger("db_batch_scheduler")
T = TypeVar('T') T = TypeVar("T")
@dataclass @dataclass
@@ -26,10 +26,10 @@ class BatchOperation:
"""批量操作基础类""" """批量操作基础类"""
operation_type: str # 'select', 'insert', 'update', 'delete' operation_type: str # 'select', 'insert', 'update', 'delete'
model_class: Any model_class: Any
conditions: Dict[str, Any] conditions: dict[str, Any]
data: Optional[Dict[str, Any]] = None data: dict[str, Any] | None = None
callback: Optional[Callable] = None callback: Callable | None = None
future: Optional[asyncio.Future] = None future: asyncio.Future | None = None
timestamp: float = 0.0 timestamp: float = 0.0
def __post_init__(self): def __post_init__(self):
@@ -42,7 +42,7 @@ class BatchResult:
"""批量操作结果""" """批量操作结果"""
success: bool success: bool
data: Any = None data: Any = None
error: Optional[str] = None error: str | None = None
class DatabaseBatchScheduler: class DatabaseBatchScheduler:
@@ -57,23 +57,23 @@ class DatabaseBatchScheduler:
self.max_queue_size = max_queue_size self.max_queue_size = max_queue_size
# 操作队列,按操作类型和模型分类 # 操作队列,按操作类型和模型分类
self.operation_queues: Dict[str, deque] = defaultdict(deque) self.operation_queues: dict[str, deque] = defaultdict(deque)
# 调度控制 # 调度控制
self._scheduler_task: Optional[asyncio.Task] = None self._scheduler_task: asyncio.Task | None = None
self._is_running = bool = False self._is_running = bool = False
self._lock = asyncio.Lock() self._lock = asyncio.Lock()
# 统计信息 # 统计信息
self.stats = { self.stats = {
'total_operations': 0, "total_operations": 0,
'batched_operations': 0, "batched_operations": 0,
'cache_hits': 0, "cache_hits": 0,
'execution_time': 0.0 "execution_time": 0.0
} }
# 简单的结果缓存(用于频繁的查询) # 简单的结果缓存(用于频繁的查询)
self._result_cache: Dict[str, Tuple[Any, float]] = {} self._result_cache: dict[str, tuple[Any, float]] = {}
self._cache_ttl = 5.0 # 5秒缓存 self._cache_ttl = 5.0 # 5秒缓存
async def start(self): async def start(self):
@@ -102,7 +102,7 @@ class DatabaseBatchScheduler:
await self._flush_all_queues() await self._flush_all_queues()
logger.info("数据库批量调度器已停止") logger.info("数据库批量调度器已停止")
def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: Dict[str, Any]) -> str: def _generate_cache_key(self, operation_type: str, model_class: Any, conditions: dict[str, Any]) -> str:
"""生成缓存键""" """生成缓存键"""
# 简单的缓存键生成,实际可以根据需要优化 # 简单的缓存键生成,实际可以根据需要优化
key_parts = [ key_parts = [
@@ -112,12 +112,12 @@ class DatabaseBatchScheduler:
] ]
return "|".join(key_parts) return "|".join(key_parts)
def _get_from_cache(self, cache_key: str) -> Optional[Any]: def _get_from_cache(self, cache_key: str) -> Any | None:
"""从缓存获取结果""" """从缓存获取结果"""
if cache_key in self._result_cache: if cache_key in self._result_cache:
result, timestamp = self._result_cache[cache_key] result, timestamp = self._result_cache[cache_key]
if time.time() - timestamp < self._cache_ttl: if time.time() - timestamp < self._cache_ttl:
self.stats['cache_hits'] += 1 self.stats["cache_hits"] += 1
return result return result
else: else:
# 清理过期缓存 # 清理过期缓存
@@ -131,7 +131,7 @@ class DatabaseBatchScheduler:
async def add_operation(self, operation: BatchOperation) -> asyncio.Future: async def add_operation(self, operation: BatchOperation) -> asyncio.Future:
"""添加操作到队列""" """添加操作到队列"""
# 检查是否可以立即返回缓存结果 # 检查是否可以立即返回缓存结果
if operation.operation_type == 'select': if operation.operation_type == "select":
cache_key = self._generate_cache_key( cache_key = self._generate_cache_key(
operation.operation_type, operation.operation_type,
operation.model_class, operation.model_class,
@@ -158,7 +158,7 @@ class DatabaseBatchScheduler:
await self._execute_operations([operation]) await self._execute_operations([operation])
else: else:
self.operation_queues[queue_key].append(operation) self.operation_queues[queue_key].append(operation)
self.stats['total_operations'] += 1 self.stats["total_operations"] += 1
return future return future
@@ -193,7 +193,7 @@ class DatabaseBatchScheduler:
if operations: if operations:
await self._execute_operations(list(operations)) await self._execute_operations(list(operations))
async def _execute_operations(self, operations: List[BatchOperation]): async def _execute_operations(self, operations: list[BatchOperation]):
"""执行批量操作""" """执行批量操作"""
if not operations: if not operations:
return return
@@ -209,13 +209,13 @@ class DatabaseBatchScheduler:
# 为每种操作类型创建批量执行任务 # 为每种操作类型创建批量执行任务
tasks = [] tasks = []
for op_type, ops in op_groups.items(): for op_type, ops in op_groups.items():
if op_type == 'select': if op_type == "select":
tasks.append(self._execute_select_batch(ops)) tasks.append(self._execute_select_batch(ops))
elif op_type == 'insert': elif op_type == "insert":
tasks.append(self._execute_insert_batch(ops)) tasks.append(self._execute_insert_batch(ops))
elif op_type == 'update': elif op_type == "update":
tasks.append(self._execute_update_batch(ops)) tasks.append(self._execute_update_batch(ops))
elif op_type == 'delete': elif op_type == "delete":
tasks.append(self._execute_delete_batch(ops)) tasks.append(self._execute_delete_batch(ops))
# 并发执行所有操作 # 并发执行所有操作
@@ -238,7 +238,7 @@ class DatabaseBatchScheduler:
operation.future.set_result(result) operation.future.set_result(result)
# 缓存查询结果 # 缓存查询结果
if operation.operation_type == 'select': if operation.operation_type == "select":
cache_key = self._generate_cache_key( cache_key = self._generate_cache_key(
operation.operation_type, operation.operation_type,
operation.model_class, operation.model_class,
@@ -246,7 +246,7 @@ class DatabaseBatchScheduler:
) )
self._set_cache(cache_key, result) self._set_cache(cache_key, result)
self.stats['batched_operations'] += len(operations) self.stats["batched_operations"] += len(operations)
except Exception as e: except Exception as e:
logger.error(f"批量操作执行失败: {e}", exc_info="") logger.error(f"批量操作执行失败: {e}", exc_info="")
@@ -255,9 +255,9 @@ class DatabaseBatchScheduler:
if operation.future and not operation.future.done(): if operation.future and not operation.future.done():
operation.future.set_exception(e) operation.future.set_exception(e)
finally: finally:
self.stats['execution_time'] += time.time() - start_time self.stats["execution_time"] += time.time() - start_time
async def _execute_select_batch(self, operations: List[BatchOperation]): async def _execute_select_batch(self, operations: list[BatchOperation]):
"""批量执行查询操作""" """批量执行查询操作"""
# 合并相似的查询条件 # 合并相似的查询条件
merged_conditions = self._merge_select_conditions(operations) merged_conditions = self._merge_select_conditions(operations)
@@ -302,7 +302,7 @@ class DatabaseBatchScheduler:
return results if len(results) > 1 else results[0] if results else [] return results if len(results) > 1 else results[0] if results else []
async def _execute_insert_batch(self, operations: List[BatchOperation]): async def _execute_insert_batch(self, operations: list[BatchOperation]):
"""批量执行插入操作""" """批量执行插入操作"""
async with get_db_session() as session: async with get_db_session() as session:
try: try:
@@ -323,7 +323,7 @@ class DatabaseBatchScheduler:
logger.error(f"批量插入失败: {e}", exc_info=True) logger.error(f"批量插入失败: {e}", exc_info=True)
return [0] * len(operations) return [0] * len(operations)
async def _execute_update_batch(self, operations: List[BatchOperation]): async def _execute_update_batch(self, operations: list[BatchOperation]):
"""批量执行更新操作""" """批量执行更新操作"""
async with get_db_session() as session: async with get_db_session() as session:
try: try:
@@ -353,7 +353,7 @@ class DatabaseBatchScheduler:
logger.error(f"批量更新失败: {e}", exc_info=True) logger.error(f"批量更新失败: {e}", exc_info=True)
return [0] * len(operations) return [0] * len(operations)
async def _execute_delete_batch(self, operations: List[BatchOperation]): async def _execute_delete_batch(self, operations: list[BatchOperation]):
"""批量执行删除操作""" """批量执行删除操作"""
async with get_db_session() as session: async with get_db_session() as session:
try: try:
@@ -382,7 +382,7 @@ class DatabaseBatchScheduler:
logger.error(f"批量删除失败: {e}", exc_info=True) logger.error(f"批量删除失败: {e}", exc_info=True)
return [0] * len(operations) return [0] * len(operations)
def _merge_select_conditions(self, operations: List[BatchOperation]) -> Dict[Tuple, List[BatchOperation]]: def _merge_select_conditions(self, operations: list[BatchOperation]) -> dict[tuple, list[BatchOperation]]:
"""合并相似的查询条件""" """合并相似的查询条件"""
merged = {} merged = {}
@@ -405,15 +405,15 @@ class DatabaseBatchScheduler:
# 记录操作 # 记录操作
if condition_key not in merged: if condition_key not in merged:
merged[condition_key] = {'_operations': []} merged[condition_key] = {"_operations": []}
if '_operations' not in merged[condition_key]: if "_operations" not in merged[condition_key]:
merged[condition_key]['_operations'] = [] merged[condition_key]["_operations"] = []
merged[condition_key]['_operations'].append(op) merged[condition_key]["_operations"].append(op)
# 去重并构建最终条件 # 去重并构建最终条件
final_merged = {} final_merged = {}
for condition_key, conditions in merged.items(): for condition_key, conditions in merged.items():
operations = conditions.pop('_operations') operations = conditions.pop("_operations")
# 去重 # 去重
for field_name, values in conditions.items(): for field_name, values in conditions.items():
@@ -423,13 +423,13 @@ class DatabaseBatchScheduler:
return final_merged return final_merged
def get_stats(self) -> Dict[str, Any]: def get_stats(self) -> dict[str, Any]:
"""获取统计信息""" """获取统计信息"""
return { return {
**self.stats, **self.stats,
'cache_size': len(self._result_cache), "cache_size": len(self._result_cache),
'queue_sizes': {k: len(v) for k, v in self.operation_queues.items()}, "queue_sizes": {k: len(v) for k, v in self.operation_queues.items()},
'is_running': self._is_running "is_running": self._is_running
} }
@@ -450,20 +450,20 @@ async def get_batch_session():
# 便捷函数 # 便捷函数
async def batch_select(model_class: Any, conditions: Dict[str, Any]) -> Any: async def batch_select(model_class: Any, conditions: dict[str, Any]) -> Any:
"""批量查询""" """批量查询"""
operation = BatchOperation( operation = BatchOperation(
operation_type='select', operation_type="select",
model_class=model_class, model_class=model_class,
conditions=conditions conditions=conditions
) )
return await db_batch_scheduler.add_operation(operation) return await db_batch_scheduler.add_operation(operation)
async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int: async def batch_insert(model_class: Any, data: dict[str, Any]) -> int:
"""批量插入""" """批量插入"""
operation = BatchOperation( operation = BatchOperation(
operation_type='insert', operation_type="insert",
model_class=model_class, model_class=model_class,
conditions={}, conditions={},
data=data data=data
@@ -471,10 +471,10 @@ async def batch_insert(model_class: Any, data: Dict[str, Any]) -> int:
return await db_batch_scheduler.add_operation(operation) return await db_batch_scheduler.add_operation(operation)
async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[str, Any]) -> int: async def batch_update(model_class: Any, conditions: dict[str, Any], data: dict[str, Any]) -> int:
"""批量更新""" """批量更新"""
operation = BatchOperation( operation = BatchOperation(
operation_type='update', operation_type="update",
model_class=model_class, model_class=model_class,
conditions=conditions, conditions=conditions,
data=data data=data
@@ -482,10 +482,10 @@ async def batch_update(model_class: Any, conditions: Dict[str, Any], data: Dict[
return await db_batch_scheduler.add_operation(operation) return await db_batch_scheduler.add_operation(operation)
async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int: async def batch_delete(model_class: Any, conditions: dict[str, Any]) -> int:
"""批量删除""" """批量删除"""
operation = BatchOperation( operation = BatchOperation(
operation_type='delete', operation_type="delete",
model_class=model_class, model_class=model_class,
conditions=conditions conditions=conditions
) )
@@ -494,4 +494,4 @@ async def batch_delete(model_class: Any, conditions: Dict[str, Any]) -> int:
def get_db_batch_scheduler() -> DatabaseBatchScheduler: def get_db_batch_scheduler() -> DatabaseBatchScheduler:
"""获取数据库批量调度器实例""" """获取数据库批量调度器实例"""
return db_batch_scheduler return db_batch_scheduler

View File

@@ -15,8 +15,8 @@ from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_asyn
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Mapped, mapped_column from sqlalchemy.orm import Mapped, mapped_column
from src.common.logger import get_logger
from src.common.database.connection_pool_manager import get_connection_pool_manager from src.common.database.connection_pool_manager import get_connection_pool_manager
from src.common.logger import get_logger
logger = get_logger("sqlalchemy_models") logger = get_logger("sqlalchemy_models")

View File

@@ -1,13 +1,13 @@
# 使用基于时间戳的文件处理器,简单的轮转份数限制 # 使用基于时间戳的文件处理器,简单的轮转份数限制
import logging import logging
import tarfile
import threading import threading
import time import time
import tarfile
from collections.abc import Callable from collections.abc import Callable
from datetime import datetime, timedelta from datetime import datetime, timedelta
from pathlib import Path from pathlib import Path
from typing import Any, Optional, Dict from typing import Any
import orjson import orjson
import structlog import structlog
@@ -18,15 +18,15 @@ LOG_DIR = Path("logs")
LOG_DIR.mkdir(exist_ok=True) LOG_DIR.mkdir(exist_ok=True)
# 全局handler实例避免重复创建可能为None表示禁用文件日志 # 全局handler实例避免重复创建可能为None表示禁用文件日志
_file_handler: Optional[logging.Handler] = None _file_handler: logging.Handler | None = None
_console_handler: Optional[logging.Handler] = None _console_handler: logging.Handler | None = None
# 动态 logger 元数据注册表 (name -> {alias:str|None, color:str|None}) # 动态 logger 元数据注册表 (name -> {alias:str|None, color:str|None})
_LOGGER_META_LOCK = threading.Lock() _LOGGER_META_LOCK = threading.Lock()
_LOGGER_META: Dict[str, Dict[str, Optional[str]]] = {} _LOGGER_META: dict[str, dict[str, str | None]] = {}
def _normalize_color(color: Optional[str]) -> Optional[str]: def _normalize_color(color: str | None) -> str | None:
"""接受 ANSI 码 / #RRGGBB / rgb(r,g,b) / 颜色名(直接返回) -> ANSI 码. """接受 ANSI 码 / #RRGGBB / rgb(r,g,b) / 颜色名(直接返回) -> ANSI 码.
不做复杂解析,只支持 #RRGGBB 转 24bit ANSI。 不做复杂解析,只支持 #RRGGBB 转 24bit ANSI。
""" """
@@ -49,13 +49,13 @@ def _normalize_color(color: Optional[str]) -> Optional[str]:
nums = color[color.find("(") + 1 : -1].split(",") nums = color[color.find("(") + 1 : -1].split(",")
r, g, b = (int(x) for x in nums[:3]) r, g, b = (int(x) for x in nums[:3])
return f"\033[38;2;{r};{g};{b}m" return f"\033[38;2;{r};{g};{b}m"
except Exception: # noqa: BLE001 except Exception:
return None return None
# 其他情况直接返回假设是短ANSI或名称控制台渲染器不做翻译仅输出 # 其他情况直接返回假设是短ANSI或名称控制台渲染器不做翻译仅输出
return color return color
def _register_logger_meta(name: str, *, alias: Optional[str] = None, color: Optional[str] = None): def _register_logger_meta(name: str, *, alias: str | None = None, color: str | None = None):
"""注册/更新 logger 元数据。""" """注册/更新 logger 元数据。"""
if not name: if not name:
return return
@@ -67,7 +67,7 @@ def _register_logger_meta(name: str, *, alias: Optional[str] = None, color: Opti
meta["color"] = _normalize_color(color) meta["color"] = _normalize_color(color)
def get_logger_meta(name: str) -> Dict[str, Optional[str]]: def get_logger_meta(name: str) -> dict[str, str | None]:
with _LOGGER_META_LOCK: with _LOGGER_META_LOCK:
return _LOGGER_META.get(name, {"alias": None, "color": None}).copy() return _LOGGER_META.get(name, {"alias": None, "color": None}).copy()
@@ -170,7 +170,7 @@ class TimestampedFileHandler(logging.Handler):
try: try:
self._compress_stale_logs() self._compress_stale_logs()
self._cleanup_old_files() self._cleanup_old_files()
except Exception as e: # noqa: BLE001 except Exception as e:
print(f"[日志轮转] 轮转过程出错: {e}") print(f"[日志轮转] 轮转过程出错: {e}")
def _compress_stale_logs(self): # sourcery skip: extract-method def _compress_stale_logs(self): # sourcery skip: extract-method
@@ -184,12 +184,12 @@ class TimestampedFileHandler(logging.Handler):
continue continue
# 压缩 # 压缩
try: try:
with tarfile.open(tar_path, "w:gz") as tf: # noqa: SIM117 with tarfile.open(tar_path, "w:gz") as tf:
tf.add(f, arcname=f.name) tf.add(f, arcname=f.name)
f.unlink(missing_ok=True) f.unlink(missing_ok=True)
except Exception as e: # noqa: BLE001 except Exception as e:
print(f"[日志压缩] 压缩 {f.name} 失败: {e}") print(f"[日志压缩] 压缩 {f.name} 失败: {e}")
except Exception as e: # noqa: BLE001 except Exception as e:
print(f"[日志压缩] 过程出错: {e}") print(f"[日志压缩] 过程出错: {e}")
def _cleanup_old_files(self): def _cleanup_old_files(self):
@@ -206,9 +206,9 @@ class TimestampedFileHandler(logging.Handler):
mtime = datetime.fromtimestamp(f.stat().st_mtime) mtime = datetime.fromtimestamp(f.stat().st_mtime)
if mtime < cutoff: if mtime < cutoff:
f.unlink(missing_ok=True) f.unlink(missing_ok=True)
except Exception as e: # noqa: BLE001 except Exception as e:
print(f"[日志清理] 删除 {f} 失败: {e}") print(f"[日志清理] 删除 {f} 失败: {e}")
except Exception as e: # noqa: BLE001 except Exception as e:
print(f"[日志清理] 清理过程出错: {e}") print(f"[日志清理] 清理过程出错: {e}")
def emit(self, record): def emit(self, record):
@@ -850,7 +850,7 @@ class ModuleColoredConsoleRenderer:
if logger_name: if logger_name:
# 获取别名,如果没有别名则使用原名称 # 获取别名,如果没有别名则使用原名称
# 若上面条件不成立需要再次获取 meta # 若上面条件不成立需要再次获取 meta
if 'meta' not in locals(): if "meta" not in locals():
meta = get_logger_meta(logger_name) meta = get_logger_meta(logger_name)
display_name = meta.get("alias") or DEFAULT_MODULE_ALIASES.get(logger_name, logger_name) display_name = meta.get("alias") or DEFAULT_MODULE_ALIASES.get(logger_name, logger_name)
@@ -1066,7 +1066,7 @@ raw_logger: structlog.stdlib.BoundLogger = structlog.get_logger()
binds: dict[str, Callable] = {} binds: dict[str, Callable] = {}
def get_logger(name: str | None, *, color: Optional[str] = None, alias: Optional[str] = None) -> structlog.stdlib.BoundLogger: def get_logger(name: str | None, *, color: str | None = None, alias: str | None = None) -> structlog.stdlib.BoundLogger:
"""获取/创建 structlog logger。 """获取/创建 structlog logger。
新增: 新增:
@@ -1132,10 +1132,10 @@ def cleanup_old_logs():
tar_path = f.with_suffix(f.suffix + ".tar.gz") tar_path = f.with_suffix(f.suffix + ".tar.gz")
if tar_path.exists(): if tar_path.exists():
continue continue
with tarfile.open(tar_path, "w:gz") as tf: # noqa: SIM117 with tarfile.open(tar_path, "w:gz") as tf:
tf.add(f, arcname=f.name) tf.add(f, arcname=f.name)
f.unlink(missing_ok=True) f.unlink(missing_ok=True)
except Exception as e: # noqa: BLE001 except Exception as e:
logger = get_logger("logger") logger = get_logger("logger")
logger.warning(f"周期压缩日志时出错: {e}") logger.warning(f"周期压缩日志时出错: {e}")
@@ -1152,7 +1152,7 @@ def cleanup_old_logs():
log_file.unlink(missing_ok=True) log_file.unlink(missing_ok=True)
deleted_count += 1 deleted_count += 1
deleted_size += size deleted_size += size
except Exception as e: # noqa: BLE001 except Exception as e:
logger = get_logger("logger") logger = get_logger("logger")
logger.warning(f"清理日志文件 {log_file} 时出错: {e}") logger.warning(f"清理日志文件 {log_file} 时出错: {e}")
if deleted_count: if deleted_count:
@@ -1160,7 +1160,7 @@ def cleanup_old_logs():
logger.info( logger.info(
f"清理 {deleted_count} 个过期日志 (≈{deleted_size / 1024 / 1024:.2f}MB), 保留策略={retention_days}" f"清理 {deleted_count} 个过期日志 (≈{deleted_size / 1024 / 1024:.2f}MB), 保留策略={retention_days}"
) )
except Exception as e: # noqa: BLE001 except Exception as e:
logger = get_logger("logger") logger = get_logger("logger")
logger.error(f"清理旧日志文件时出错: {e}") logger.error(f"清理旧日志文件时出错: {e}")
@@ -1183,7 +1183,7 @@ def start_log_cleanup_task():
while True: while True:
try: try:
cleanup_old_logs() cleanup_old_logs()
except Exception as e: # noqa: BLE001 except Exception as e:
print(f"[日志任务] 执行清理出错: {e}") print(f"[日志任务] 执行清理出错: {e}")
# 再次等待到下一个午夜 # 再次等待到下一个午夜
time.sleep(max(1, seconds_until_next_midnight())) time.sleep(max(1, seconds_until_next_midnight()))

View File

@@ -120,10 +120,10 @@ class MainSystem:
logger.warning("未发现任何兴趣计算器组件") logger.warning("未发现任何兴趣计算器组件")
return return
logger.info(f"发现的兴趣计算器组件:") logger.info("发现的兴趣计算器组件:")
for calc_name, calc_info in interest_calculators.items(): for calc_name, calc_info in interest_calculators.items():
enabled = getattr(calc_info, 'enabled', True) enabled = getattr(calc_info, "enabled", True)
default_enabled = getattr(calc_info, 'enabled_by_default', True) default_enabled = getattr(calc_info, "enabled_by_default", True)
logger.info(f" - {calc_name}: 启用: {enabled}, 默认启用: {default_enabled}") logger.info(f" - {calc_name}: 启用: {enabled}, 默认启用: {default_enabled}")
# 初始化兴趣度管理器 # 初始化兴趣度管理器
@@ -136,8 +136,8 @@ class MainSystem:
# 使用组件注册表获取组件类并注册 # 使用组件注册表获取组件类并注册
for calc_name, calc_info in interest_calculators.items(): for calc_name, calc_info in interest_calculators.items():
enabled = getattr(calc_info, 'enabled', True) enabled = getattr(calc_info, "enabled", True)
default_enabled = getattr(calc_info, 'enabled_by_default', True) default_enabled = getattr(calc_info, "enabled_by_default", True)
if not enabled or not default_enabled: if not enabled or not default_enabled:
logger.info(f"兴趣计算器 {calc_name} 未启用,跳过") logger.info(f"兴趣计算器 {calc_name} 未启用,跳过")
@@ -183,7 +183,7 @@ class MainSystem:
async def _async_cleanup(self): async def _async_cleanup(self):
"""异步清理资源""" """异步清理资源"""
try: try:
# 停止数据库服务 # 停止数据库服务
try: try:
from src.common.database.database import stop_database from src.common.database.database import stop_database
@@ -343,8 +343,8 @@ MoFox_Bot(第三方修改版)
# 初始化表情管理器 # 初始化表情管理器
get_emoji_manager().initialize() get_emoji_manager().initialize()
logger.info("表情包管理器初始化成功") logger.info("表情包管理器初始化成功")
''' """
# 初始化回复后关系追踪系统 # 初始化回复后关系追踪系统
try: try:
from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system from src.plugins.built_in.affinity_flow_chatter.interest_scoring import chatter_interest_scoring_system
@@ -356,8 +356,8 @@ MoFox_Bot(第三方修改版)
except Exception as e: except Exception as e:
logger.error(f"回复后关系追踪系统初始化失败: {e}") logger.error(f"回复后关系追踪系统初始化失败: {e}")
relationship_tracker = None relationship_tracker = None
''' """
# 启动情绪管理器 # 启动情绪管理器
await mood_manager.start() await mood_manager.start()
logger.info("情绪管理器初始化成功") logger.info("情绪管理器初始化成功")
@@ -487,10 +487,10 @@ MoFox_Bot(第三方修改版)
# 关闭应用 (MessageServer可能没有shutdown方法) # 关闭应用 (MessageServer可能没有shutdown方法)
try: try:
if self.app: if self.app:
if hasattr(self.app, 'shutdown'): if hasattr(self.app, "shutdown"):
await self.app.shutdown() await self.app.shutdown()
logger.info("应用已关闭") logger.info("应用已关闭")
elif hasattr(self.app, 'stop'): elif hasattr(self.app, "stop"):
await self.app.stop() await self.app.stop()
logger.info("应用已停止") logger.info("应用已停止")
else: else:

View File

@@ -2,7 +2,6 @@ import math
import random import random
import time import time
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.message import MessageRecv from src.chat.message_receive.message import MessageRecv
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager

View File

@@ -5,7 +5,6 @@ import time
import traceback import traceback
from typing import Any from typing import Any
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
get_raw_msg_before_timestamp_with_chat, get_raw_msg_before_timestamp_with_chat,
get_raw_msg_by_timestamp_with_chat, get_raw_msg_by_timestamp_with_chat,

View File

@@ -5,7 +5,6 @@ from typing import Any
import orjson import orjson
from json_repair import repair_json from json_repair import repair_json
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.config.config import global_config, model_config from src.config.config import global_config, model_config

View File

@@ -4,8 +4,8 @@ from datetime import datetime
from difflib import SequenceMatcher from difflib import SequenceMatcher
from typing import Any from typing import Any
import rjieba
import orjson import orjson
import rjieba
from json_repair import repair_json from json_repair import repair_json
from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity from sklearn.metrics.pairwise import cosine_similarity

View File

@@ -49,7 +49,6 @@ from .base import (
ToolParamType, ToolParamType,
create_plus_command_adapter, create_plus_command_adapter,
) )
from .utils.dependency_config import configure_dependency_settings, get_dependency_config from .utils.dependency_config import configure_dependency_settings, get_dependency_config
# 导入依赖管理模块 # 导入依赖管理模块

View File

@@ -113,7 +113,7 @@ class BaseInterestCalculator(ABC):
try: try:
self._enabled = True self._enabled = True
return True return True
except Exception as e: except Exception:
self._enabled = False self._enabled = False
return False return False
@@ -170,7 +170,7 @@ class BaseInterestCalculator(ABC):
if not self._enabled: if not self._enabled:
return InterestCalculationResult( return InterestCalculationResult(
success=False, success=False,
message_id=getattr(message, 'message_id', ''), message_id=getattr(message, "message_id", ""),
interest_value=0.0, interest_value=0.0,
error_message="组件未启用" error_message="组件未启用"
) )
@@ -184,9 +184,9 @@ class BaseInterestCalculator(ABC):
except Exception as e: except Exception as e:
result = InterestCalculationResult( result = InterestCalculationResult(
success=False, success=False,
message_id=getattr(message, 'message_id', ''), message_id=getattr(message, "message_id", ""),
interest_value=0.0, interest_value=0.0,
error_message=f"计算执行失败: {str(e)}", error_message=f"计算执行失败: {e!s}",
calculation_time=time.time() - start_time calculation_time=time.time() - start_time
) )
self._update_statistics(result) self._update_statistics(result)
@@ -201,7 +201,7 @@ class BaseInterestCalculator(ABC):
Returns: Returns:
InterestCalculatorInfo: 生成的兴趣计算器信息对象 InterestCalculatorInfo: 生成的兴趣计算器信息对象
""" """
name = getattr(cls, 'component_name', cls.__name__.lower().replace('calculator', '')) name = getattr(cls, "component_name", cls.__name__.lower().replace("calculator", ""))
if "." in name: if "." in name:
logger.error(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代") logger.error(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
raise ValueError(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代") raise ValueError(f"InterestCalculator名称 '{name}' 包含非法字符 '.',请使用下划线替代")
@@ -209,12 +209,12 @@ class BaseInterestCalculator(ABC):
return InterestCalculatorInfo( return InterestCalculatorInfo(
name=name, name=name,
component_type=ComponentType.INTEREST_CALCULATOR, component_type=ComponentType.INTEREST_CALCULATOR,
description=getattr(cls, 'component_description', cls.__doc__ or "兴趣度计算器"), description=getattr(cls, "component_description", cls.__doc__ or "兴趣度计算器"),
enabled_by_default=getattr(cls, 'enabled_by_default', True), enabled_by_default=getattr(cls, "enabled_by_default", True),
) )
def __repr__(self) -> str: def __repr__(self) -> str:
return (f"{self.__class__.__name__}(" return (f"{self.__class__.__name__}("
f"name={self.component_name}, " f"name={self.component_name}, "
f"version={self.component_version}, " f"version={self.component_version}, "
f"enabled={self._enabled})") f"enabled={self._enabled})")

View File

@@ -43,21 +43,21 @@ class BasePlugin(PluginBase):
对应类型的ComponentInfo对象 对应类型的ComponentInfo对象
""" """
if component_type == ComponentType.COMMAND: if component_type == ComponentType.COMMAND:
if hasattr(component_class, 'get_command_info'): if hasattr(component_class, "get_command_info"):
return component_class.get_command_info() return component_class.get_command_info()
else: else:
logger.warning(f"Command类 {component_class.__name__} 缺少 get_command_info 方法") logger.warning(f"Command类 {component_class.__name__} 缺少 get_command_info 方法")
return None return None
elif component_type == ComponentType.ACTION: elif component_type == ComponentType.ACTION:
if hasattr(component_class, 'get_action_info'): if hasattr(component_class, "get_action_info"):
return component_class.get_action_info() return component_class.get_action_info()
else: else:
logger.warning(f"Action类 {component_class.__name__} 缺少 get_action_info 方法") logger.warning(f"Action类 {component_class.__name__} 缺少 get_action_info 方法")
return None return None
elif component_type == ComponentType.INTEREST_CALCULATOR: elif component_type == ComponentType.INTEREST_CALCULATOR:
if hasattr(component_class, 'get_interest_calculator_info'): if hasattr(component_class, "get_interest_calculator_info"):
return component_class.get_interest_calculator_info() return component_class.get_interest_calculator_info()
else: else:
logger.warning(f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法") logger.warning(f"InterestCalculator类 {component_class.__name__} 缺少 get_interest_calculator_info 方法")

View File

@@ -1,5 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set from typing import Any
@dataclass @dataclass
class PluginMetadata: class PluginMetadata:
@@ -11,15 +12,15 @@ class PluginMetadata:
usage: str # 插件使用方法 usage: str # 插件使用方法
# 以下为可选字段,参考自 _manifest.json 和 NoneBot 设计 # 以下为可选字段,参考自 _manifest.json 和 NoneBot 设计
type: Optional[str] = None # 插件类别: "library", "application" type: str | None = None # 插件类别: "library", "application"
# 从原 _manifest.json 迁移的字段 # 从原 _manifest.json 迁移的字段
version: str = "1.0.0" # 插件版本 version: str = "1.0.0" # 插件版本
author: str = "" # 作者名称 author: str = "" # 作者名称
license: Optional[str] = None # 开源协议 license: str | None = None # 开源协议
repository_url: Optional[str] = None # 仓库地址 repository_url: str | None = None # 仓库地址
keywords: List[str] = field(default_factory=list) # 关键词 keywords: list[str] = field(default_factory=list) # 关键词
categories: List[str] = field(default_factory=list) # 分类 categories: list[str] = field(default_factory=list) # 分类
# 扩展字段 # 扩展字段
extra: Dict[str, Any] = field(default_factory=dict) # 其他任意信息 extra: dict[str, Any] = field(default_factory=dict) # 其他任意信息

View File

@@ -1,7 +1,6 @@
import asyncio import asyncio
import importlib import importlib
import os import os
import traceback
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Optional
@@ -104,7 +103,7 @@ class PluginManager:
return False, 1 return False, 1
module = self.plugin_modules.get(plugin_name) module = self.plugin_modules.get(plugin_name)
if not module or not hasattr(module, "__plugin_meta__"): if not module or not hasattr(module, "__plugin_meta__"):
self.failed_plugins[plugin_name] = "插件模块中缺少 __plugin_meta__" self.failed_plugins[plugin_name] = "插件模块中缺少 __plugin_meta__"
logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少 __plugin_meta__") logger.error(f"❌ 插件加载失败: {plugin_name} - 缺少 __plugin_meta__")
@@ -288,7 +287,7 @@ class PluginManager:
return loaded_count, failed_count return loaded_count, failed_count
def _load_plugin_module_file(self, plugin_file: str) -> Optional[Any]: def _load_plugin_module_file(self, plugin_file: str) -> Any | None:
# sourcery skip: extract-method # sourcery skip: extract-method
"""加载单个插件模块文件 """加载单个插件模块文件

View File

@@ -2,7 +2,6 @@ import inspect
import time import time
from typing import Any from typing import Any
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.common.cache_manager import tool_cache from src.common.cache_manager import tool_cache
from src.common.logger import get_logger from src.common.logger import get_logger

View File

@@ -2,4 +2,4 @@
插件系统工具模块 插件系统工具模块
提供插件开发和管理的实用工具 提供插件开发和管理的实用工具
""" """

View File

@@ -52,7 +52,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
# 用户关系数据缓存 # 用户关系数据缓存
self.user_relationships: dict[str, float] = {} # user_id -> relationship_score self.user_relationships: dict[str, float] = {} # user_id -> relationship_score
logger.info(f"[Affinity兴趣计算器] 初始化完成:") logger.info("[Affinity兴趣计算器] 初始化完成:")
logger.info(f" - 权重配置: {self.score_weights}") logger.info(f" - 权重配置: {self.score_weights}")
logger.info(f" - 回复阈值: {self.reply_threshold}") logger.info(f" - 回复阈值: {self.reply_threshold}")
logger.info(f" - 智能匹配: {self.use_smart_matching}") logger.info(f" - 智能匹配: {self.use_smart_matching}")
@@ -69,9 +69,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
"""执行AffinityFlow风格的兴趣值计算""" """执行AffinityFlow风格的兴趣值计算"""
try: try:
start_time = time.time() start_time = time.time()
message_id = getattr(message, 'message_id', '') message_id = getattr(message, "message_id", "")
content = getattr(message, 'processed_plain_text', '') content = getattr(message, "processed_plain_text", "")
user_id = getattr(message, 'user_info', {}).user_id if hasattr(message, 'user_info') and hasattr(message.user_info, 'user_id') else '' user_id = getattr(message, "user_info", {}).user_id if hasattr(message, "user_info") and hasattr(message.user_info, "user_id") else ""
logger.debug(f"[Affinity兴趣计算] 开始处理消息 {message_id}") logger.debug(f"[Affinity兴趣计算] 开始处理消息 {message_id}")
logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...") logger.debug(f"[Affinity兴趣计算] 消息内容: {content[:50]}...")
@@ -135,7 +135,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True) logger.error(f"Affinity兴趣值计算失败: {e}", exc_info=True)
return InterestCalculationResult( return InterestCalculationResult(
success=False, success=False,
message_id=getattr(message, 'message_id', ''), message_id=getattr(message, "message_id", ""),
interest_value=0.0, interest_value=0.0,
error_message=str(e) error_message=str(e)
) )
@@ -206,9 +206,9 @@ class AffinityInterestCalculator(BaseInterestCalculator):
def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float: def _calculate_mentioned_score(self, message: "DatabaseMessages", bot_nickname: str) -> float:
"""计算提及分""" """计算提及分"""
is_mentioned = getattr(message, 'is_mentioned', False) is_mentioned = getattr(message, "is_mentioned", False)
is_at = getattr(message, 'is_at', False) is_at = getattr(message, "is_at", False)
processed_plain_text = getattr(message, 'processed_plain_text', '') processed_plain_text = getattr(message, "processed_plain_text", "")
if is_mentioned: if is_mentioned:
if is_at: if is_at:
@@ -238,7 +238,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
keywords = [] keywords = []
# 尝试从 key_words 字段提取存储的是JSON字符串 # 尝试从 key_words 字段提取存储的是JSON字符串
key_words = getattr(message, 'key_words', '') key_words = getattr(message, "key_words", "")
if key_words: if key_words:
try: try:
import orjson import orjson
@@ -250,7 +250,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
# 如果没有 keywords尝试从 key_words_lite 提取 # 如果没有 keywords尝试从 key_words_lite 提取
if not keywords: if not keywords:
key_words_lite = getattr(message, 'key_words_lite', '') key_words_lite = getattr(message, "key_words_lite", "")
if key_words_lite: if key_words_lite:
try: try:
import orjson import orjson
@@ -262,7 +262,7 @@ class AffinityInterestCalculator(BaseInterestCalculator):
# 如果还是没有,从消息内容中提取(降级方案) # 如果还是没有,从消息内容中提取(降级方案)
if not keywords: if not keywords:
content = getattr(message, 'processed_plain_text', '') or '' content = getattr(message, "processed_plain_text", "") or ""
keywords = self._extract_keywords_from_content(content) keywords = self._extract_keywords_from_content(content)
return keywords[:15] # 返回前15个关键词 return keywords[:15] # 返回前15个关键词
@@ -298,4 +298,4 @@ class AffinityInterestCalculator(BaseInterestCalculator):
self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count) self.no_reply_count = min(self.no_reply_count + 1, self.max_no_reply_count)
# 是否使用智能兴趣匹配(作为类属性) # 是否使用智能兴趣匹配(作为类属性)
use_smart_matching = True use_smart_matching = True

View File

@@ -107,9 +107,9 @@ class ChatterActionPlanner:
# 直接使用消息中已计算的标志,无需重复计算兴趣值 # 直接使用消息中已计算的标志,无需重复计算兴趣值
for message in unread_messages: for message in unread_messages:
try: try:
message_interest = getattr(message, 'interest_value', 0.3) message_interest = getattr(message, "interest_value", 0.3)
message_should_reply = getattr(message, 'should_reply', False) message_should_reply = getattr(message, "should_reply", False)
message_should_act = getattr(message, 'should_act', False) message_should_act = getattr(message, "should_act", False)
# 确保interest_value不是None # 确保interest_value不是None
if message_interest is None: if message_interest is None:

View File

@@ -5,7 +5,7 @@
from src.common.logger import get_logger from src.common.logger import get_logger
from src.plugin_system.apis.plugin_register_api import register_plugin from src.plugin_system.apis.plugin_register_api import register_plugin
from src.plugin_system.base.base_plugin import BasePlugin from src.plugin_system.base.base_plugin import BasePlugin
from src.plugin_system.base.component_types import ComponentInfo, ComponentType, InterestCalculatorInfo from src.plugin_system.base.component_types import ComponentInfo
logger = get_logger("affinity_chatter_plugin") logger = get_logger("affinity_chatter_plugin")
@@ -52,4 +52,3 @@ class AffinityChatterPlugin(BasePlugin):
return components return components

View File

@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
"is_built_in": True, "is_built_in": True,
"plugin_type": "action_provider", "plugin_type": "action_provider",
} }
) )

View File

@@ -13,4 +13,4 @@ __plugin_meta__ = PluginMetadata(
"is_built_in": True, "is_built_in": True,
"plugin_type": "permission", "plugin_type": "permission",
} }
) )

View File

@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
"is_built_in": True, "is_built_in": True,
"plugin_type": "plugin_management", "plugin_type": "plugin_management",
} }
) )

View File

@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
"is_built_in": True, "is_built_in": True,
"plugin_type": "functional" "plugin_type": "functional"
} }
) )

View File

@@ -65,10 +65,10 @@ class ColdStartTask(AsyncTask):
nickname = await person_api.get_person_value(person_id, "nickname") nickname = await person_api.get_person_value(person_id, "nickname")
user_nickname = nickname or f"用户{user_id}" user_nickname = nickname or f"用户{user_id}"
user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname) user_info = UserInfo(platform=platform, user_id=str(user_id), user_nickname=user_nickname)
# 使用 get_or_create_stream 来安全地获取或创建流 # 使用 get_or_create_stream 来安全地获取或创建流
stream = await self.chat_manager.get_or_create_stream(platform, user_info) stream = await self.chat_manager.get_or_create_stream(platform, user_info)
formatted_stream_id = f"{stream.user_info.platform}:{stream.user_info.user_id}:private" formatted_stream_id = f"{stream.user_info.platform}:{stream.user_info.user_id}:private"
await self.executor.execute(stream_id=formatted_stream_id, start_mode="cold_start") await self.executor.execute(stream_id=formatted_stream_id, start_mode="cold_start")
logger.info(f"【冷启动】已为用户 {chat_id} (昵称: {user_nickname}) 发送唤醒/问候消息。") logger.info(f"【冷启动】已为用户 {chat_id} (昵称: {user_nickname}) 发送唤醒/问候消息。")

View File

@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
"is_built_in": "true", "is_built_in": "true",
"plugin_type": "functional" "plugin_type": "functional"
} }
) )

View File

@@ -14,4 +14,4 @@ __plugin_meta__ = PluginMetadata(
"is_built_in": True, "is_built_in": True,
"plugin_type": "audio_processor", "plugin_type": "audio_processor",
} }
) )

View File

@@ -13,4 +13,4 @@ __plugin_meta__ = PluginMetadata(
extra={ extra={
"is_built_in": True, "is_built_in": True,
} }
) )