re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -3,14 +3,14 @@
|
||||
提供统一的消息管理、上下文管理和流循环调度功能
|
||||
"""
|
||||
|
||||
from .message_manager import MessageManager, message_manager
|
||||
from .context_manager import SingleStreamContextManager
|
||||
from .distribution_manager import StreamLoopManager, stream_loop_manager
|
||||
from .message_manager import MessageManager, message_manager
|
||||
|
||||
__all__ = [
|
||||
"MessageManager",
|
||||
"message_manager",
|
||||
"SingleStreamContextManager",
|
||||
"StreamLoopManager",
|
||||
"message_manager",
|
||||
"stream_loop_manager",
|
||||
]
|
||||
|
||||
@@ -6,13 +6,14 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, List, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.chat.energy_system import energy_manager
|
||||
|
||||
from .distribution_manager import stream_loop_manager
|
||||
|
||||
logger = get_logger("context_manager")
|
||||
@@ -21,7 +22,7 @@ logger = get_logger("context_manager")
|
||||
class SingleStreamContextManager:
|
||||
"""单流上下文管理器 - 每个实例只管理一个 stream 的上下文"""
|
||||
|
||||
def __init__(self, stream_id: str, context: StreamContext, max_context_size: Optional[int] = None):
|
||||
def __init__(self, stream_id: str, context: StreamContext, max_context_size: int | None = None):
|
||||
self.stream_id = stream_id
|
||||
self.context = context
|
||||
|
||||
@@ -66,7 +67,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"添加消息到单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_message(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
async def update_message(self, message_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""更新上下文中的消息
|
||||
|
||||
Args:
|
||||
@@ -84,7 +85,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"更新单流上下文消息失败 {self.stream_id}/{message_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_messages(self, limit: Optional[int] = None, include_unread: bool = True) -> List[DatabaseMessages]:
|
||||
def get_messages(self, limit: int | None = None, include_unread: bool = True) -> list[DatabaseMessages]:
|
||||
"""获取上下文消息
|
||||
|
||||
Args:
|
||||
@@ -117,7 +118,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"获取单流上下文消息失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_unread_messages(self) -> List[DatabaseMessages]:
|
||||
def get_unread_messages(self) -> list[DatabaseMessages]:
|
||||
"""获取未读消息"""
|
||||
try:
|
||||
return self.context.get_unread_messages()
|
||||
@@ -125,7 +126,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"获取单流未读消息失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
def mark_messages_as_read(self, message_ids: List[str]) -> bool:
|
||||
def mark_messages_as_read(self, message_ids: list[str]) -> bool:
|
||||
"""标记消息为已读"""
|
||||
try:
|
||||
if not hasattr(self.context, "mark_message_as_read"):
|
||||
@@ -168,7 +169,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"清空单流上下文失败 {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
def get_statistics(self) -> Dict[str, Any]:
|
||||
def get_statistics(self) -> dict[str, Any]:
|
||||
"""获取流统计信息"""
|
||||
try:
|
||||
current_time = time.time()
|
||||
@@ -285,7 +286,7 @@ class SingleStreamContextManager:
|
||||
logger.error(f"添加消息到单流上下文失败 (async) {self.stream_id}: {e}", exc_info=True)
|
||||
return False
|
||||
|
||||
async def update_message_async(self, message_id: str, updates: Dict[str, Any]) -> bool:
|
||||
async def update_message_async(self, message_id: str, updates: dict[str, Any]) -> bool:
|
||||
"""异步实现的 update_message:更新消息并在需要时 await 能量更新。"""
|
||||
try:
|
||||
self.context.update_message_info(message_id, **updates)
|
||||
@@ -327,7 +328,7 @@ class SingleStreamContextManager:
|
||||
"""更新流能量"""
|
||||
try:
|
||||
history_messages = self.context.get_history_messages(limit=self.max_context_size)
|
||||
messages: List[DatabaseMessages] = list(history_messages)
|
||||
messages: list[DatabaseMessages] = list(history_messages)
|
||||
|
||||
if include_unread:
|
||||
messages.extend(self.get_unread_messages())
|
||||
|
||||
@@ -5,12 +5,12 @@
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional, Any
|
||||
from typing import Any
|
||||
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.energy_system import energy_manager
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
logger = get_logger("stream_loop_manager")
|
||||
@@ -19,13 +19,13 @@ logger = get_logger("stream_loop_manager")
|
||||
class StreamLoopManager:
|
||||
"""流循环管理器 - 每个流一个独立的无限循环任务"""
|
||||
|
||||
def __init__(self, max_concurrent_streams: Optional[int] = None):
|
||||
def __init__(self, max_concurrent_streams: int | None = None):
|
||||
# 流循环任务管理
|
||||
self.stream_loops: Dict[str, asyncio.Task] = {}
|
||||
self.stream_loops: dict[str, asyncio.Task] = {}
|
||||
self.loop_lock = asyncio.Lock()
|
||||
|
||||
# 统计信息
|
||||
self.stats: Dict[str, Any] = {
|
||||
self.stats: dict[str, Any] = {
|
||||
"active_streams": 0,
|
||||
"total_loops": 0,
|
||||
"total_process_cycles": 0,
|
||||
@@ -37,13 +37,13 @@ class StreamLoopManager:
|
||||
self.max_concurrent_streams = max_concurrent_streams or global_config.chat.max_concurrent_distributions
|
||||
|
||||
# 强制分发策略
|
||||
self.force_dispatch_unread_threshold: Optional[int] = getattr(
|
||||
self.force_dispatch_unread_threshold: int | None = getattr(
|
||||
global_config.chat, "force_dispatch_unread_threshold", 20
|
||||
)
|
||||
self.force_dispatch_min_interval: float = getattr(global_config.chat, "force_dispatch_min_interval", 0.1)
|
||||
|
||||
# Chatter管理器
|
||||
self.chatter_manager: Optional[ChatterManager] = None
|
||||
self.chatter_manager: ChatterManager | None = None
|
||||
|
||||
# 状态控制
|
||||
self.is_running = False
|
||||
@@ -212,7 +212,7 @@ class StreamLoopManager:
|
||||
|
||||
logger.info(f"流循环结束: {stream_id}")
|
||||
|
||||
async def _get_stream_context(self, stream_id: str) -> Optional[Any]:
|
||||
async def _get_stream_context(self, stream_id: str) -> Any | None:
|
||||
"""获取流上下文
|
||||
|
||||
Args:
|
||||
@@ -320,7 +320,7 @@ class StreamLoopManager:
|
||||
logger.debug(f"流 {stream_id} 使用默认间隔: {base_interval:.2f}s ({e})")
|
||||
return base_interval
|
||||
|
||||
def get_queue_status(self) -> Dict[str, Any]:
|
||||
def get_queue_status(self) -> dict[str, Any]:
|
||||
"""获取队列状态
|
||||
|
||||
Returns:
|
||||
@@ -374,14 +374,14 @@ class StreamLoopManager:
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def _needs_force_dispatch_for_context(self, context: Any, unread_count: Optional[int] = None) -> bool:
|
||||
def _needs_force_dispatch_for_context(self, context: Any, unread_count: int | None = None) -> bool:
|
||||
if not self.force_dispatch_unread_threshold or self.force_dispatch_unread_threshold <= 0:
|
||||
return False
|
||||
|
||||
count = unread_count if unread_count is not None else self._get_unread_count(context)
|
||||
return count > self.force_dispatch_unread_threshold
|
||||
|
||||
def get_performance_summary(self) -> Dict[str, Any]:
|
||||
def get_performance_summary(self) -> dict[str, Any]:
|
||||
"""获取性能摘要
|
||||
|
||||
Returns:
|
||||
|
||||
@@ -6,19 +6,20 @@
|
||||
import asyncio
|
||||
import random
|
||||
import time
|
||||
from typing import Dict, Optional, Any, TYPE_CHECKING, List
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.data_models.message_manager_data_model import MessageManagerStats, StreamStats
|
||||
from src.chat.chatter_manager import ChatterManager
|
||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.plugin_system.apis.chat_api import get_chat_manager
|
||||
|
||||
from .distribution_manager import stream_loop_manager
|
||||
from .sleep_manager.sleep_manager import SleepManager
|
||||
from .sleep_manager.wakeup_manager import WakeUpManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -32,7 +33,7 @@ class MessageManager:
|
||||
def __init__(self, check_interval: float = 5.0):
|
||||
self.check_interval = check_interval # 检查间隔(秒)
|
||||
self.is_running = False
|
||||
self.manager_task: Optional[asyncio.Task] = None
|
||||
self.manager_task: asyncio.Task | None = None
|
||||
|
||||
# 统计信息
|
||||
self.stats = MessageManagerStats()
|
||||
@@ -125,7 +126,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"更新消息 {message_id} 时发生错误: {e}")
|
||||
|
||||
async def bulk_update_messages(self, stream_id: str, updates: List[Dict[str, Any]]) -> int:
|
||||
async def bulk_update_messages(self, stream_id: str, updates: list[dict[str, Any]]) -> int:
|
||||
"""批量更新消息信息,降低更新频率"""
|
||||
if not updates:
|
||||
return 0
|
||||
@@ -214,7 +215,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"激活聊天流 {stream_id} 时发生错误: {e}")
|
||||
|
||||
def get_stream_stats(self, stream_id: str) -> Optional[StreamStats]:
|
||||
def get_stream_stats(self, stream_id: str) -> StreamStats | None:
|
||||
"""获取聊天流统计"""
|
||||
try:
|
||||
# 通过 ChatManager 获取 ChatStream
|
||||
@@ -243,7 +244,7 @@ class MessageManager:
|
||||
logger.error(f"获取聊天流 {stream_id} 统计时发生错误: {e}")
|
||||
return None
|
||||
|
||||
def get_manager_stats(self) -> Dict[str, Any]:
|
||||
def get_manager_stats(self) -> dict[str, Any]:
|
||||
"""获取管理器统计"""
|
||||
return {
|
||||
"total_streams": self.stats.total_streams,
|
||||
@@ -278,7 +279,7 @@ class MessageManager:
|
||||
except Exception as e:
|
||||
logger.error(f"清理不活跃聊天流时发生错误: {e}")
|
||||
|
||||
async def _check_and_handle_interruption(self, chat_stream: Optional[ChatStream] = None):
|
||||
async def _check_and_handle_interruption(self, chat_stream: ChatStream | None = None):
|
||||
"""检查并处理消息打断"""
|
||||
if not global_config.chat.interruption_enabled:
|
||||
return
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import asyncio
|
||||
import random
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
|
||||
from .notification_sender import NotificationSender
|
||||
from .sleep_state import SleepState, SleepContext
|
||||
from .sleep_state import SleepContext, SleepState
|
||||
from .time_checker import TimeChecker
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -92,7 +93,7 @@ class SleepManager:
|
||||
elif current_state == SleepState.WOKEN_UP:
|
||||
self._handle_woken_up(now, is_in_theoretical_sleep, wakeup_manager)
|
||||
|
||||
def _handle_awake_to_sleep(self, now: datetime, activity: Optional[str], wakeup_manager: Optional["WakeUpManager"]):
|
||||
def _handle_awake_to_sleep(self, now: datetime, activity: str | None, wakeup_manager: Optional["WakeUpManager"]):
|
||||
"""处理从“清醒”到“准备入睡”的状态转换。"""
|
||||
if activity:
|
||||
logger.info(f"进入理论休眠时间 '{activity}',开始进行睡眠决策...")
|
||||
@@ -181,7 +182,7 @@ class SleepManager:
|
||||
self,
|
||||
now: datetime,
|
||||
is_in_theoretical_sleep: bool,
|
||||
activity: Optional[str],
|
||||
activity: str | None,
|
||||
wakeup_manager: Optional["WakeUpManager"],
|
||||
):
|
||||
"""处理“正在睡觉”状态下的逻辑。"""
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import date, datetime
|
||||
from enum import Enum, auto
|
||||
from datetime import datetime, date
|
||||
from typing import Optional
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.manager.local_store_manager import local_storage
|
||||
@@ -29,10 +28,10 @@ class SleepContext:
|
||||
def __init__(self):
|
||||
"""初始化睡眠上下文,并从本地存储加载初始状态。"""
|
||||
self.current_state: SleepState = SleepState.AWAKE
|
||||
self.sleep_buffer_end_time: Optional[datetime] = None
|
||||
self.sleep_buffer_end_time: datetime | None = None
|
||||
self.total_delayed_minutes_today: float = 0.0
|
||||
self.last_sleep_check_date: Optional[date] = None
|
||||
self.re_sleep_attempt_time: Optional[datetime] = None
|
||||
self.last_sleep_check_date: date | None = None
|
||||
self.re_sleep_attempt_time: datetime | None = None
|
||||
self.load()
|
||||
|
||||
def save(self):
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Optional, List, Dict, Any
|
||||
import random
|
||||
from datetime import datetime, time, timedelta
|
||||
from typing import Any
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
@@ -37,11 +37,11 @@ class TimeChecker:
|
||||
return self._daily_sleep_offset, self._daily_wake_offset
|
||||
|
||||
@staticmethod
|
||||
def get_today_schedule() -> Optional[List[Dict[str, Any]]]:
|
||||
def get_today_schedule() -> list[dict[str, Any]] | None:
|
||||
"""从全局 ScheduleManager 获取今天的日程安排。"""
|
||||
return schedule_manager.today_schedule
|
||||
|
||||
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
def is_in_theoretical_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
|
||||
if global_config.sleep_system.sleep_by_schedule:
|
||||
if self.get_today_schedule():
|
||||
return self._is_in_schedule_sleep_time(now_time)
|
||||
@@ -50,7 +50,7 @@ class TimeChecker:
|
||||
else:
|
||||
return self._is_in_sleep_time(now_time)
|
||||
|
||||
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
def _is_in_schedule_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
|
||||
"""检查当前时间是否落在日程表的任何一个睡眠活动中"""
|
||||
sleep_keywords = ["休眠", "睡觉", "梦乡"]
|
||||
today_schedule = self.get_today_schedule()
|
||||
@@ -79,7 +79,7 @@ class TimeChecker:
|
||||
continue
|
||||
return False, None
|
||||
|
||||
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, Optional[str]]:
|
||||
def _is_in_sleep_time(self, now_time: time) -> tuple[bool, str | None]:
|
||||
"""检查当前时间是否在固定的睡眠时间内(应用偏移量)"""
|
||||
try:
|
||||
start_time_str = global_config.sleep_system.fixed_sleep_time
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_manager.sleep_manager.wakeup_context import WakeUpContext
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .sleep_manager import SleepManager
|
||||
@@ -27,9 +28,9 @@ class WakeUpManager:
|
||||
"""
|
||||
self.sleep_manager = sleep_manager
|
||||
self.context = WakeUpContext() # 使用新的上下文管理器
|
||||
self.angry_chat_id: Optional[str] = None
|
||||
self.angry_chat_id: str | None = None
|
||||
self.last_decay_time = time.time()
|
||||
self._decay_task: Optional[asyncio.Task] = None
|
||||
self._decay_task: asyncio.Task | None = None
|
||||
self.is_running = False
|
||||
self.last_log_time = 0
|
||||
self.log_interval = 30
|
||||
@@ -104,9 +105,7 @@ class WakeUpManager:
|
||||
logger.debug(f"唤醒度衰减: {old_value:.1f} -> {self.context.wakeup_value:.1f}")
|
||||
self.context.save()
|
||||
|
||||
def add_wakeup_value(
|
||||
self, is_private_chat: bool, is_mentioned: bool = False, chat_id: Optional[str] = None
|
||||
) -> bool:
|
||||
def add_wakeup_value(self, is_private_chat: bool, is_mentioned: bool = False, chat_id: str | None = None) -> bool:
|
||||
"""
|
||||
增加唤醒度值
|
||||
|
||||
|
||||
Reference in New Issue
Block a user