re-style: 格式化代码

This commit is contained in:
John Richard
2025-10-02 20:26:01 +08:00
committed by Windpicker-owo
parent 00ba07e0e1
commit a79253c714
263 changed files with 3781 additions and 3189 deletions

View File

@@ -2,9 +2,8 @@ from src.chat.emoji_system.emoji_manager import get_emoji_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.storage import MessageStorage
__all__ = [
"get_emoji_manager",
"get_chat_manager",
"MessageStorage",
"get_chat_manager",
"get_emoji_manager",
]

View File

@@ -1,25 +1,24 @@
import traceback
import os
import re
import traceback
from typing import Any
from typing import Dict, Any, Optional
from maim_message import UserInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.message_manager import message_manager
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
from src.plugin_system.base import BaseCommand, EventType
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.chat.utils.utils import is_mentioned_bot_in_message
# 导入反注入系统
from src.chat.antipromptinjector import initialize_anti_injector
from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import Prompt, global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message
from src.common.logger import get_logger
from src.config.config import global_config
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.plugin_system.base import BaseCommand, EventType
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
@@ -220,7 +219,7 @@ class ChatBot:
logger.error(traceback.format_exc())
try:
await plus_command_instance.send_text(f"命令执行出错: {str(e)}")
await plus_command_instance.send_text(f"命令执行出错: {e!s}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
@@ -288,7 +287,7 @@ class ChatBot:
logger.error(traceback.format_exc())
try:
await command_instance.send_text(f"命令执行出错: {str(e)}")
await command_instance.send_text(f"命令执行出错: {e!s}")
except Exception as send_error:
logger.error(f"发送错误消息失败: {send_error}")
@@ -341,7 +340,7 @@ class ChatBot:
except Exception as e:
logger.error(f"处理适配器响应时出错: {e}")
async def do_s4u(self, message_data: Dict[str, Any]):
async def do_s4u(self, message_data: dict[str, Any]):
message = MessageRecvS4U(message_data)
group_info = message.message_info.group_info
user_info = message.message_info.user_info
@@ -364,7 +363,7 @@ class ChatBot:
return
async def message_process(self, message_data: Dict[str, Any]) -> None:
async def message_process(self, message_data: dict[str, Any]) -> None:
"""处理转化后的统一格式消息"""
try:
# 首先处理可能的切片消息重组
@@ -462,7 +461,7 @@ class ChatBot:
# TODO:暂不可用
# 确认从接口发来的message是否有自定义的prompt模板信息
if message.message_info.template_info and not message.message_info.template_info.template_default:
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
template_items = message.message_info.template_info.template_items
async with global_prompt_manager.async_message_scope(template_group_name):
if isinstance(template_items, dict):

View File

@@ -1,17 +1,18 @@
import asyncio
import copy
import hashlib
import time
import copy
from typing import Dict, Optional, TYPE_CHECKING
from rich.traceback import install
from maim_message import GroupInfo, UserInfo
from typing import TYPE_CHECKING
from src.common.logger import get_logger
from maim_message import GroupInfo, UserInfo
from rich.traceback import install
from sqlalchemy import select
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from sqlalchemy.dialects.mysql import insert as mysql_insert
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
from src.common.logger import get_logger
from src.config.config import global_config # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
@@ -33,8 +34,8 @@ class ChatStream:
stream_id: str,
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: Optional[dict] = None,
group_info: GroupInfo | None = None,
data: dict | None = None,
):
self.stream_id = stream_id
self.platform = platform
@@ -47,7 +48,7 @@ class ChatStream:
# 使用StreamContext替代ChatMessageContext
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatType, ChatMode
from src.plugin_system.base.component_types import ChatMode, ChatType
# 创建StreamContext
self.stream_context: StreamContext = StreamContext(
@@ -133,11 +134,11 @@ class ChatStream:
# 恢复stream_context信息
if "stream_context_chat_type" in data:
from src.plugin_system.base.component_types import ChatType, ChatMode
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
if "stream_context_chat_mode" in data:
from src.plugin_system.base.component_types import ChatType, ChatMode
from src.plugin_system.base.component_types import ChatMode, ChatType
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
@@ -163,9 +164,10 @@ class ChatStream:
def set_context(self, message: "MessageRecv"):
"""设置聊天消息上下文"""
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
from src.common.data_models.database_data_model import DatabaseMessages
import json
from src.common.data_models.database_data_model import DatabaseMessages
# 安全获取message_info中的数据
message_info = getattr(message, "message_info", {})
user_info = getattr(message_info, "user_info", {})
@@ -248,7 +250,7 @@ class ChatStream:
f"interest_value: {db_message.interest_value}"
)
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
"""安全获取消息的actions字段"""
try:
actions = getattr(message, "actions", None)
@@ -278,7 +280,7 @@ class ChatStream:
logger.warning(f"获取actions字段失败: {e}")
return None
def _extract_reply_from_segment(self, segment) -> Optional[str]:
def _extract_reply_from_segment(self, segment) -> str | None:
"""从消息段中提取reply_to信息"""
try:
if hasattr(segment, "type") and segment.type == "seglist":
@@ -391,8 +393,8 @@ class ChatManager:
def __init__(self):
if not self._initialized:
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
# try:
# async with get_db_session() as session:
# db.connect(reuse_if_open=True)
@@ -414,7 +416,7 @@ class ChatManager:
await self.load_all_streams()
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
except Exception as e:
logger.error(f"聊天管理器启动失败: {str(e)}")
logger.error(f"聊天管理器启动失败: {e!s}")
async def _auto_save_task(self):
"""定期自动保存所有聊天流"""
@@ -424,7 +426,7 @@ class ChatManager:
await self._save_all_streams()
logger.info("聊天流自动保存完成")
except Exception as e:
logger.error(f"聊天流自动保存失败: {str(e)}")
logger.error(f"聊天流自动保存失败: {e!s}")
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
@@ -437,9 +439,7 @@ class ChatManager:
# logger.debug(f"注册消息到聊天流: {stream_id}")
@staticmethod
def _generate_stream_id(
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
) -> str:
def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str:
"""生成聊天流唯一ID"""
if not user_info and not group_info:
raise ValueError("用户信息或群组信息必须提供")
@@ -462,7 +462,7 @@ class ChatManager:
return hashlib.md5(key.encode()).hexdigest()
async def get_or_create_stream(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
) -> ChatStream:
"""获取或创建聊天流
@@ -572,7 +572,7 @@ class ChatManager:
await self._save_stream(stream)
return stream
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
def get_stream(self, stream_id: str) -> ChatStream | None:
"""通过stream_id获取聊天流"""
stream = self.streams.get(stream_id)
if not stream:
@@ -582,13 +582,13 @@ class ChatManager:
return stream
def get_stream_by_info(
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
) -> Optional[ChatStream]:
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
) -> ChatStream | None:
"""通过信息获取聊天流"""
stream_id = self._generate_stream_id(platform, user_info, group_info)
return self.streams.get(stream_id)
def get_stream_name(self, stream_id: str) -> Optional[str]:
def get_stream_name(self, stream_id: str) -> str | None:
"""根据 stream_id 获取聊天流名称"""
stream = self.get_stream(stream_id)
if not stream:

View File

@@ -1,20 +1,19 @@
import base64
import time
from abc import abstractmethod, ABCMeta
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Optional, Any
from typing import Any, Optional
import urllib3
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
from rich.traceback import install
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils_image import get_image_manager
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
from src.chat.utils.utils_voice import get_voice_text
from src.common.logger import get_logger
from src.config.config import global_config
from src.chat.message_receive.chat_stream import ChatStream
install(extra_lines=3)
@@ -41,8 +40,8 @@ class Message(MessageBase, metaclass=ABCMeta):
message_id: str,
chat_stream: "ChatStream",
user_info: UserInfo,
message_segment: Optional[Seg] = None,
timestamp: Optional[float] = None,
message_segment: Seg | None = None,
timestamp: float | None = None,
reply: Optional["MessageRecv"] = None,
processed_plain_text: str = "",
):
@@ -264,7 +263,7 @@ class MessageRecv(Message):
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {str(e)}")
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
@@ -278,7 +277,7 @@ class MessageRecv(Message):
logger.info("未启用视频识别")
return "[视频]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@@ -291,7 +290,7 @@ class MessageRecvS4U(MessageRecv):
self.is_superchat = False
self.gift_info = None
self.gift_name = None
self.gift_count: Optional[str] = None
self.gift_count: str | None = None
self.superchat_info = None
self.superchat_price = None
self.superchat_message_text = None
@@ -444,7 +443,7 @@ class MessageRecvS4U(MessageRecv):
logger.warning("视频消息中没有base64数据")
return "[收到视频消息,但数据异常]"
except Exception as e:
logger.error(f"视频处理失败: {str(e)}")
logger.error(f"视频处理失败: {e!s}")
import traceback
logger.error(f"错误详情: {traceback.format_exc()}")
@@ -458,7 +457,7 @@ class MessageRecvS4U(MessageRecv):
logger.info("未启用视频识别")
return "[视频]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
return f"[处理失败的{segment.type}消息]"
@@ -471,10 +470,10 @@ class MessageProcessBase(Message):
message_id: str,
chat_stream: "ChatStream",
bot_user_info: UserInfo,
message_segment: Optional[Seg] = None,
message_segment: Seg | None = None,
reply: Optional["MessageRecv"] = None,
thinking_start_time: float = 0,
timestamp: Optional[float] = None,
timestamp: float | None = None,
):
# 调用父类初始化,传递时间戳
super().__init__(
@@ -533,9 +532,9 @@ class MessageProcessBase(Message):
return f"[回复<{self.reply.message_info.user_info.user_nickname}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
return None
else:
return f"[{seg.type}:{str(seg.data)}]"
return f"[{seg.type}:{seg.data!s}]"
except Exception as e:
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
logger.error(f"处理消息段失败: {e!s}, 类型: {seg.type}, 数据: {seg.data}")
return f"[处理失败的{seg.type}消息]"
def _generate_detailed_text(self) -> str:
@@ -565,8 +564,7 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: Optional[str] = None,
selected_expressions:List[int] = None,
reply_to: str | None = None,
):
# 调用父类初始化
super().__init__(
@@ -638,11 +636,11 @@ class MessageSet:
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
def get_message_by_index(self, index: int) -> MessageSending | None:
"""通过索引获取消息"""
return self.messages[index] if 0 <= index < len(self.messages) else None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
def get_message_by_time(self, target_time: float) -> MessageSending | None:
"""获取最接近指定时间的消息"""
if not self.messages:
return None

View File

@@ -1,15 +1,16 @@
import re
import json
import traceback
from typing import Union
import orjson
from sqlalchemy import select, desc, update
from sqlalchemy import desc, select, update
from src.common.database.sqlalchemy_models import Messages, Images, get_db_session
from src.common.database.sqlalchemy_database_api import get_db_session
from src.common.database.sqlalchemy_models import Images, Messages
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
from .message import MessageRecv, MessageSending
logger = get_logger("message_storage")
@@ -33,7 +34,7 @@ class MessageStorage:
return []
@staticmethod
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
"""存储消息到数据库"""
try:
# 过滤敏感信息的正则模式
@@ -292,6 +293,7 @@ class MessageStorage:
try:
async with get_db_session() as session:
from sqlalchemy import select, update
from src.common.database.sqlalchemy_models import Messages
# 查找需要修复的记录interest_value为0、null或很小的值

View File

@@ -3,12 +3,11 @@ import traceback
from rich.traceback import install
from src.common.message.api import get_global_api
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageSending
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
from src.chat.utils.utils import calculate_typing_time, truncate_message
from src.common.logger import get_logger
from src.common.message.api import get_global_api
install(extra_lines=3)
@@ -27,7 +26,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
return True
except Exception as e:
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}")
traceback.print_exc()
raise e # 重新抛出其他异常