diff --git a/src/api/message_router.py b/src/api/message_router.py index f7a57bed7..5e707cc95 100644 --- a/src/api/message_router.py +++ b/src/api/message_router.py @@ -3,7 +3,6 @@ from typing import Literal from fastapi import APIRouter, Depends, HTTPException, Query -from src.chat.message_receive.chat_stream import get_chat_manager from src.common.logger import get_logger from src.common.security import get_api_key from src.config.config import global_config @@ -123,6 +122,7 @@ async def get_message_stats_by_chat( return stats # 获取聊天管理器以查询会话信息 + from src.chat.message_receive.chat_stream import get_chat_manager chat_manager = get_chat_manager() formatted_stats = {} # 遍历统计结果进行格式化 diff --git a/src/chat/message_receive/chat_stream.py b/src/chat/message_receive/chat_stream.py index 82f0ac659..e9b2e833a 100644 --- a/src/chat/message_receive/chat_stream.py +++ b/src/chat/message_receive/chat_stream.py @@ -8,10 +8,7 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo from src.common.data_models.database_data_model import DatabaseMessages -from src.common.data_models.message_manager_data_model import StreamContext -from src.plugin_system.base.component_types import ChatMode, ChatType from src.common.database.api.crud import CRUDBase -from src.common.database.api.specialized import get_or_create_chat_stream from src.common.database.compatibility import get_db_session from src.common.database.core.models import ChatStreams # 新增导入 from src.common.logger import get_logger @@ -43,6 +40,8 @@ class ChatStream: self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0 self.saved = False + from src.common.data_models.message_manager_data_model import StreamContext + from src.plugin_system.base.component_types import ChatMode, ChatType self.context: StreamContext = StreamContext( stream_id=stream_id, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, @@ -407,6 +406,7 @@ class ChatManager: stream.group_info = group_info else: current_time = time.time() + from src.common.database.api.specialized import get_or_create_chat_stream model_instance, _ = await get_or_create_chat_stream( stream_id=stream_id, platform=platform, diff --git a/src/chat/message_receive/message_handler.py b/src/chat/message_receive/message_handler.py index 3b922660b..6434bd1d1 100644 --- a/src/chat/message_receive/message_handler.py +++ b/src/chat/message_receive/message_handler.py @@ -38,7 +38,6 @@ from typing import TYPE_CHECKING, Any from mofox_bus import MessageEnvelope, MessageRuntime from src.chat.message_manager import message_manager -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.message_receive.storage import MessageStorage from src.chat.utils.prompt import global_prompt_manager from src.chat.utils.utils import is_mentioned_bot_in_message @@ -261,7 +260,8 @@ class MessageHandler: # 获取或创建聊天流 platform = message_info.get("platform", "unknown") - + + from src.chat.message_receive.chat_stream import get_chat_manager chat = await get_chat_manager().get_or_create_stream( platform=platform, user_info=user_info, # type: ignore @@ -281,6 +281,7 @@ class MessageHandler: message.chat_info.last_active_time = chat.last_active_time # 注册消息到聊天管理器 + from src.chat.message_receive.chat_stream import get_chat_manager get_chat_manager().register_message(message) # 检测是否提及机器人 diff --git a/src/chat/message_receive/storage.py b/src/chat/message_receive/storage.py index 4dee0745d..8e65245c7 100644 --- a/src/chat/message_receive/storage.py +++ b/src/chat/message_receive/storage.py @@ -3,7 +3,7 @@ import re import time import traceback from collections import deque -from typing import Optional +from typing import Optional, TYPE_CHECKING import orjson from sqlalchemy import desc, select, update @@ -13,9 +13,11 @@ from src.common.database.core import get_db_session from src.common.database.core.models import Images, Messages from src.common.logger import get_logger -from .chat_stream import ChatStream from .message import MessageSending +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("message_storage") @@ -479,7 +481,7 @@ class MessageStorage: return [] @staticmethod - async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None: + async def store_message(message: DatabaseMessages | MessageSending, chat_stream: "ChatStream", use_batch: bool = True) -> None: """ 存储消息到数据库 diff --git a/src/chat/planner_actions/action_modifier.py b/src/chat/planner_actions/action_modifier.py index 4cc2992f5..db6c84804 100644 --- a/src/chat/planner_actions/action_modifier.py +++ b/src/chat/planner_actions/action_modifier.py @@ -4,7 +4,7 @@ import random import time from typing import TYPE_CHECKING, Any, cast -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat from src.common.logger import get_logger @@ -15,6 +15,7 @@ from src.plugin_system.core.global_announcement_manager import global_announceme if TYPE_CHECKING: from src.common.data_models.message_manager_data_model import StreamContext + from src.chat.message_receive.chat_stream import ChatStream logger = get_logger("action_manager") @@ -31,7 +32,7 @@ class ActionModifier: """初始化动作处理器""" self.chat_id = chat_id # chat_stream 和 log_prefix 将在异步方法中初始化 - self.chat_stream: ChatStream | None = None + self.chat_stream: "ChatStream | None" = None self.log_prefix = f"[{chat_id}]" self.action_manager = action_manager diff --git a/src/chat/replyer/default_generator.py b/src/chat/replyer/default_generator.py index a3524dc83..baebbcfa5 100644 --- a/src/chat/replyer/default_generator.py +++ b/src/chat/replyer/default_generator.py @@ -9,10 +9,9 @@ import re import time import traceback from datetime import datetime, timedelta -from typing import Any, Literal +from typing import Any, Literal, TYPE_CHECKING from src.chat.express.expression_selector import expression_selector -from src.chat.message_receive.chat_stream import ChatStream from src.chat.message_receive.message import MessageSending, Seg, UserInfo from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.utils.chat_message_builder import ( @@ -38,6 +37,9 @@ from src.plugin_system.apis import llm_api from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.base.component_types import ActionInfo, EventType +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("replyer") # 用于存储后台任务的集合,防止被垃圾回收 @@ -236,7 +238,7 @@ If you need to use the search tool, please directly call the function "lpmm_sear class DefaultReplyer: def __init__( self, - chat_stream: ChatStream, + chat_stream: "ChatStream", request_type: str = "replyer", ): self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) diff --git a/src/chat/replyer/replyer_manager.py b/src/chat/replyer/replyer_manager.py index 4f3f4f428..bc908d728 100644 --- a/src/chat/replyer/replyer_manager.py +++ b/src/chat/replyer/replyer_manager.py @@ -1,7 +1,11 @@ -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.replyer.default_generator import DefaultReplyer from src.common.logger import get_logger +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream logger = get_logger("ReplyerManager") @@ -11,7 +15,7 @@ class ReplyerManager: async def get_replyer( self, - chat_stream: ChatStream | None = None, + chat_stream: "ChatStream | None" = None, chat_id: str | None = None, request_type: str = "replyer", ) -> DefaultReplyer | None: diff --git a/src/chat/utils/utils.py b/src/chat/utils/utils.py index 2c2713cda..30217d79b 100644 --- a/src/chat/utils/utils.py +++ b/src/chat/utils/utils.py @@ -10,8 +10,6 @@ import numpy as np import rjieba from mofox_bus import UserInfo -from src.chat.message_receive.chat_stream import get_chat_manager - # MessageRecv 已被移除,现在使用 DatabaseMessages from src.common.logger import get_logger from src.common.message_repository import count_messages, find_messages @@ -780,6 +778,7 @@ async def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None chat_target_info = None try: + from src.chat.message_receive.chat_stream import get_chat_manager if chat_stream := await get_chat_manager().get_stream(chat_id): if chat_stream.group_info: is_group_chat = True diff --git a/src/main.py b/src/main.py index ae62b18a5..7721e9c44 100644 --- a/src/main.py +++ b/src/main.py @@ -13,7 +13,6 @@ from rich.traceback import install from src.chat.emoji_system.emoji_manager import get_emoji_manager from chat.message_receive.message_handler import get_message_handler, shutdown_message_handler -from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.common.core_sink_manager import ( CoreSinkManager, @@ -469,6 +468,7 @@ MoFox_Bot(第三方修改版) logger.info("情绪管理器初始化成功") # 启动聊天管理器的自动保存任务 + from src.chat.message_receive.chat_stream import get_chat_manager task = asyncio.create_task(get_chat_manager()._auto_save_task()) _background_tasks.add(task) task.add_done_callback(_background_tasks.discard) diff --git a/src/plugin_system/apis/chat_api.py b/src/plugin_system/apis/chat_api.py index 2f2d5d1df..71926384f 100644 --- a/src/plugin_system/apis/chat_api.py +++ b/src/plugin_system/apis/chat_api.py @@ -13,11 +13,13 @@ """ from enum import Enum -from typing import Any +from typing import Any, TYPE_CHECKING -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.common.logger import get_logger +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("chat_api") @@ -31,7 +33,7 @@ class ChatManager: """聊天管理器 - 专门负责聊天信息的查询和管理""" @staticmethod - def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: + def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]: # sourcery skip: for-append-to-extend """获取所有聊天流 @@ -48,6 +50,7 @@ class ChatManager: raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") streams = [] try: + from src.chat.message_receive.chat_stream import get_chat_manager streams.extend( stream for stream in get_chat_manager().streams.values() if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform @@ -58,7 +61,7 @@ class ChatManager: return streams @staticmethod - def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: + def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]: # sourcery skip: for-append-to-extend """获取所有群聊聊天流 @@ -72,6 +75,7 @@ class ChatManager: raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") streams = [] try: + from src.chat.message_receive.chat_stream import get_chat_manager streams.extend( stream for stream in get_chat_manager().streams.values() if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info @@ -82,7 +86,7 @@ class ChatManager: return streams @staticmethod - def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: + def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]: # sourcery skip: for-append-to-extend """获取所有私聊聊天流 @@ -99,6 +103,7 @@ class ChatManager: raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") streams = [] try: + from src.chat.message_receive.chat_stream import get_chat_manager streams.extend( stream for stream in get_chat_manager().streams.values() if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info @@ -111,7 +116,7 @@ class ChatManager: @staticmethod def get_group_stream_by_group_id( group_id: str, platform: str | None | SpecialTypes = "qq" - ) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast + ) -> "ChatStream | None": # sourcery skip: remove-unnecessary-cast """根据群ID获取聊天流 Args: @@ -132,6 +137,7 @@ class ChatManager: if not group_id: raise ValueError("group_id 不能为空") try: + from src.chat.message_receive.chat_stream import get_chat_manager for stream in get_chat_manager().streams.values(): if ( stream.group_info @@ -148,7 +154,7 @@ class ChatManager: @staticmethod def get_private_stream_by_user_id( user_id: str, platform: str | None | SpecialTypes = "qq" - ) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast + ) -> "ChatStream | None": # sourcery skip: remove-unnecessary-cast """根据用户ID获取私聊流 Args: @@ -169,6 +175,7 @@ class ChatManager: if not user_id: raise ValueError("user_id 不能为空") try: + from src.chat.message_receive.chat_stream import get_chat_manager for stream in get_chat_manager().streams.values(): if ( not stream.group_info @@ -184,7 +191,7 @@ class ChatManager: return None @staticmethod - def get_stream_type(chat_stream: ChatStream) -> str: + def get_stream_type(chat_stream: "ChatStream") -> str: """获取聊天流类型 Args: @@ -197,6 +204,7 @@ class ChatManager: TypeError: 如果 chat_stream 不是 ChatStream 类型 ValueError: 如果 chat_stream 为空 """ + from src.chat.message_receive.chat_stream import ChatStream if not isinstance(chat_stream, ChatStream): raise TypeError("chat_stream 必须是 ChatStream 类型") if not chat_stream: @@ -207,7 +215,7 @@ class ChatManager: return "unknown" @staticmethod - def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]: + def get_stream_info(chat_stream: "ChatStream") -> dict[str, Any]: """获取聊天流详细信息 Args: @@ -220,6 +228,7 @@ class ChatManager: TypeError: 如果 chat_stream 不是 ChatStream 类型 ValueError: 如果 chat_stream 为空 """ + from src.chat.message_receive.chat_stream import ChatStream if not chat_stream: raise ValueError("chat_stream 不能为 None") if not isinstance(chat_stream, ChatStream): @@ -289,37 +298,37 @@ class ChatManager: # ============================================================================= -def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: +def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]: """获取所有聊天流的便捷函数""" return ChatManager.get_all_streams(platform) -def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: +def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]: """获取群聊聊天流的便捷函数""" return ChatManager.get_group_streams(platform) -def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: +def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]: """获取私聊聊天流的便捷函数""" return ChatManager.get_private_streams(platform) -def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None: +def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> "ChatStream | None": """根据群ID获取聊天流的便捷函数""" return ChatManager.get_group_stream_by_group_id(group_id, platform) -def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None: +def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> "ChatStream | None": """根据用户ID获取私聊流的便捷函数""" return ChatManager.get_private_stream_by_user_id(user_id, platform) -def get_stream_type(chat_stream: ChatStream) -> str: +def get_stream_type(chat_stream: "ChatStream") -> str: """获取聊天流类型的便捷函数""" return ChatManager.get_stream_type(chat_stream) -def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]: +def get_stream_info(chat_stream: "ChatStream") -> dict[str, Any]: """获取聊天流信息的便捷函数""" return ChatManager.get_stream_info(chat_stream) diff --git a/src/plugin_system/apis/cross_context_api.py b/src/plugin_system/apis/cross_context_api.py index d97dcde19..b11a4e7d0 100644 --- a/src/plugin_system/apis/cross_context_api.py +++ b/src/plugin_system/apis/cross_context_api.py @@ -3,9 +3,9 @@ """ import time -from typing import Any +from typing import Any, TYPE_CHECKING -from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager +from src.chat.message_receive.chat_stream import get_chat_manager from src.chat.utils.chat_message_builder import ( build_readable_messages_with_id, get_raw_msg_before_timestamp_with_chat, @@ -15,6 +15,9 @@ from src.common.message_repository import get_user_messages_from_streams from src.config.config import global_config from src.config.official_configs import ContextGroup +if TYPE_CHECKING: + from src.chat.message_receive.chat_stream import ChatStream + logger = get_logger("cross_context_api") @@ -51,7 +54,7 @@ async def get_context_group(chat_id: str) -> ContextGroup | None: return None -async def build_cross_context_normal(chat_stream: ChatStream, context_group: ContextGroup) -> str: +async def build_cross_context_normal(chat_stream: "ChatStream", context_group: ContextGroup) -> str: """ 构建跨群聊/私聊上下文 (Normal模式)。 @@ -124,7 +127,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con async def build_cross_context_s4u( - chat_stream: ChatStream, + chat_stream: "ChatStream", target_user_info: dict[str, Any] | None, ) -> str: """ diff --git a/src/plugin_system/apis/generator_api.py b/src/plugin_system/apis/generator_api.py index 9c6fb0840..8f6185972 100644 --- a/src/plugin_system/apis/generator_api.py +++ b/src/plugin_system/apis/generator_api.py @@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any from rich.traceback import install -from src.chat.message_receive.chat_stream import ChatStream from src.chat.utils.utils import process_llm_response from src.common.data_models.database_data_model import DatabaseMessages from src.common.logger import get_logger @@ -21,6 +20,7 @@ from src.plugin_system.base.component_types import ActionInfo if TYPE_CHECKING: from chat.replyer.default_generator import DefaultReplyer + from src.chat.message_receive.chat_stream import ChatStream install(extra_lines=3) @@ -34,7 +34,7 @@ logger = get_logger("generator_api") async def get_replyer( - chat_stream: ChatStream | None = None, + chat_stream: "ChatStream | None" = None, chat_id: str | None = None, request_type: str = "replyer", ) -> "DefaultReplyer | None": @@ -78,7 +78,7 @@ async def get_replyer( async def generate_reply( - chat_stream: ChatStream | None = None, + chat_stream: "ChatStream | None" = None, chat_id: str | None = None, action_data: dict[str, Any] | None = None, reply_to: str = "", @@ -189,7 +189,7 @@ async def generate_reply( async def rewrite_reply( - chat_stream: ChatStream | None = None, + chat_stream: "ChatStream | None" = None, reply_data: dict[str, Any] | None = None, chat_id: str | None = None, enable_splitter: bool = True, @@ -287,7 +287,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo: async def generate_response_custom( - chat_stream: ChatStream | None = None, + chat_stream: "ChatStream | None" = None, chat_id: str | None = None, request_type: str = "generator_api", prompt: str = "",