修复ChatStream循环导入问题
This commit is contained in:
@@ -3,7 +3,6 @@ from typing import Literal
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.security import get_api_key
|
from src.common.security import get_api_key
|
||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
@@ -123,6 +122,7 @@ async def get_message_stats_by_chat(
|
|||||||
return stats
|
return stats
|
||||||
|
|
||||||
# 获取聊天管理器以查询会话信息
|
# 获取聊天管理器以查询会话信息
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
chat_manager = get_chat_manager()
|
chat_manager = get_chat_manager()
|
||||||
formatted_stats = {}
|
formatted_stats = {}
|
||||||
# 遍历统计结果进行格式化
|
# 遍历统计结果进行格式化
|
||||||
|
|||||||
@@ -8,10 +8,7 @@ from sqlalchemy.dialects.sqlite import insert as sqlite_insert
|
|||||||
|
|
||||||
from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo
|
from src.common.data_models.database_data_model import DatabaseGroupInfo,DatabaseUserInfo
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.data_models.message_manager_data_model import StreamContext
|
|
||||||
from src.plugin_system.base.component_types import ChatMode, ChatType
|
|
||||||
from src.common.database.api.crud import CRUDBase
|
from src.common.database.api.crud import CRUDBase
|
||||||
from src.common.database.api.specialized import get_or_create_chat_stream
|
|
||||||
from src.common.database.compatibility import get_db_session
|
from src.common.database.compatibility import get_db_session
|
||||||
from src.common.database.core.models import ChatStreams # 新增导入
|
from src.common.database.core.models import ChatStreams # 新增导入
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -43,6 +40,8 @@ class ChatStream:
|
|||||||
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
self.sleep_pressure = data.get("sleep_pressure", 0.0) if data else 0.0
|
||||||
self.saved = False
|
self.saved = False
|
||||||
|
|
||||||
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.plugin_system.base.component_types import ChatMode, ChatType
|
||||||
self.context: StreamContext = StreamContext(
|
self.context: StreamContext = StreamContext(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
|
chat_type=ChatType.GROUP if group_info else ChatType.PRIVATE,
|
||||||
@@ -407,6 +406,7 @@ class ChatManager:
|
|||||||
stream.group_info = group_info
|
stream.group_info = group_info
|
||||||
else:
|
else:
|
||||||
current_time = time.time()
|
current_time = time.time()
|
||||||
|
from src.common.database.api.specialized import get_or_create_chat_stream
|
||||||
model_instance, _ = await get_or_create_chat_stream(
|
model_instance, _ = await get_or_create_chat_stream(
|
||||||
stream_id=stream_id,
|
stream_id=stream_id,
|
||||||
platform=platform,
|
platform=platform,
|
||||||
|
|||||||
@@ -38,7 +38,6 @@ from typing import TYPE_CHECKING, Any
|
|||||||
from mofox_bus import MessageEnvelope, MessageRuntime
|
from mofox_bus import MessageEnvelope, MessageRuntime
|
||||||
|
|
||||||
from src.chat.message_manager import message_manager
|
from src.chat.message_manager import message_manager
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.message_receive.storage import MessageStorage
|
from src.chat.message_receive.storage import MessageStorage
|
||||||
from src.chat.utils.prompt import global_prompt_manager
|
from src.chat.utils.prompt import global_prompt_manager
|
||||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||||
@@ -261,7 +260,8 @@ class MessageHandler:
|
|||||||
|
|
||||||
# 获取或创建聊天流
|
# 获取或创建聊天流
|
||||||
platform = message_info.get("platform", "unknown")
|
platform = message_info.get("platform", "unknown")
|
||||||
|
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
chat = await get_chat_manager().get_or_create_stream(
|
chat = await get_chat_manager().get_or_create_stream(
|
||||||
platform=platform,
|
platform=platform,
|
||||||
user_info=user_info, # type: ignore
|
user_info=user_info, # type: ignore
|
||||||
@@ -281,6 +281,7 @@ class MessageHandler:
|
|||||||
message.chat_info.last_active_time = chat.last_active_time
|
message.chat_info.last_active_time = chat.last_active_time
|
||||||
|
|
||||||
# 注册消息到聊天管理器
|
# 注册消息到聊天管理器
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
get_chat_manager().register_message(message)
|
get_chat_manager().register_message(message)
|
||||||
|
|
||||||
# 检测是否提及机器人
|
# 检测是否提及机器人
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import re
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from typing import Optional
|
from typing import Optional, TYPE_CHECKING
|
||||||
|
|
||||||
import orjson
|
import orjson
|
||||||
from sqlalchemy import desc, select, update
|
from sqlalchemy import desc, select, update
|
||||||
@@ -13,9 +13,11 @@ from src.common.database.core import get_db_session
|
|||||||
from src.common.database.core.models import Images, Messages
|
from src.common.database.core.models import Images, Messages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
from .chat_stream import ChatStream
|
|
||||||
from .message import MessageSending
|
from .message import MessageSending
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_logger("message_storage")
|
logger = get_logger("message_storage")
|
||||||
|
|
||||||
|
|
||||||
@@ -479,7 +481,7 @@ class MessageStorage:
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: ChatStream, use_batch: bool = True) -> None:
|
async def store_message(message: DatabaseMessages | MessageSending, chat_stream: "ChatStream", use_batch: bool = True) -> None:
|
||||||
"""
|
"""
|
||||||
存储消息到数据库
|
存储消息到数据库
|
||||||
|
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import random
|
|||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, cast
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.planner_actions.action_manager import ChatterActionManager
|
from src.chat.planner_actions.action_manager import ChatterActionManager
|
||||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -15,6 +15,7 @@ from src.plugin_system.core.global_announcement_manager import global_announceme
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from src.common.data_models.message_manager_data_model import StreamContext
|
from src.common.data_models.message_manager_data_model import StreamContext
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_logger("action_manager")
|
logger = get_logger("action_manager")
|
||||||
|
|
||||||
@@ -31,7 +32,7 @@ class ActionModifier:
|
|||||||
"""初始化动作处理器"""
|
"""初始化动作处理器"""
|
||||||
self.chat_id = chat_id
|
self.chat_id = chat_id
|
||||||
# chat_stream 和 log_prefix 将在异步方法中初始化
|
# chat_stream 和 log_prefix 将在异步方法中初始化
|
||||||
self.chat_stream: ChatStream | None = None
|
self.chat_stream: "ChatStream | None" = None
|
||||||
self.log_prefix = f"[{chat_id}]"
|
self.log_prefix = f"[{chat_id}]"
|
||||||
|
|
||||||
self.action_manager = action_manager
|
self.action_manager = action_manager
|
||||||
|
|||||||
@@ -9,10 +9,9 @@ import re
|
|||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal, TYPE_CHECKING
|
||||||
|
|
||||||
from src.chat.express.expression_selector import expression_selector
|
from src.chat.express.expression_selector import expression_selector
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
from src.chat.message_receive.message import MessageSending, Seg, UserInfo
|
from src.chat.message_receive.message import MessageSending, Seg, UserInfo
|
||||||
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
from src.chat.message_receive.uni_message_sender import HeartFCSender
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
@@ -38,6 +37,9 @@ from src.plugin_system.apis import llm_api
|
|||||||
from src.plugin_system.apis.permission_api import permission_api
|
from src.plugin_system.apis.permission_api import permission_api
|
||||||
from src.plugin_system.base.component_types import ActionInfo, EventType
|
from src.plugin_system.base.component_types import ActionInfo, EventType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_logger("replyer")
|
logger = get_logger("replyer")
|
||||||
|
|
||||||
# 用于存储后台任务的集合,防止被垃圾回收
|
# 用于存储后台任务的集合,防止被垃圾回收
|
||||||
@@ -236,7 +238,7 @@ If you need to use the search tool, please directly call the function "lpmm_sear
|
|||||||
class DefaultReplyer:
|
class DefaultReplyer:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
request_type: str = "replyer",
|
request_type: str = "replyer",
|
||||||
):
|
):
|
||||||
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
self.express_model = LLMRequest(model_set=model_config.model_task_config.replyer, request_type=request_type)
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.replyer.default_generator import DefaultReplyer
|
from src.chat.replyer.default_generator import DefaultReplyer
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
logger = get_logger("ReplyerManager")
|
logger = get_logger("ReplyerManager")
|
||||||
|
|
||||||
|
|
||||||
@@ -11,7 +15,7 @@ class ReplyerManager:
|
|||||||
|
|
||||||
async def get_replyer(
|
async def get_replyer(
|
||||||
self,
|
self,
|
||||||
chat_stream: ChatStream | None = None,
|
chat_stream: "ChatStream | None" = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
request_type: str = "replyer",
|
request_type: str = "replyer",
|
||||||
) -> DefaultReplyer | None:
|
) -> DefaultReplyer | None:
|
||||||
|
|||||||
@@ -10,8 +10,6 @@ import numpy as np
|
|||||||
import rjieba
|
import rjieba
|
||||||
from mofox_bus import UserInfo
|
from mofox_bus import UserInfo
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
|
|
||||||
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
# MessageRecv 已被移除,现在使用 DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
from src.common.message_repository import count_messages, find_messages
|
from src.common.message_repository import count_messages, find_messages
|
||||||
@@ -780,6 +778,7 @@ async def get_chat_type_and_target_info(chat_id: str) -> tuple[bool, dict | None
|
|||||||
chat_target_info = None
|
chat_target_info = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
if chat_stream := await get_chat_manager().get_stream(chat_id):
|
if chat_stream := await get_chat_manager().get_stream(chat_id):
|
||||||
if chat_stream.group_info:
|
if chat_stream.group_info:
|
||||||
is_group_chat = True
|
is_group_chat = True
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from rich.traceback import install
|
|||||||
|
|
||||||
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
from src.chat.emoji_system.emoji_manager import get_emoji_manager
|
||||||
from chat.message_receive.message_handler import get_message_handler, shutdown_message_handler
|
from chat.message_receive.message_handler import get_message_handler, shutdown_message_handler
|
||||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
|
||||||
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
from src.chat.utils.statistic import OnlineTimeRecordTask, StatisticOutputTask
|
||||||
from src.common.core_sink_manager import (
|
from src.common.core_sink_manager import (
|
||||||
CoreSinkManager,
|
CoreSinkManager,
|
||||||
@@ -469,6 +468,7 @@ MoFox_Bot(第三方修改版)
|
|||||||
logger.info("情绪管理器初始化成功")
|
logger.info("情绪管理器初始化成功")
|
||||||
|
|
||||||
# 启动聊天管理器的自动保存任务
|
# 启动聊天管理器的自动保存任务
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
task = asyncio.create_task(get_chat_manager()._auto_save_task())
|
task = asyncio.create_task(get_chat_manager()._auto_save_task())
|
||||||
_background_tasks.add(task)
|
_background_tasks.add(task)
|
||||||
task.add_done_callback(_background_tasks.discard)
|
task.add_done_callback(_background_tasks.discard)
|
||||||
|
|||||||
@@ -13,11 +13,13 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Any
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_logger("chat_api")
|
logger = get_logger("chat_api")
|
||||||
|
|
||||||
|
|
||||||
@@ -31,7 +33,7 @@ class ChatManager:
|
|||||||
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
"""聊天管理器 - 专门负责聊天信息的查询和管理"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||||
# sourcery skip: for-append-to-extend
|
# sourcery skip: for-append-to-extend
|
||||||
"""获取所有聊天流
|
"""获取所有聊天流
|
||||||
|
|
||||||
@@ -48,6 +50,7 @@ class ChatManager:
|
|||||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
streams = []
|
streams = []
|
||||||
try:
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
streams.extend(
|
streams.extend(
|
||||||
stream for stream in get_chat_manager().streams.values()
|
stream for stream in get_chat_manager().streams.values()
|
||||||
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
|
if platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform
|
||||||
@@ -58,7 +61,7 @@ class ChatManager:
|
|||||||
return streams
|
return streams
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||||
# sourcery skip: for-append-to-extend
|
# sourcery skip: for-append-to-extend
|
||||||
"""获取所有群聊聊天流
|
"""获取所有群聊聊天流
|
||||||
|
|
||||||
@@ -72,6 +75,7 @@ class ChatManager:
|
|||||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
streams = []
|
streams = []
|
||||||
try:
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
streams.extend(
|
streams.extend(
|
||||||
stream for stream in get_chat_manager().streams.values()
|
stream for stream in get_chat_manager().streams.values()
|
||||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info
|
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and stream.group_info
|
||||||
@@ -82,7 +86,7 @@ class ChatManager:
|
|||||||
return streams
|
return streams
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||||
# sourcery skip: for-append-to-extend
|
# sourcery skip: for-append-to-extend
|
||||||
"""获取所有私聊聊天流
|
"""获取所有私聊聊天流
|
||||||
|
|
||||||
@@ -99,6 +103,7 @@ class ChatManager:
|
|||||||
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
raise TypeError("platform 必须是字符串或是 SpecialTypes 枚举")
|
||||||
streams = []
|
streams = []
|
||||||
try:
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
streams.extend(
|
streams.extend(
|
||||||
stream for stream in get_chat_manager().streams.values()
|
stream for stream in get_chat_manager().streams.values()
|
||||||
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info
|
if (platform == SpecialTypes.ALL_PLATFORMS or stream.platform == platform) and not stream.group_info
|
||||||
@@ -111,7 +116,7 @@ class ChatManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_group_stream_by_group_id(
|
def get_group_stream_by_group_id(
|
||||||
group_id: str, platform: str | None | SpecialTypes = "qq"
|
group_id: str, platform: str | None | SpecialTypes = "qq"
|
||||||
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
|
) -> "ChatStream | None": # sourcery skip: remove-unnecessary-cast
|
||||||
"""根据群ID获取聊天流
|
"""根据群ID获取聊天流
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -132,6 +137,7 @@ class ChatManager:
|
|||||||
if not group_id:
|
if not group_id:
|
||||||
raise ValueError("group_id 不能为空")
|
raise ValueError("group_id 不能为空")
|
||||||
try:
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
for stream in get_chat_manager().streams.values():
|
for stream in get_chat_manager().streams.values():
|
||||||
if (
|
if (
|
||||||
stream.group_info
|
stream.group_info
|
||||||
@@ -148,7 +154,7 @@ class ChatManager:
|
|||||||
@staticmethod
|
@staticmethod
|
||||||
def get_private_stream_by_user_id(
|
def get_private_stream_by_user_id(
|
||||||
user_id: str, platform: str | None | SpecialTypes = "qq"
|
user_id: str, platform: str | None | SpecialTypes = "qq"
|
||||||
) -> ChatStream | None: # sourcery skip: remove-unnecessary-cast
|
) -> "ChatStream | None": # sourcery skip: remove-unnecessary-cast
|
||||||
"""根据用户ID获取私聊流
|
"""根据用户ID获取私聊流
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -169,6 +175,7 @@ class ChatManager:
|
|||||||
if not user_id:
|
if not user_id:
|
||||||
raise ValueError("user_id 不能为空")
|
raise ValueError("user_id 不能为空")
|
||||||
try:
|
try:
|
||||||
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
for stream in get_chat_manager().streams.values():
|
for stream in get_chat_manager().streams.values():
|
||||||
if (
|
if (
|
||||||
not stream.group_info
|
not stream.group_info
|
||||||
@@ -184,7 +191,7 @@ class ChatManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
def get_stream_type(chat_stream: "ChatStream") -> str:
|
||||||
"""获取聊天流类型
|
"""获取聊天流类型
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -197,6 +204,7 @@ class ChatManager:
|
|||||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||||
ValueError: 如果 chat_stream 为空
|
ValueError: 如果 chat_stream 为空
|
||||||
"""
|
"""
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
if not isinstance(chat_stream, ChatStream):
|
if not isinstance(chat_stream, ChatStream):
|
||||||
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
raise TypeError("chat_stream 必须是 ChatStream 类型")
|
||||||
if not chat_stream:
|
if not chat_stream:
|
||||||
@@ -207,7 +215,7 @@ class ChatManager:
|
|||||||
return "unknown"
|
return "unknown"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
|
def get_stream_info(chat_stream: "ChatStream") -> dict[str, Any]:
|
||||||
"""获取聊天流详细信息
|
"""获取聊天流详细信息
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -220,6 +228,7 @@ class ChatManager:
|
|||||||
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
TypeError: 如果 chat_stream 不是 ChatStream 类型
|
||||||
ValueError: 如果 chat_stream 为空
|
ValueError: 如果 chat_stream 为空
|
||||||
"""
|
"""
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
if not chat_stream:
|
if not chat_stream:
|
||||||
raise ValueError("chat_stream 不能为 None")
|
raise ValueError("chat_stream 不能为 None")
|
||||||
if not isinstance(chat_stream, ChatStream):
|
if not isinstance(chat_stream, ChatStream):
|
||||||
@@ -289,37 +298,37 @@ class ChatManager:
|
|||||||
# =============================================================================
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
def get_all_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||||
"""获取所有聊天流的便捷函数"""
|
"""获取所有聊天流的便捷函数"""
|
||||||
return ChatManager.get_all_streams(platform)
|
return ChatManager.get_all_streams(platform)
|
||||||
|
|
||||||
|
|
||||||
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
def get_group_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||||
"""获取群聊聊天流的便捷函数"""
|
"""获取群聊聊天流的便捷函数"""
|
||||||
return ChatManager.get_group_streams(platform)
|
return ChatManager.get_group_streams(platform)
|
||||||
|
|
||||||
|
|
||||||
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list[ChatStream]:
|
def get_private_streams(platform: str | None | SpecialTypes = "qq") -> list["ChatStream"]:
|
||||||
"""获取私聊聊天流的便捷函数"""
|
"""获取私聊聊天流的便捷函数"""
|
||||||
return ChatManager.get_private_streams(platform)
|
return ChatManager.get_private_streams(platform)
|
||||||
|
|
||||||
|
|
||||||
def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
|
def get_stream_by_group_id(group_id: str, platform: str | None | SpecialTypes = "qq") -> "ChatStream | None":
|
||||||
"""根据群ID获取聊天流的便捷函数"""
|
"""根据群ID获取聊天流的便捷函数"""
|
||||||
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
return ChatManager.get_group_stream_by_group_id(group_id, platform)
|
||||||
|
|
||||||
|
|
||||||
def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> ChatStream | None:
|
def get_stream_by_user_id(user_id: str, platform: str | None | SpecialTypes = "qq") -> "ChatStream | None":
|
||||||
"""根据用户ID获取私聊流的便捷函数"""
|
"""根据用户ID获取私聊流的便捷函数"""
|
||||||
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
return ChatManager.get_private_stream_by_user_id(user_id, platform)
|
||||||
|
|
||||||
|
|
||||||
def get_stream_type(chat_stream: ChatStream) -> str:
|
def get_stream_type(chat_stream: "ChatStream") -> str:
|
||||||
"""获取聊天流类型的便捷函数"""
|
"""获取聊天流类型的便捷函数"""
|
||||||
return ChatManager.get_stream_type(chat_stream)
|
return ChatManager.get_stream_type(chat_stream)
|
||||||
|
|
||||||
|
|
||||||
def get_stream_info(chat_stream: ChatStream) -> dict[str, Any]:
|
def get_stream_info(chat_stream: "ChatStream") -> dict[str, Any]:
|
||||||
"""获取聊天流信息的便捷函数"""
|
"""获取聊天流信息的便捷函数"""
|
||||||
return ChatManager.get_stream_info(chat_stream)
|
return ChatManager.get_stream_info(chat_stream)
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,9 @@
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Any
|
from typing import Any, TYPE_CHECKING
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||||
from src.chat.utils.chat_message_builder import (
|
from src.chat.utils.chat_message_builder import (
|
||||||
build_readable_messages_with_id,
|
build_readable_messages_with_id,
|
||||||
get_raw_msg_before_timestamp_with_chat,
|
get_raw_msg_before_timestamp_with_chat,
|
||||||
@@ -15,6 +15,9 @@ from src.common.message_repository import get_user_messages_from_streams
|
|||||||
from src.config.config import global_config
|
from src.config.config import global_config
|
||||||
from src.config.official_configs import ContextGroup
|
from src.config.official_configs import ContextGroup
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
logger = get_logger("cross_context_api")
|
logger = get_logger("cross_context_api")
|
||||||
|
|
||||||
|
|
||||||
@@ -51,7 +54,7 @@ async def get_context_group(chat_id: str) -> ContextGroup | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
async def build_cross_context_normal(chat_stream: ChatStream, context_group: ContextGroup) -> str:
|
async def build_cross_context_normal(chat_stream: "ChatStream", context_group: ContextGroup) -> str:
|
||||||
"""
|
"""
|
||||||
构建跨群聊/私聊上下文 (Normal模式)。
|
构建跨群聊/私聊上下文 (Normal模式)。
|
||||||
|
|
||||||
@@ -124,7 +127,7 @@ async def build_cross_context_normal(chat_stream: ChatStream, context_group: Con
|
|||||||
|
|
||||||
|
|
||||||
async def build_cross_context_s4u(
|
async def build_cross_context_s4u(
|
||||||
chat_stream: ChatStream,
|
chat_stream: "ChatStream",
|
||||||
target_user_info: dict[str, Any] | None,
|
target_user_info: dict[str, Any] | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ from typing import TYPE_CHECKING, Any
|
|||||||
|
|
||||||
from rich.traceback import install
|
from rich.traceback import install
|
||||||
|
|
||||||
from src.chat.message_receive.chat_stream import ChatStream
|
|
||||||
from src.chat.utils.utils import process_llm_response
|
from src.chat.utils.utils import process_llm_response
|
||||||
from src.common.data_models.database_data_model import DatabaseMessages
|
from src.common.data_models.database_data_model import DatabaseMessages
|
||||||
from src.common.logger import get_logger
|
from src.common.logger import get_logger
|
||||||
@@ -21,6 +20,7 @@ from src.plugin_system.base.component_types import ActionInfo
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from chat.replyer.default_generator import DefaultReplyer
|
from chat.replyer.default_generator import DefaultReplyer
|
||||||
|
from src.chat.message_receive.chat_stream import ChatStream
|
||||||
|
|
||||||
install(extra_lines=3)
|
install(extra_lines=3)
|
||||||
|
|
||||||
@@ -34,7 +34,7 @@ logger = get_logger("generator_api")
|
|||||||
|
|
||||||
|
|
||||||
async def get_replyer(
|
async def get_replyer(
|
||||||
chat_stream: ChatStream | None = None,
|
chat_stream: "ChatStream | None" = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
request_type: str = "replyer",
|
request_type: str = "replyer",
|
||||||
) -> "DefaultReplyer | None":
|
) -> "DefaultReplyer | None":
|
||||||
@@ -78,7 +78,7 @@ async def get_replyer(
|
|||||||
|
|
||||||
|
|
||||||
async def generate_reply(
|
async def generate_reply(
|
||||||
chat_stream: ChatStream | None = None,
|
chat_stream: "ChatStream | None" = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
action_data: dict[str, Any] | None = None,
|
action_data: dict[str, Any] | None = None,
|
||||||
reply_to: str = "",
|
reply_to: str = "",
|
||||||
@@ -189,7 +189,7 @@ async def generate_reply(
|
|||||||
|
|
||||||
|
|
||||||
async def rewrite_reply(
|
async def rewrite_reply(
|
||||||
chat_stream: ChatStream | None = None,
|
chat_stream: "ChatStream | None" = None,
|
||||||
reply_data: dict[str, Any] | None = None,
|
reply_data: dict[str, Any] | None = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
enable_splitter: bool = True,
|
enable_splitter: bool = True,
|
||||||
@@ -287,7 +287,7 @@ def process_human_text(content: str, enable_splitter: bool, enable_chinese_typo:
|
|||||||
|
|
||||||
|
|
||||||
async def generate_response_custom(
|
async def generate_response_custom(
|
||||||
chat_stream: ChatStream | None = None,
|
chat_stream: "ChatStream | None" = None,
|
||||||
chat_id: str | None = None,
|
chat_id: str | None = None,
|
||||||
request_type: str = "generator_api",
|
request_type: str = "generator_api",
|
||||||
prompt: str = "",
|
prompt: str = "",
|
||||||
|
|||||||
Reference in New Issue
Block a user