部分类型注解修复,优化import顺序,删除无用API文件

This commit is contained in:
UnCLAS-Prommer
2025-07-12 00:34:49 +08:00
parent 3165a0f8df
commit b303a95f61
44 changed files with 405 additions and 1166 deletions

View File

@@ -1,23 +1,25 @@
import traceback
import os
import re
from typing import Dict, Any
from maim_message import UserInfo
from src.common.logger import get_logger
from src.config.config import global_config
from src.mood.mood_manager import mood_manager # 导入情绪管理器
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.chat_stream import get_chat_manager, ChatStream
from src.chat.message_receive.message import MessageRecv
from src.experimental.only_message_process import MessageProcessor
from src.chat.message_receive.storage import MessageStorage
from src.experimental.PFC.pfc_manager import PFCManager
from src.chat.heart_flow.heartflow_message_processor import HeartFCMessageReceiver
from src.chat.utils.prompt_builder import Prompt, global_prompt_manager
from src.config.config import global_config
from src.experimental.only_message_process import MessageProcessor
from src.experimental.PFC.pfc_manager import PFCManager
from src.plugin_system.core.component_registry import component_registry # 导入新插件系统
from src.plugin_system.base.base_command import BaseCommand
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
from maim_message import UserInfo
from src.chat.message_receive.chat_stream import ChatStream
import re
# 定义日志配置
# 获取项目根目录假设本文件在src/chat/message_receive/下,根目录为上上上级目录)
@@ -184,8 +186,8 @@ class ChatBot:
get_chat_manager().register_message(message)
chat = await get_chat_manager().get_or_create_stream(
platform=message.message_info.platform,
user_info=user_info,
platform=message.message_info.platform, # type: ignore
user_info=user_info, # type: ignore
group_info=group_info,
)
@@ -195,8 +197,10 @@ class ChatBot:
await message.process()
# 过滤检查
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex(
message.raw_message, chat, user_info
if _check_ban_words(message.processed_plain_text, chat, user_info) or _check_ban_regex( # type: ignore
message.raw_message, # type: ignore
chat,
user_info, # type: ignore
):
return

View File

@@ -3,18 +3,17 @@ import hashlib
import time
import copy
from typing import Dict, Optional, TYPE_CHECKING
from ...common.database.database import db
from ...common.database.database_model import ChatStreams # 新增导入
from rich.traceback import install
from maim_message import GroupInfo, UserInfo
from src.common.logger import get_logger
from src.common.database.database import db
from src.common.database.database_model import ChatStreams # 新增导入
# 避免循环导入使用TYPE_CHECKING进行类型提示
if TYPE_CHECKING:
from .message import MessageRecv
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
@@ -28,7 +27,7 @@ class ChatMessageContext:
def __init__(self, message: "MessageRecv"):
self.message = message
def get_template_name(self) -> str:
def get_template_name(self) -> Optional[str]:
"""获取模板名称"""
if self.message.message_info.template_info and not self.message.message_info.template_info.template_default:
return self.message.message_info.template_info.template_name
@@ -41,10 +40,10 @@ class ChatMessageContext:
def check_types(self, types: list) -> bool:
# sourcery skip: invert-any-all, use-any, use-next
"""检查消息类型"""
if not self.message.message_info.format_info.accept_format:
if not self.message.message_info.format_info.accept_format: # type: ignore
return False
for t in types:
if t not in self.message.message_info.format_info.accept_format:
if t not in self.message.message_info.format_info.accept_format: # type: ignore
return False
return True
@@ -68,7 +67,7 @@ class ChatStream:
platform: str,
user_info: UserInfo,
group_info: Optional[GroupInfo] = None,
data: dict = None,
data: Optional[dict] = None,
):
self.stream_id = stream_id
self.platform = platform
@@ -77,7 +76,7 @@ class ChatStream:
self.create_time = data.get("create_time", time.time()) if data else time.time()
self.last_active_time = data.get("last_active_time", self.create_time) if data else self.create_time
self.saved = False
self.context: ChatMessageContext = None # 用于存储该聊天的上下文信息
self.context: ChatMessageContext = None # type: ignore # 用于存储该聊天的上下文信息
def to_dict(self) -> dict:
"""转换为字典格式"""
@@ -99,7 +98,7 @@ class ChatStream:
return cls(
stream_id=data["stream_id"],
platform=data["platform"],
user_info=user_info,
user_info=user_info, # type: ignore
group_info=group_info,
data=data,
)
@@ -163,8 +162,8 @@ class ChatManager:
def register_message(self, message: "MessageRecv"):
"""注册消息到聊天流"""
stream_id = self._generate_stream_id(
message.message_info.platform,
message.message_info.user_info,
message.message_info.platform, # type: ignore
message.message_info.user_info, # type: ignore
message.message_info.group_info,
)
self.last_messages[stream_id] = message
@@ -185,10 +184,7 @@ class ChatManager:
def get_stream_id(self, platform: str, id: str, is_group: bool = True) -> str:
"""获取聊天流ID"""
if is_group:
components = [platform, str(id)]
else:
components = [platform, str(id), "private"]
components = [platform, id] if is_group else [platform, id, "private"]
key = "_".join(components)
return hashlib.md5(key.encode()).hexdigest()

View File

@@ -1,17 +1,15 @@
import time
from abc import abstractmethod
from dataclasses import dataclass
from typing import Optional, Any, TYPE_CHECKING
import urllib3
from src.common.logger import get_logger
if TYPE_CHECKING:
from .chat_stream import ChatStream
from ..utils.utils_image import get_image_manager
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from abc import abstractmethod
from dataclasses import dataclass
from rich.traceback import install
from typing import Optional, Any
from maim_message import Seg, UserInfo, BaseMessageInfo, MessageBase
from src.common.logger import get_logger
from src.chat.utils.utils_image import get_image_manager
from .chat_stream import ChatStream
install(extra_lines=3)
@@ -27,7 +25,7 @@ urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
@dataclass
class Message(MessageBase):
chat_stream: "ChatStream" = None
chat_stream: "ChatStream" = None # type: ignore
reply: Optional["Message"] = None
processed_plain_text: str = ""
memorized_times: int = 0
@@ -55,7 +53,7 @@ class Message(MessageBase):
)
# 调用父类初始化
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None)
super().__init__(message_info=message_info, message_segment=message_segment, raw_message=None) # type: ignore
self.chat_stream = chat_stream
# 文本处理相关属性
@@ -66,6 +64,7 @@ class Message(MessageBase):
self.reply = reply
async def _process_message_segments(self, segment: Seg) -> str:
# sourcery skip: remove-unnecessary-else, swap-if-else-branches
"""递归处理消息段,转换为文字描述
Args:
@@ -78,13 +77,13 @@ class Message(MessageBase):
# 处理消息段列表
segments_text = []
for seg in segment.data:
processed = await self._process_message_segments(seg)
processed = await self._process_message_segments(seg) # type: ignore
if processed:
segments_text.append(processed)
return " ".join(segments_text)
else:
# 处理单个消息段
return await self._process_single_segment(segment)
return await self._process_single_segment(segment) # type: ignore
@abstractmethod
async def _process_single_segment(self, segment):
@@ -138,7 +137,7 @@ class MessageRecv(Message):
if segment.type == "text":
self.is_picid = False
self.is_emoji = False
return segment.data
return segment.data # type: ignore
elif segment.type == "image":
# 如果是base64图片数据
if isinstance(segment.data, str):
@@ -160,7 +159,7 @@ class MessageRecv(Message):
elif segment.type == "mention_bot":
self.is_picid = False
self.is_emoji = False
self.is_mentioned = float(segment.data)
self.is_mentioned = float(segment.data) # type: ignore
return ""
elif segment.type == "priority_info":
self.is_picid = False
@@ -186,7 +185,7 @@ class MessageRecv(Message):
"""生成详细文本,包含时间和用户信息"""
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}] {name}: {self.processed_plain_text}\n"
@@ -234,7 +233,7 @@ class MessageProcessBase(Message):
"""
try:
if seg.type == "text":
return seg.data
return seg.data # type: ignore
elif seg.type == "image":
# 如果是base64图片数据
if isinstance(seg.data, str):
@@ -250,7 +249,7 @@ class MessageProcessBase(Message):
if self.reply and hasattr(self.reply, "processed_plain_text"):
# print(f"self.reply.processed_plain_text: {self.reply.processed_plain_text}")
# print(f"reply: {self.reply}")
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]"
return f"[回复<{self.reply.message_info.user_info.user_nickname}:{self.reply.message_info.user_info.user_id}> 的消息:{self.reply.processed_plain_text}]" # type: ignore
return None
else:
return f"[{seg.type}:{str(seg.data)}]"
@@ -264,7 +263,7 @@ class MessageProcessBase(Message):
timestamp = self.message_info.time
user_info = self.message_info.user_info
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>"
name = f"<{self.message_info.platform}:{user_info.user_id}:{user_info.user_nickname}:{user_info.user_cardname}>" # type: ignore
return f"[{timestamp}]{name} 说:{self.processed_plain_text}\n"
@@ -313,7 +312,7 @@ class MessageSending(MessageProcessBase):
is_emoji: bool = False,
thinking_start_time: float = 0,
apply_set_reply_logic: bool = False,
reply_to: str = None,
reply_to: str = None, # type: ignore
):
# 调用父类初始化
super().__init__(
@@ -344,7 +343,7 @@ class MessageSending(MessageProcessBase):
self.message_segment = Seg(
type="seglist",
data=[
Seg(type="reply", data=self.reply.message_info.message_id),
Seg(type="reply", data=self.reply.message_info.message_id), # type: ignore
self.message_segment,
],
)
@@ -364,10 +363,10 @@ class MessageSending(MessageProcessBase):
) -> "MessageSending":
"""从思考状态消息创建发送状态消息"""
return cls(
message_id=thinking.message_info.message_id,
message_id=thinking.message_info.message_id, # type: ignore
chat_stream=thinking.chat_stream,
message_segment=message_segment,
bot_user_info=thinking.message_info.user_info,
bot_user_info=thinking.message_info.user_info, # type: ignore
reply=thinking.reply,
is_head=is_head,
is_emoji=is_emoji,
@@ -399,13 +398,11 @@ class MessageSet:
if not isinstance(message, MessageSending):
raise TypeError("MessageSet只能添加MessageSending类型的消息")
self.messages.append(message)
self.messages.sort(key=lambda x: x.message_info.time)
self.messages.sort(key=lambda x: x.message_info.time) # type: ignore
def get_message_by_index(self, index: int) -> Optional[MessageSending]:
"""通过索引获取消息"""
if 0 <= index < len(self.messages):
return self.messages[index]
return None
return self.messages[index] if 0 <= index < len(self.messages) else None
def get_message_by_time(self, target_time: float) -> Optional[MessageSending]:
"""获取最接近指定时间的消息"""
@@ -415,7 +412,7 @@ class MessageSet:
left, right = 0, len(self.messages) - 1
while left < right:
mid = (left + right) // 2
if self.messages[mid].message_info.time < target_time:
if self.messages[mid].message_info.time < target_time: # type: ignore
left = mid + 1
else:
right = mid

View File

@@ -1,21 +1,16 @@
# src/plugins/chat/message_sender.py
import asyncio
import time
from asyncio import Task
from typing import Union
from src.common.message.api import get_global_api
# from ...common.database import db # 数据库依赖似乎不需要了,注释掉
from .message import MessageSending, MessageThinking, MessageSet
from src.chat.message_receive.storage import MessageStorage
from ..utils.utils import truncate_message, calculate_typing_time, count_messages_between
from src.common.logger import get_logger
from rich.traceback import install
install(extra_lines=3)
from src.common.logger import get_logger
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message, calculate_typing_time, count_messages_between
from .message import MessageSending, MessageThinking, MessageSet
install(extra_lines=3)
logger = get_logger("sender")
@@ -79,9 +74,10 @@ class MessageContainer:
def count_thinking_messages(self) -> int:
"""计算当前容器中思考消息的数量"""
return sum(1 for msg in self.messages if isinstance(msg, MessageThinking))
return sum(isinstance(msg, MessageThinking) for msg in self.messages)
def get_timeout_sending_messages(self) -> list[MessageSending]:
# sourcery skip: merge-nested-ifs
"""获取所有超时的MessageSending对象思考时间超过20秒按thinking_start_time排序 - 从旧 sender 合并"""
current_time = time.time()
timeout_messages = []
@@ -230,9 +226,7 @@ class MessageManager:
f"[{message.chat_stream.stream_id}] 处理发送消息 {getattr(message.message_info, 'message_id', 'N/A')} 时出错: {e}"
)
logger.exception("详细错误信息:")
# 考虑是否移除出错的消息,防止无限循环
removed = container.remove_message(message)
if removed:
if container.remove_message(message):
logger.warning(f"[{message.chat_stream.stream_id}] 已移除处理出错的消息。")
async def _process_chat_messages(self, chat_id: str):
@@ -261,10 +255,7 @@ class MessageManager:
# --- 处理发送消息 ---
await self._handle_sending_message(container, message_earliest)
# --- 处理超时发送消息 (来自旧 sender) ---
# 在处理完最早的消息后,检查是否有超时的发送消息
timeout_sending_messages = container.get_timeout_sending_messages()
if timeout_sending_messages:
if timeout_sending_messages := container.get_timeout_sending_messages():
logger.debug(f"[{chat_id}] 发现 {len(timeout_sending_messages)} 条超时的发送消息")
for msg in timeout_sending_messages:
# 确保不是刚刚处理过的最早消息 (虽然理论上应该已被移除,但以防万一)
@@ -274,6 +265,7 @@ class MessageManager:
await self._handle_sending_message(container, msg) # 复用处理逻辑
async def _start_processor_loop(self):
# sourcery skip: list-comprehension, move-assign-in-block, use-named-expression
"""消息处理器主循环"""
while self._running:
tasks = []
@@ -282,10 +274,7 @@ class MessageManager:
# 创建 keys 的快照以安全迭代
chat_ids = list(self.containers.keys())
for chat_id in chat_ids:
# 为每个 chat_id 创建一个处理任务
tasks.append(asyncio.create_task(self._process_chat_messages(chat_id)))
tasks.extend(asyncio.create_task(self._process_chat_messages(chat_id)) for chat_id in chat_ids)
if tasks:
try:
# 等待当前批次的所有任务完成

View File

@@ -1,11 +1,10 @@
import re
from typing import Union
# from ...common.database.database import db # db is now Peewee's SqliteDatabase instance
from .message import MessageSending, MessageRecv
from .chat_stream import ChatStream
from ...common.database.database_model import Messages, RecalledMessages, Images # Import Peewee models
from src.common.database.database_model import Messages, RecalledMessages, Images
from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending, MessageRecv
logger = get_logger("message_storage")
@@ -44,7 +43,7 @@ class MessageStorage:
reply_to = ""
chat_info_dict = chat_stream.to_dict()
user_info_dict = message.message_info.user_info.to_dict()
user_info_dict = message.message_info.user_info.to_dict() # type: ignore
# message_id 现在是 TextField直接使用字符串值
msg_id = message.message_info.message_id
@@ -56,7 +55,7 @@ class MessageStorage:
Messages.create(
message_id=msg_id,
time=float(message.message_info.time),
time=float(message.message_info.time), # type: ignore
chat_id=chat_stream.stream_id,
# Flattened chat_info
reply_to=reply_to,
@@ -103,7 +102,7 @@ class MessageStorage:
try:
# Assuming input 'time' is a string timestamp that can be converted to float
current_time_float = float(time)
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute()
RecalledMessages.delete().where(RecalledMessages.time < (current_time_float - 300)).execute() # type: ignore
except Exception:
logger.exception("删除撤回消息失败")
@@ -115,22 +114,19 @@ class MessageStorage:
"""更新最新一条匹配消息的message_id"""
try:
if message.message_segment.type == "notify":
mmc_message_id = message.message_segment.data.get("echo")
qq_message_id = message.message_segment.data.get("actual_id")
mmc_message_id = message.message_segment.data.get("echo") # type: ignore
qq_message_id = message.message_segment.data.get("actual_id") # type: ignore
else:
logger.info(f"更新消息ID错误seg类型为{message.message_segment.type}")
return
if not qq_message_id:
logger.info("消息不存在message_id无法更新")
return
# 查询最新一条匹配消息
matched_message = (
if matched_message := (
Messages.select().where((Messages.message_id == mmc_message_id)).order_by(Messages.time.desc()).first()
)
if matched_message:
):
# 更新找到的消息记录
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute()
Messages.update(message_id=qq_message_id).where(Messages.id == matched_message.id).execute() # type: ignore
logger.debug(f"更新消息ID成功: {matched_message.message_id} -> {qq_message_id}")
else:
logger.debug("未找到匹配的消息")
@@ -155,10 +151,7 @@ class MessageStorage:
image_record = (
Images.select().where(Images.description == description).order_by(Images.timestamp.desc()).first()
)
if image_record:
return f"[picid:{image_record.image_id}]"
else:
return match.group(0) # 保持原样
return f"[picid:{image_record.image_id}]" if image_record else match.group(0)
except Exception:
return match.group(0)

View File

@@ -1,16 +1,17 @@
import asyncio
from typing import Dict, Optional # 重新导入类型
from src.chat.message_receive.message import MessageSending, MessageThinking
from src.common.message.api import get_global_api
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.common.logger import get_logger
from src.chat.utils.utils import calculate_typing_time
from rich.traceback import install
import traceback
install(extra_lines=3)
from typing import Dict, Optional
from rich.traceback import install
from src.common.message.api import get_global_api
from src.common.logger import get_logger
from src.chat.message_receive.message import MessageSending, MessageThinking
from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.utils import truncate_message
from src.chat.utils.utils import calculate_typing_time
install(extra_lines=3)
logger = get_logger("sender")
@@ -86,10 +87,10 @@ class HeartFCSender:
"""
if not message.chat_stream:
logger.error("消息缺少 chat_stream无法发送")
raise Exception("消息缺少 chat_stream无法发送")
raise ValueError("消息缺少 chat_stream无法发送")
if not message.message_info or not message.message_info.message_id:
logger.error("消息缺少 message_info 或 message_id无法发送")
raise Exception("消息缺少 message_info 或 message_id无法发送")
raise ValueError("消息缺少 message_info 或 message_id无法发送")
chat_id = message.chat_stream.stream_id
message_id = message.message_info.message_id