修复ChatStream循环导入问题

This commit is contained in:
Windpicker-owo
2025-11-25 20:29:48 +08:00
parent 6b3b2a8245
commit c268ea2fb2
12 changed files with 65 additions and 44 deletions

View File

@@ -3,7 +3,6 @@ from typing import Literal
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from src.chat.message_receive.chat_stream import get_chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.security import get_api_key from src.common.security import get_api_key
from src.config.config import global_config from src.config.config import global_config
@@ -123,6 +122,7 @@ async def get_message_stats_by_chat(
return stats return stats
# 获取聊天管理器以查询会话信息 # 获取聊天管理器以查询会话信息
from src.chat.message_receive.chat_stream import get_chat_manager
chat_manager = get_chat_manager() chat_manager = get_chat_manager()
formatted_stats = {} formatted_stats = {}
# 遍历统计结果进行格式化 # 遍历统计结果进行格式化

View File

@@ -8,10 +8,7 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
from src.common.database.api.crud import CRUDBase from src.common.database.api.crud import CRUDBase
from src.common.database.api.specialized import get_or_create_chat_stream
from src.common.database.compatibility import get_db_session from src.common.database.compatibility import get_db_session
from src.common.database.core.models import ChatStreams # 新增导入 from src.common.database.core.models import ChatStreams # 新增导入
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -43,6 +40,8 @@ class ChatStream:
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0 self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
self.saved = False self.saved = False
from src.common.data_models.message_manager_data_model import StreamContext
from src.plugin_system.base.component_types import ChatMode, ChatType
self.context: StreamContext = StreamContext( self.context: StreamContext = StreamContext(
stream_id=stream_id, stream_id=stream_id,
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE, chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
@@ -407,6 +406,7 @@ class ChatManager:
stream.group_info = group_info stream.group_info = group_info
else: else:
current_time = time.time() current_time = time.time()
from src.common.database.api.specialized import get_or_create_chat_stream
model_instance, _ = await get_or_create_chat_stream( model_instance, _ = await get_or_create_chat_stream(
stream_id=stream_id, stream_id=stream_id,
platform=platform, platform=platform,

View File

@@ -38,7 +38,6 @@ from typing import TYPE_CHECKING, Any
from mofox_bus import MessageEnvelope, MessageRuntime from mofox_bus import MessageEnvelope, MessageRuntime
from src.chat.message_manager import message_manager from src.chat.message_manager import message_manager
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.message_receive.storage import MessageStorage from src.chat.message_receive.storage import MessageStorage
from src.chat.utils.prompt import global_prompt_manager from src.chat.utils.prompt import global_prompt_manager
from src.chat.utils.utils import is_mentioned_bot_in_message from src.chat.utils.utils import is_mentioned_bot_in_message
@@ -261,7 +260,8 @@ class MessageHandler:
# 获取或创建聊天流 # 获取或创建聊天流
platform = message_info.get("platform", "unknown") platform = message_info.get("platform", "unknown")
from src.chat.message_receive.chat_stream import get_chat_manager
chat = await get_chat_manager().get_or_create_stream( chat = await get_chat_manager().get_or_create_stream(
platform=platform, platform=platform,
user_info=user_info, # type: ignore user_info=user_info, # type: ignore
@@ -281,6 +281,7 @@ class MessageHandler:
message.chat_info.last_active_time = chat.last_active_time message.chat_info.last_active_time = chat.last_active_time
# 注册消息到聊天管理器 # 注册消息到聊天管理器
from src.chat.message_receive.chat_stream import get_chat_manager
get_chat_manager().register_message(message) get_chat_manager().register_message(message)
# 检测是否提及机器人 # 检测是否提及机器人

View File

@@ -3,7 +3,7 @@ import re
import time import time
import traceback import traceback
from collections import deque from collections import deque
from typing import Optional from typing import Optional, TYPE_CHECKING
import orjson import orjson
from sqlalchemy import desc, select, update from sqlalchemy import desc, select, update
@@ -13,9 +13,11 @@ from src.common.database.core import get_db_session
from src.common.database.core.models import Images, Messages from src.common.database.core.models import Images, Messages
from src.common.logger import get_logger from src.common.logger import get_logger
from .chat_stream import ChatStream
from .message import MessageSending from .message import MessageSending
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("message_storage") logger = get_logger("message_storage")
@@ -479,7 +481,7 @@ class MessageStorage:
return [] return []
@staticmethod @staticmethod
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None: async def store_message(message: DatabaseMessages | MessageSending, chat_stream: "ChatStream", use_batch: bool = True) -> None:
""" """
存储消息到数据库 存储消息到数据库

View File

@@ -4,7 +4,7 @@ import random
import time import time
from typing import TYPE_CHECKING, Any, cast from typing import TYPE_CHECKING, Any, cast
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.planner_actions.action_manager import ChatterActionManager from src.chat.planner_actions.action_manager import ChatterActionManager
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -15,6 +15,7 @@ from src.plugin_system.core.global_announcement_manager import global_announceme
if TYPE_CHECKING: if TYPE_CHECKING:
from src.common.data_models.message_manager_data_model import StreamContext from src.common.data_models.message_manager_data_model import StreamContext
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("action_manager") logger = get_logger("action_manager")
@@ -31,7 +32,7 @@ class ActionModifier:
"""初始化动作处理器""" """初始化动作处理器"""
self.chat_id = chat_id self.chat_id = chat_id
# chat_stream 和 log_prefix 将在异步方法中初始化 # chat_stream 和 log_prefix 将在异步方法中初始化
self.chat_stream: ChatStream | None = None self.chat_stream: "ChatStream | None" = None
self.log_prefix = f"[{chat_id}]" self.log_prefix = f"[{chat_id}]"
self.action_manager = action_manager self.action_manager = action_manager

View File

@@ -9,10 +9,9 @@ import re
import time import time
import traceback import traceback
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Literal from typing import Any, Literal, TYPE_CHECKING
from src.chat.express.expression_selector import expression_selector from src.chat.express.expression_selector import expression_selector
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.message_receive.message import MessageSending, Seg, UserInfo from src.chat.message_receive.message import MessageSending, Seg, UserInfo
from src.chat.message_receive.uni_message_sender import HeartFCSender from src.chat.message_receive.uni_message_sender import HeartFCSender
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
@@ -38,6 +37,9 @@ from src.plugin_system.apis import llm_api
from src.plugin_system.apis.permission_api import permission_api from src.plugin_system.apis.permission_api import permission_api
from src.plugin_system.base.component_types import ActionInfo, EventType from src.plugin_system.base.component_types import ActionInfo, EventType
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("replyer") logger = get_logger("replyer")
# 用于存储后台任务的集合,防止被垃圾回收 # 用于存储后台任务的集合,防止被垃圾回收
@@ -236,7 +238,7 @@ If you need to use the search tool, please directly call the function "lpmm_sear
class DefaultReplyer: class DefaultReplyer:
def __init__( def __init__(
self, self,
chat_stream: ChatStream, chat_stream: "ChatStream",
request_type: str = "replyer", request_type: str = "replyer",
): ):
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type) self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)

View File

@@ -1,7 +1,11 @@
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.replyer.default_generator import DefaultReplyer from src.chat.replyer.default_generator import DefaultReplyer
from src.common.logger import get_logger from src.common.logger import get_logger
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("ReplyerManager") logger = get_logger("ReplyerManager")
@@ -11,7 +15,7 @@ class ReplyerManager:
async def get_replyer( async def get_replyer(
self, self,
chat_stream: ChatStream | None = None, chat_stream: "ChatStream | None" = None,
chat_id: str | None = None, chat_id: str | None = None,
request_type: str = "replyer", request_type: str = "replyer",
) -> DefaultReplyer | None: ) -> DefaultReplyer | None:

View File

@@ -10,8 +10,6 @@ import numpy as np
import rjieba import rjieba
from mofox_bus import UserInfo from mofox_bus import UserInfo
from src.chat.message_receive.chat_stream import get_chat_manager
# MessageRecv 已被移除,现在使用 DatabaseMessages # MessageRecv 已被移除,现在使用 DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
from src.common.message_repository import count_messages, find_messages from src.common.message_repository import count_messages, find_messages
@@ -780,6 +778,7 @@ async def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None
chat_target_info = None chat_target_info = None
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager
if chat_stream := await get_chat_manager().get_stream(chat_id): if chat_stream := await get_chat_manager().get_stream(chat_id):
if chat_stream.group_info: if chat_stream.group_info:
is_group_chat = True is_group_chat = True

View File

@@ -13,7 +13,6 @@ from rich.traceback import install
from src.chat.emoji_system.emoji_manager import get_emoji_manager from src.chat.emoji_system.emoji_manager import get_emoji_manager
from chat.message_receive.message_handler import get_message_handler, shutdown_message_handler from chat.message_receive.message_handler import get_message_handler, shutdown_message_handler
from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
from src.common.core_sink_manager import ( from src.common.core_sink_manager import (
CoreSinkManager, CoreSinkManager,
@@ -469,6 +468,7 @@ MoFox_Bot(第三方修改版)
logger.info("情绪管理器初始化成功") logger.info("情绪管理器初始化成功")
# 启动聊天管理器的自动保存任务 # 启动聊天管理器的自动保存任务
from src.chat.message_receive.chat_stream import get_chat_manager
task = asyncio.create_task(get_chat_manager()._auto_save_task()) task = asyncio.create_task(get_chat_manager()._auto_save_task())
_background_tasks.add(task) _background_tasks.add(task)
task.add_done_callback(_background_tasks.discard) task.add_done_callback(_background_tasks.discard)

View File

@@ -13,11 +13,13 @@
""" """
from enum import Enum from enum import Enum
from typing import Any from typing import Any, TYPE_CHECKING
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
from src.common.logger import get_logger from src.common.logger import get_logger
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("chat_api") logger = get_logger("chat_api")
@@ -31,7 +33,7 @@ class ChatManager:
"""聊天管理器 - 专门负责聊天信息的查询和管理""" """聊天管理器 - 专门负责聊天信息的查询和管理"""
@staticmethod @staticmethod
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
# sourcery skip: for-append-to-extend # sourcery skip: for-append-to-extend
"""获取所有聊天流 """获取所有聊天流
@@ -48,6 +50,7 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = [] streams = []
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager
streams.extend( streams.extend(
stream for stream in get_chat_manager().streams.values() stream for stream in get_chat_manager().streams.values()
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
@@ -58,7 +61,7 @@ class ChatManager:
return streams return streams
@staticmethod @staticmethod
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
# sourcery skip: for-append-to-extend # sourcery skip: for-append-to-extend
"""获取所有群聊聊天流 """获取所有群聊聊天流
@@ -72,6 +75,7 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = [] streams = []
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager
streams.extend( streams.extend(
stream for stream in get_chat_manager().streams.values() stream for stream in get_chat_manager().streams.values()
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info
@@ -82,7 +86,7 @@ class ChatManager:
return streams return streams
@staticmethod @staticmethod
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
# sourcery skip: for-append-to-extend # sourcery skip: for-append-to-extend
"""获取所有私聊聊天流 """获取所有私聊聊天流
@@ -99,6 +103,7 @@ class ChatManager:
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举") raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
streams = [] streams = []
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager
streams.extend( streams.extend(
stream for stream in get_chat_manager().streams.values() stream for stream in get_chat_manager().streams.values()
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info
@@ -111,7 +116,7 @@ class ChatManager:
@staticmethod @staticmethod
def get_group_stream_by_group_id( def get_group_stream_by_group_id(
group_id: str, platform: str | None | SpecialTypes = "qq" group_id: str, platform: str | None | SpecialTypes = "qq"
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast ) -> "ChatStream | None": # sourcery skip: remove-unnecessary-cast
"""根据群ID获取聊天流 """根据群ID获取聊天流
Args: Args:
@@ -132,6 +137,7 @@ class ChatManager:
if not group_id: if not group_id:
raise ValueError("group_id 不能为空") raise ValueError("group_id 不能为空")
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager
for stream in get_chat_manager().streams.values(): for stream in get_chat_manager().streams.values():
if ( if (
stream.group_info stream.group_info
@@ -148,7 +154,7 @@ class ChatManager:
@staticmethod @staticmethod
def get_private_stream_by_user_id( def get_private_stream_by_user_id(
user_id: str, platform: str | None | SpecialTypes = "qq" user_id: str, platform: str | None | SpecialTypes = "qq"
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast ) -> "ChatStream | None": # sourcery skip: remove-unnecessary-cast
"""根据用户ID获取私聊流 """根据用户ID获取私聊流
Args: Args:
@@ -169,6 +175,7 @@ class ChatManager:
if not user_id: if not user_id:
raise ValueError("user_id 不能为空") raise ValueError("user_id 不能为空")
try: try:
from src.chat.message_receive.chat_stream import get_chat_manager
for stream in get_chat_manager().streams.values(): for stream in get_chat_manager().streams.values():
if ( if (
not stream.group_info not stream.group_info
@@ -184,7 +191,7 @@ class ChatManager:
return None return None
@staticmethod @staticmethod
def get_stream_type(chat_stream: ChatStream) -> str: def get_stream_type(chat_stream: "ChatStream") -> str:
"""获取聊天流类型 """获取聊天流类型
Args: Args:
@@ -197,6 +204,7 @@ class ChatManager:
TypeError: 如果 chat_stream 不是 ChatStream 类型 TypeError: 如果 chat_stream 不是 ChatStream 类型
ValueError: 如果 chat_stream 为空 ValueError: 如果 chat_stream 为空
""" """
from src.chat.message_receive.chat_stream import ChatStream
if not isinstance(chat_stream, ChatStream): if not isinstance(chat_stream, ChatStream):
raise TypeError("chat_stream 必须是 ChatStream 类型") raise TypeError("chat_stream 必须是 ChatStream 类型")
if not chat_stream: if not chat_stream:
@@ -207,7 +215,7 @@ class ChatManager:
return "unknown" return "unknown"
@staticmethod @staticmethod
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]: def get_stream_info(chat_stream: "ChatStream") -> dict[str, Any]:
"""获取聊天流详细信息 """获取聊天流详细信息
Args: Args:
@@ -220,6 +228,7 @@ class ChatManager:
TypeError: 如果 chat_stream 不是 ChatStream 类型 TypeError: 如果 chat_stream 不是 ChatStream 类型
ValueError: 如果 chat_stream 为空 ValueError: 如果 chat_stream 为空
""" """
from src.chat.message_receive.chat_stream import ChatStream
if not chat_stream: if not chat_stream:
raise ValueError("chat_stream 不能为 None") raise ValueError("chat_stream 不能为 None")
if not isinstance(chat_stream, ChatStream): if not isinstance(chat_stream, ChatStream):
@@ -289,37 +298,37 @@ class ChatManager:
# ============================================================================= # =============================================================================
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
"""获取所有聊天流的便捷函数""" """获取所有聊天流的便捷函数"""
return ChatManager.get_all_streams(platform) return ChatManager.get_all_streams(platform)
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
"""获取群聊聊天流的便捷函数""" """获取群聊聊天流的便捷函数"""
return ChatManager.get_group_streams(platform) return ChatManager.get_group_streams(platform)
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]: def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
"""获取私聊聊天流的便捷函数""" """获取私聊聊天流的便捷函数"""
return ChatManager.get_private_streams(platform) return ChatManager.get_private_streams(platform)
def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None: def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> "ChatStream | None":
"""根据群ID获取聊天流的便捷函数""" """根据群ID获取聊天流的便捷函数"""
return ChatManager.get_group_stream_by_group_id(group_id, platform) return ChatManager.get_group_stream_by_group_id(group_id, platform)
def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None: def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> "ChatStream | None":
"""根据用户ID获取私聊流的便捷函数""" """根据用户ID获取私聊流的便捷函数"""
return ChatManager.get_private_stream_by_user_id(user_id, platform) return ChatManager.get_private_stream_by_user_id(user_id, platform)
def get_stream_type(chat_stream: ChatStream) -> str: def get_stream_type(chat_stream: "ChatStream") -> str:
"""获取聊天流类型的便捷函数""" """获取聊天流类型的便捷函数"""
return ChatManager.get_stream_type(chat_stream) return ChatManager.get_stream_type(chat_stream)
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]: def get_stream_info(chat_stream: "ChatStream") -> dict[str, Any]:
"""获取聊天流信息的便捷函数""" """获取聊天流信息的便捷函数"""
return ChatManager.get_stream_info(chat_stream) return ChatManager.get_stream_info(chat_stream)

View File

@@ -3,9 +3,9 @@
""" """
import time import time
from typing import Any from typing import Any, TYPE_CHECKING
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager from src.chat.message_receive.chat_stream import get_chat_manager
from src.chat.utils.chat_message_builder import ( from src.chat.utils.chat_message_builder import (
build_readable_messages_with_id, build_readable_messages_with_id,
get_raw_msg_before_timestamp_with_chat, get_raw_msg_before_timestamp_with_chat,
@@ -15,6 +15,9 @@ from src.common.message_repository import get_user_messages_from_streams
from src.config.config import global_config from src.config.config import global_config
from src.config.official_configs import ContextGroup from src.config.official_configs import ContextGroup
if TYPE_CHECKING:
from src.chat.message_receive.chat_stream import ChatStream
logger = get_logger("cross_context_api") logger = get_logger("cross_context_api")
@@ -51,7 +54,7 @@ async def get_context_group(chat_id: str) -> ContextGroup | None:
return None return None
async def build_cross_context_normal(chat_stream: ChatStream, context_group: ContextGroup) -> str: async def build_cross_context_normal(chat_stream: "ChatStream", context_group: ContextGroup) -> str:
""" """
构建跨群聊/私聊上下文 (Normal模式)。 构建跨群聊/私聊上下文 (Normal模式)。
@@ -124,7 +127,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
async def build_cross_context_s4u( async def build_cross_context_s4u(
chat_stream: ChatStream, chat_stream: "ChatStream",
target_user_info: dict[str, Any] | None, target_user_info: dict[str, Any] | None,
) -> str: ) -> str:
""" """

View File

@@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any
from rich.traceback import install from rich.traceback import install
from src.chat.message_receive.chat_stream import ChatStream
from src.chat.utils.utils import process_llm_response from src.chat.utils.utils import process_llm_response
from src.common.data_models.database_data_model import DatabaseMessages from src.common.data_models.database_data_model import DatabaseMessages
from src.common.logger import get_logger from src.common.logger import get_logger
@@ -21,6 +20,7 @@ from src.plugin_system.base.component_types import ActionInfo
if TYPE_CHECKING: if TYPE_CHECKING:
from chat.replyer.default_generator import DefaultReplyer from chat.replyer.default_generator import DefaultReplyer
from src.chat.message_receive.chat_stream import ChatStream
install(extra_lines=3) install(extra_lines=3)
@@ -34,7 +34,7 @@ logger = get_logger("generator_api")
async def get_replyer( async def get_replyer(
chat_stream: ChatStream | None = None, chat_stream: "ChatStream | None" = None,
chat_id: str | None = None, chat_id: str | None = None,
request_type: str = "replyer", request_type: str = "replyer",
) -> "DefaultReplyer | None": ) -> "DefaultReplyer | None":
@@ -78,7 +78,7 @@ async def get_replyer(
async def generate_reply( async def generate_reply(
chat_stream: ChatStream | None = None, chat_stream: "ChatStream | None" = None,
chat_id: str | None = None, chat_id: str | None = None,
action_data: dict[str, Any] | None = None, action_data: dict[str, Any] | None = None,
reply_to: str = "", reply_to: str = "",
@@ -189,7 +189,7 @@ async def generate_reply(
async def rewrite_reply( async def rewrite_reply(
chat_stream: ChatStream | None = None, chat_stream: "ChatStream | None" = None,
reply_data: dict[str, Any] | None = None, reply_data: dict[str, Any] | None = None,
chat_id: str | None = None, chat_id: str | None = None,
enable_splitter: bool = True, enable_splitter: bool = True,
@@ -287,7 +287,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
async def generate_response_custom( async def generate_response_custom(
chat_stream: ChatStream | None = None, chat_stream: "ChatStream | None" = None,
chat_id: str | None = None, chat_id: str | None = None,
request_type: str = "generator_api", request_type: str = "generator_api",
prompt: str = "", prompt: str = "",