re-style: 格式化代码
This commit is contained in:
committed by
Windpicker-owo
parent
00ba07e0e1
commit
a79253c714
@@ -2,9 +2,8 @@ from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
|
||||
|
||||
__all__ = [
|
||||
"get_emoji_manager",
|
||||
"get_chat_manager",
|
||||
"MessageStorage",
|
||||
"get_chat_manager",
|
||||
"get_emoji_manager",
|
||||
]
|
||||
|
||||
@@ -1,25 +1,24 @@
|
||||
import traceback
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
|
||||
from typing import Dict, Any, Optional
|
||||
from maim_message import UserInfo
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
|
||||
# 导入反注入系统
|
||||
from src.chat.antipromptinjector import initialize_anti_injector
|
||||
from src.chat.message_manager import message_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mood.mood_manager import mood_manager # 导入情绪管理器
|
||||
from src.plugin_system.base import BaseCommand, EventType
|
||||
from src.plugin_system.core import component_registry, event_manager, global_announcement_manager
|
||||
|
||||
# 获取项目根目录(假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
|
||||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../.."))
|
||||
@@ -220,7 +219,7 @@ class ChatBot:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await plus_command_instance.send_text(f"命令执行出错: {str(e)}")
|
||||
await plus_command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
@@ -288,7 +287,7 @@ class ChatBot:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
try:
|
||||
await command_instance.send_text(f"命令执行出错: {str(e)}")
|
||||
await command_instance.send_text(f"命令执行出错: {e!s}")
|
||||
except Exception as send_error:
|
||||
logger.error(f"发送错误消息失败: {send_error}")
|
||||
|
||||
@@ -341,7 +340,7 @@ class ChatBot:
|
||||
except Exception as e:
|
||||
logger.error(f"处理适配器响应时出错: {e}")
|
||||
|
||||
async def do_s4u(self, message_data: Dict[str, Any]):
|
||||
async def do_s4u(self, message_data: dict[str, Any]):
|
||||
message = MessageRecvS4U(message_data)
|
||||
group_info = message.message_info.group_info
|
||||
user_info = message.message_info.user_info
|
||||
@@ -364,7 +363,7 @@ class ChatBot:
|
||||
|
||||
return
|
||||
|
||||
async def message_process(self, message_data: Dict[str, Any]) -> None:
|
||||
async def message_process(self, message_data: dict[str, Any]) -> None:
|
||||
"""处理转化后的统一格式消息"""
|
||||
try:
|
||||
# 首先处理可能的切片消息重组
|
||||
@@ -462,7 +461,7 @@ class ChatBot:
|
||||
# TODO:暂不可用
|
||||
# 确认从接口发来的message是否有自定义的prompt模板信息
|
||||
if message.message_info.template_info and not message.message_info.template_info.template_default:
|
||||
template_group_name: Optional[str] = message.message_info.template_info.template_name # type: ignore
|
||||
template_group_name: str | None = message.message_info.template_info.template_name # type: ignore
|
||||
template_items = message.message_info.template_info.template_items
|
||||
async with global_prompt_manager.async_message_scope(template_group_name):
|
||||
if isinstance(template_items, dict):
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import hashlib
|
||||
import time
|
||||
import copy
|
||||
from typing import Dict, Optional, TYPE_CHECKING
|
||||
from rich.traceback import install
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from maim_message import GroupInfo, UserInfo
|
||||
from rich.traceback import install
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
from sqlalchemy.dialects.mysql import insert as mysql_insert
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
||||
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import ChatStreams # 新增导入
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config # 新增导入
|
||||
|
||||
# 避免循环导入,使用TYPE_CHECKING进行类型提示
|
||||
@@ -33,8 +34,8 @@ class ChatStream:
|
||||
stream_id: str,
|
||||
platform: str,
|
||||
user_info: UserInfo,
|
||||
group_info: Optional[GroupInfo] = None,
|
||||
data: Optional[dict] = None,
|
||||
group_info: GroupInfo | None = None,
|
||||
data: dict | None = None,
|
||||
):
|
||||
self.stream_id = stream_id
|
||||
self.platform = platform
|
||||
@@ -47,7 +48,7 @@ class ChatStream:
|
||||
|
||||
# 使用StreamContext替代ChatMessageContext
|
||||
from src.common.data_models.message_manager_data_model import StreamContext
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
# 创建StreamContext
|
||||
self.stream_context: StreamContext = StreamContext(
|
||||
@@ -133,11 +134,11 @@ class ChatStream:
|
||||
|
||||
# 恢复stream_context信息
|
||||
if "stream_context_chat_type" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_type = ChatType(data["stream_context_chat_type"])
|
||||
if "stream_context_chat_mode" in data:
|
||||
from src.plugin_system.base.component_types import ChatType, ChatMode
|
||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||
|
||||
instance.stream_context.chat_mode = ChatMode(data["stream_context_chat_mode"])
|
||||
|
||||
@@ -163,9 +164,10 @@ class ChatStream:
|
||||
def set_context(self, message: "MessageRecv"):
|
||||
"""设置聊天消息上下文"""
|
||||
# 将MessageRecv转换为DatabaseMessages并设置到stream_context
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
import json
|
||||
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
|
||||
# 安全获取message_info中的数据
|
||||
message_info = getattr(message, "message_info", {})
|
||||
user_info = getattr(message_info, "user_info", {})
|
||||
@@ -248,7 +250,7 @@ class ChatStream:
|
||||
f"interest_value: {db_message.interest_value}"
|
||||
)
|
||||
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> Optional[list]:
|
||||
def _safe_get_actions(self, message: "MessageRecv") -> list | None:
|
||||
"""安全获取消息的actions字段"""
|
||||
try:
|
||||
actions = getattr(message, "actions", None)
|
||||
@@ -278,7 +280,7 @@ class ChatStream:
|
||||
logger.warning(f"获取actions字段失败: {e}")
|
||||
return None
|
||||
|
||||
def _extract_reply_from_segment(self, segment) -> Optional[str]:
|
||||
def _extract_reply_from_segment(self, segment) -> str | None:
|
||||
"""从消息段中提取reply_to信息"""
|
||||
try:
|
||||
if hasattr(segment, "type") and segment.type == "seglist":
|
||||
@@ -391,8 +393,8 @@ class ChatManager:
|
||||
|
||||
def __init__(self):
|
||||
if not self._initialized:
|
||||
self.streams: Dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: Dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
self.streams: dict[str, ChatStream] = {} # stream_id -> ChatStream
|
||||
self.last_messages: dict[str, "MessageRecv"] = {} # stream_id -> last_message
|
||||
# try:
|
||||
# async with get_db_session() as session:
|
||||
# db.connect(reuse_if_open=True)
|
||||
@@ -414,7 +416,7 @@ class ChatManager:
|
||||
await self.load_all_streams()
|
||||
logger.info(f"聊天管理器已启动,已加载 {len(self.streams)} 个聊天流")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天管理器启动失败: {str(e)}")
|
||||
logger.error(f"聊天管理器启动失败: {e!s}")
|
||||
|
||||
async def _auto_save_task(self):
|
||||
"""定期自动保存所有聊天流"""
|
||||
@@ -424,7 +426,7 @@ class ChatManager:
|
||||
await self._save_all_streams()
|
||||
logger.info("聊天流自动保存完成")
|
||||
except Exception as e:
|
||||
logger.error(f"聊天流自动保存失败: {str(e)}")
|
||||
logger.error(f"聊天流自动保存失败: {e!s}")
|
||||
|
||||
def register_message(self, message: "MessageRecv"):
|
||||
"""注册消息到聊天流"""
|
||||
@@ -437,9 +439,7 @@ class ChatManager:
|
||||
# logger.debug(f"注册消息到聊天流: {stream_id}")
|
||||
|
||||
@staticmethod
|
||||
def _generate_stream_id(
|
||||
platform: str, user_info: Optional[UserInfo], group_info: Optional[GroupInfo] = None
|
||||
) -> str:
|
||||
def _generate_stream_id(platform: str, user_info: UserInfo | None, group_info: GroupInfo | None = None) -> str:
|
||||
"""生成聊天流唯一ID"""
|
||||
if not user_info and not group_info:
|
||||
raise ValueError("用户信息或群组信息必须提供")
|
||||
@@ -462,7 +462,7 @@ class ChatManager:
|
||||
return hashlib.md5(key.encode()).hexdigest()
|
||||
|
||||
async def get_or_create_stream(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream:
|
||||
"""获取或创建聊天流
|
||||
|
||||
@@ -572,7 +572,7 @@ class ChatManager:
|
||||
await self._save_stream(stream)
|
||||
return stream
|
||||
|
||||
def get_stream(self, stream_id: str) -> Optional[ChatStream]:
|
||||
def get_stream(self, stream_id: str) -> ChatStream | None:
|
||||
"""通过stream_id获取聊天流"""
|
||||
stream = self.streams.get(stream_id)
|
||||
if not stream:
|
||||
@@ -582,13 +582,13 @@ class ChatManager:
|
||||
return stream
|
||||
|
||||
def get_stream_by_info(
|
||||
self, platform: str, user_info: UserInfo, group_info: Optional[GroupInfo] = None
|
||||
) -> Optional[ChatStream]:
|
||||
self, platform: str, user_info: UserInfo, group_info: GroupInfo | None = None
|
||||
) -> ChatStream | None:
|
||||
"""通过信息获取聊天流"""
|
||||
stream_id = self._generate_stream_id(platform, user_info, group_info)
|
||||
return self.streams.get(stream_id)
|
||||
|
||||
def get_stream_name(self, stream_id: str) -> Optional[str]:
|
||||
def get_stream_name(self, stream_id: str) -> str | None:
|
||||
"""根据 stream_id 获取聊天流名称"""
|
||||
stream = self.get_stream(stream_id)
|
||||
if not stream:
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
import base64
|
||||
import time
|
||||
from abc import abstractmethod, ABCMeta
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Any
|
||||
from typing import Any, Optional
|
||||
|
||||
import urllib3
|
||||
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
|
||||
from maim_message import BaseMessageInfo, MessageBase, Seg, UserInfo
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils_image import get_image_manager
|
||||
from src.chat.utils.utils_video import get_video_analyzer, is_video_analysis_available
|
||||
from src.chat.utils.utils_voice import get_voice_text
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -41,8 +40,8 @@ class Message(MessageBase, metaclass=ABCMeta):
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
message_segment: Seg | None = None,
|
||||
timestamp: float | None = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
processed_plain_text: str = "",
|
||||
):
|
||||
@@ -264,7 +263,7 @@ class MessageRecv(Message):
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {str(e)}")
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
@@ -278,7 +277,7 @@ class MessageRecv(Message):
|
||||
logger.info("未启用视频识别")
|
||||
return "[视频]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@@ -291,7 +290,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
self.is_superchat = False
|
||||
self.gift_info = None
|
||||
self.gift_name = None
|
||||
self.gift_count: Optional[str] = None
|
||||
self.gift_count: str | None = None
|
||||
self.superchat_info = None
|
||||
self.superchat_price = None
|
||||
self.superchat_message_text = None
|
||||
@@ -444,7 +443,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
logger.warning("视频消息中没有base64数据")
|
||||
return "[收到视频消息,但数据异常]"
|
||||
except Exception as e:
|
||||
logger.error(f"视频处理失败: {str(e)}")
|
||||
logger.error(f"视频处理失败: {e!s}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"错误详情: {traceback.format_exc()}")
|
||||
@@ -458,7 +457,7 @@ class MessageRecvS4U(MessageRecv):
|
||||
logger.info("未启用视频识别")
|
||||
return "[视频]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {segment.type}, 数据: {segment.data}")
|
||||
return f"[处理失败的{segment.type}消息]"
|
||||
|
||||
|
||||
@@ -471,10 +470,10 @@ class MessageProcessBase(Message):
|
||||
message_id: str,
|
||||
chat_stream: "ChatStream",
|
||||
bot_user_info: UserInfo,
|
||||
message_segment: Optional[Seg] = None,
|
||||
message_segment: Seg | None = None,
|
||||
reply: Optional["MessageRecv"] = None,
|
||||
thinking_start_time: float = 0,
|
||||
timestamp: Optional[float] = None,
|
||||
timestamp: float | None = None,
|
||||
):
|
||||
# 调用父类初始化,传递时间戳
|
||||
super().__init__(
|
||||
@@ -533,9 +532,9 @@ class MessageProcessBase(Message):
|
||||
return f"[回复<{self.reply.message_info.user_info.user_nickname}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
|
||||
return None
|
||||
else:
|
||||
return f"[{seg.type}:{str(seg.data)}]"
|
||||
return f"[{seg.type}:{seg.data!s}]"
|
||||
except Exception as e:
|
||||
logger.error(f"处理消息段失败: {str(e)}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
logger.error(f"处理消息段失败: {e!s}, 类型: {seg.type}, 数据: {seg.data}")
|
||||
return f"[处理失败的{seg.type}消息]"
|
||||
|
||||
def _generate_detailed_text(self) -> str:
|
||||
@@ -565,8 +564,7 @@ class MessageSending(MessageProcessBase):
|
||||
is_emoji: bool = False,
|
||||
thinking_start_time: float = 0,
|
||||
apply_set_reply_logic: bool = False,
|
||||
reply_to: Optional[str] = None,
|
||||
selected_expressions:List[int] = None,
|
||||
reply_to: str | None = None,
|
||||
):
|
||||
# 调用父类初始化
|
||||
super().__init__(
|
||||
@@ -638,11 +636,11 @@ class MessageSet:
|
||||
self.messages.append(message)
|
||||
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
|
||||
|
||||
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
|
||||
def get_message_by_index(self, index: int) -> MessageSending | None:
|
||||
"""通过索引获取消息"""
|
||||
return self.messages[index] if 0 <= index < len(self.messages) else None
|
||||
|
||||
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
|
||||
def get_message_by_time(self, target_time: float) -> MessageSending | None:
|
||||
"""获取最接近指定时间的消息"""
|
||||
if not self.messages:
|
||||
return None
|
||||
|
||||
@@ -1,15 +1,16 @@
|
||||
import re
|
||||
import json
|
||||
import traceback
|
||||
from typing import Union
|
||||
|
||||
import orjson
|
||||
from sqlalchemy import select, desc, update
|
||||
from sqlalchemy import desc, select, update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages, Images, get_db_session
|
||||
from src.common.database.sqlalchemy_database_api import get_db_session
|
||||
from src.common.database.sqlalchemy_models import Images, Messages
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from .chat_stream import ChatStream
|
||||
from .message import MessageSending, MessageRecv
|
||||
from .message import MessageRecv, MessageSending
|
||||
|
||||
logger = get_logger("message_storage")
|
||||
|
||||
@@ -33,7 +34,7 @@ class MessageStorage:
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def store_message(message: Union[MessageSending, MessageRecv], chat_stream: ChatStream) -> None:
|
||||
async def store_message(message: MessageSending | MessageRecv, chat_stream: ChatStream) -> None:
|
||||
"""存储消息到数据库"""
|
||||
try:
|
||||
# 过滤敏感信息的正则模式
|
||||
@@ -292,6 +293,7 @@ class MessageStorage:
|
||||
try:
|
||||
async with get_db_session() as session:
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from src.common.database.sqlalchemy_models import Messages
|
||||
|
||||
# 查找需要修复的记录:interest_value为0、null或很小的值
|
||||
|
||||
@@ -3,12 +3,11 @@ import traceback
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.common.message.api import get_global_api
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.message import MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.utils.utils import truncate_message
|
||||
from src.chat.utils.utils import calculate_typing_time
|
||||
from src.chat.utils.utils import calculate_typing_time, truncate_message
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message.api import get_global_api
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -27,7 +26,7 @@ async def send_message(message: MessageSending, show_log=True) -> bool:
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {str(e)}")
|
||||
logger.error(f"发送消息 '{message_preview}' 发往平台'{message.message_info.platform}' 失败: {e!s}")
|
||||
traceback.print_exc()
|
||||
raise e # 重新抛出其他异常
|
||||
|
||||
|
||||
Reference in New Issue
Block a user