修复ChatStream循环导入问题
This commit is contained in:
@@ -13,11 +13,13 @@
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.common.logger import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("chat_api")
|
||||
|
||||
|
||||
@@ -31,7 +33,7 @@ class ChatManager:
|
||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||
|
||||
@staticmethod
|
||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有聊天流
|
||||
|
||||
@@ -48,6 +50,7 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
streams.extend(
|
||||
stream for stream in get_chat_manager().streams.values()
|
||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
|
||||
@@ -58,7 +61,7 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有群聊聊天流
|
||||
|
||||
@@ -72,6 +75,7 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
streams.extend(
|
||||
stream for stream in get_chat_manager().streams.values()
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info
|
||||
@@ -82,7 +86,7 @@ class ChatManager:
|
||||
return streams
|
||||
|
||||
@staticmethod
|
||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||
# sourcery skip: for-append-to-extend
|
||||
"""获取所有私聊聊天流
|
||||
|
||||
@@ -99,6 +103,7 @@ class ChatManager:
|
||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||
streams = []
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
streams.extend(
|
||||
stream for stream in get_chat_manager().streams.values()
|
||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info
|
||||
@@ -111,7 +116,7 @@ class ChatManager:
|
||||
@staticmethod
|
||||
def get_group_stream_by_group_id(
|
||||
group_id: str, platform: str | None | SpecialTypes = "qq"
|
||||
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
|
||||
) -> "ChatStream | None": # sourcery skip: remove-unnecessary-cast
|
||||
"""根据群ID获取聊天流
|
||||
|
||||
Args:
|
||||
@@ -132,6 +137,7 @@ class ChatManager:
|
||||
if not group_id:
|
||||
raise ValueError("group_id 不能为空")
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
for stream in get_chat_manager().streams.values():
|
||||
if (
|
||||
stream.group_info
|
||||
@@ -148,7 +154,7 @@ class ChatManager:
|
||||
@staticmethod
|
||||
def get_private_stream_by_user_id(
|
||||
user_id: str, platform: str | None | SpecialTypes = "qq"
|
||||
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
|
||||
) -> "ChatStream | None": # sourcery skip: remove-unnecessary-cast
|
||||
"""根据用户ID获取私聊流
|
||||
|
||||
Args:
|
||||
@@ -169,6 +175,7 @@ class ChatManager:
|
||||
if not user_id:
|
||||
raise ValueError("user_id 不能为空")
|
||||
try:
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
for stream in get_chat_manager().streams.values():
|
||||
if (
|
||||
not stream.group_info
|
||||
@@ -184,7 +191,7 @@ class ChatManager:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
def get_stream_type(chat_stream: "ChatStream") -> str:
|
||||
"""获取聊天流类型
|
||||
|
||||
Args:
|
||||
@@ -197,6 +204,7 @@ class ChatManager:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||
if not chat_stream:
|
||||
@@ -207,7 +215,7 @@ class ChatManager:
|
||||
return "unknown"
|
||||
|
||||
@staticmethod
|
||||
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
|
||||
def get_stream_info(chat_stream: "ChatStream") -> dict[str, Any]:
|
||||
"""获取聊天流详细信息
|
||||
|
||||
Args:
|
||||
@@ -220,6 +228,7 @@ class ChatManager:
|
||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||
ValueError: 如果 chat_stream 为空
|
||||
"""
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
if not chat_stream:
|
||||
raise ValueError("chat_stream 不能为 None")
|
||||
if not isinstance(chat_stream, ChatStream):
|
||||
@@ -289,37 +298,37 @@ class ChatManager:
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||
"""获取所有聊天流的便捷函数"""
|
||||
return ChatManager.get_all_streams(platform)
|
||||
|
||||
|
||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||
"""获取群聊聊天流的便捷函数"""
|
||||
return ChatManager.get_group_streams(platform)
|
||||
|
||||
|
||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||
"""获取私聊聊天流的便捷函数"""
|
||||
return ChatManager.get_private_streams(platform)
|
||||
|
||||
|
||||
def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
|
||||
def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> "ChatStream | None":
|
||||
"""根据群ID获取聊天流的便捷函数"""
|
||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||
|
||||
|
||||
def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
|
||||
def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> "ChatStream | None":
|
||||
"""根据用户ID获取私聊流的便捷函数"""
|
||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||
|
||||
|
||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
||||
def get_stream_type(chat_stream: "ChatStream") -> str:
|
||||
"""获取聊天流类型的便捷函数"""
|
||||
return ChatManager.get_stream_type(chat_stream)
|
||||
|
||||
|
||||
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
|
||||
def get_stream_info(chat_stream: "ChatStream") -> dict[str, Any]:
|
||||
"""获取聊天流信息的便捷函数"""
|
||||
return ChatManager.get_stream_info(chat_stream)
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import Any, TYPE_CHECKING
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.chat_message_builder import (
|
||||
build_readable_messages_with_id,
|
||||
get_raw_msg_before_timestamp_with_chat,
|
||||
@@ -15,6 +15,9 @@ from src.common.message_repository import get_user_messages_from_streams
|
||||
from src.config.config import global_config
|
||||
from src.config.official_configs import ContextGroup
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
logger = get_logger("cross_context_api")
|
||||
|
||||
|
||||
@@ -51,7 +54,7 @@ async def get_context_group(chat_id: str) -> ContextGroup | None:
|
||||
return None
|
||||
|
||||
|
||||
async def build_cross_context_normal(chat_stream: ChatStream, context_group: ContextGroup) -> str:
|
||||
async def build_cross_context_normal(chat_stream: "ChatStream", context_group: ContextGroup) -> str:
|
||||
"""
|
||||
构建跨群聊/私聊上下文 (Normal模式)。
|
||||
|
||||
@@ -124,7 +127,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
|
||||
|
||||
|
||||
async def build_cross_context_s4u(
|
||||
chat_stream: ChatStream,
|
||||
chat_stream: "ChatStream",
|
||||
target_user_info: dict[str, Any] | None,
|
||||
) -> str:
|
||||
"""
|
||||
|
||||
@@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any
|
||||
|
||||
from rich.traceback import install
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.utils.utils import process_llm_response
|
||||
from src.common.data_models.database_data_model import DatabaseMessages
|
||||
from src.common.logger import get_logger
|
||||
@@ -21,6 +20,7 @@ from src.plugin_system.base.component_types import ActionInfo
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from chat.replyer.default_generator import DefaultReplyer
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
|
||||
install(extra_lines=3)
|
||||
|
||||
@@ -34,7 +34,7 @@ logger = get_logger("generator_api")
|
||||
|
||||
|
||||
async def get_replyer(
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_stream: "ChatStream | None" = None,
|
||||
chat_id: str | None = None,
|
||||
request_type: str = "replyer",
|
||||
) -> "DefaultReplyer | None":
|
||||
@@ -78,7 +78,7 @@ async def get_replyer(
|
||||
|
||||
|
||||
async def generate_reply(
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_stream: "ChatStream | None" = None,
|
||||
chat_id: str | None = None,
|
||||
action_data: dict[str, Any] | None = None,
|
||||
reply_to: str = "",
|
||||
@@ -189,7 +189,7 @@ async def generate_reply(
|
||||
|
||||
|
||||
async def rewrite_reply(
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_stream: "ChatStream | None" = None,
|
||||
reply_data: dict[str, Any] | None = None,
|
||||
chat_id: str | None = None,
|
||||
enable_splitter: bool = True,
|
||||
@@ -287,7 +287,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
|
||||
|
||||
|
||||
async def generate_response_custom(
|
||||
chat_stream: ChatStream | None = None,
|
||||
chat_stream: "ChatStream | None" = None,
|
||||
chat_id: str | None = None,
|
||||
request_type: str = "generator_api",
|
||||
prompt: str = "",
|
||||
|
||||
Reference in New Issue
Block a user