re-style: 格式化代码
This commit is contained in:
@@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import time
|
||||
import copy
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 新增导入
|
||||
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
@@ -33,8 +34,8 @@ class ChatStream:
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: Optional[dict] = None,
|
||||
group_info: GroupInfo | None = None,
|
||||
data: dict | None = None,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
@@ -47,7 +48,7 @@ class ChatStream:
|
||||
|
||||
# 使用StreamContext替代ChatMessageContext
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
# 创建StreamContext
|
||||
self.stream_context: StreamContext = StreamContext(
|
||||
@@ -133,11 +134,11 @@ class ChatStream:
|
||||
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
@@ -163,9 +164,10 @@ class ChatStream:
|
||||
def set_context(self, message: "MessageRecv"):
|
||||
"""设置聊天消息上下文"""
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import json
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 安全获取message_info中的数据
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
@@ -248,7 +250,7 @@ class ChatStream:
|
||||
f"interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
@@ -278,7 +280,7 @@ class ChatStream:
|
||||
logger.warning(f"获取actions字段失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> Optional[str]:
|
||||
def _extract_reply_from_segment(self, segment) -> str | None:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
@@ -391,8 +393,8 @@ class ChatManager:
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
# try:
|
||||
# async with get_db_session() as session:
|
||||
# db.connect(reuse_if_open=True)
|
||||
@@ -414,7 +416,7 @@ class ChatManager:
|
||||
await self.load_all_streams()
|
||||
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||||
logger.error(f"聊天管理器启动失败: {e!s}")
|
||||
|
||||
async def _auto_save_task(self):
|
||||
"""定期自动保存所有聊天流"""
|
||||
@@ -424,7 +426,7 @@ class ChatManager:
|
||||
await self._save_all_streams()
|
||||
logger.info("聊天流自动保存完成")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
logger.error(f"聊天流自动保存失败: {e!s}")
|
||||
|
||||
def register_message(self, message: "MessageRecv"):
|
||||
"""注册消息到聊天流"""
|
||||
@@ -437,9 +439,7 @@ class ChatManager:
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(
|
||||
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||||
) -> str:
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
@@ -462,7 +462,7 @@ class ChatManager:
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream:
|
||||
"""获取或创建聊天流
|
||||
|
||||
@@ -572,7 +572,7 @@ class ChatManager:
|
||||
await self._save_stream(stream)
|
||||
return stream
|
||||
|
||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||
def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||
"""通过stream_id获取聊天流"""
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
@@ -582,13 +582,13 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
def get_stream_by_info(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> Optional[ChatStream]:
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream | None:
|
||||
"""通过信息获取聊天流"""
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
def get_stream_name(self, stream_id: str) -> Optional[str]:
|
||||
def get_stream_name(self, stream_id: str) -> str | None:
|
||||
"""根据 stream_id 获取聊天流名称"""
|
||||
stream = self.get_stream(stream_id)
|
||||
if not stream:
|
||||
|
||||
Reference in New Issue
Block a user