re-style: 格式化代码
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
import time
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.config.config import model_config
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.mais4u.mais4u_chat.s4u_msg_processor import S4UMessageProcessor
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
import orjson
|
||||
import time
|
||||
|
||||
import orjson
|
||||
from json_repair import repair_json
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
|
||||
logger = get_logger("action")
|
||||
|
||||
HEAD_CODE = {
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import orjson
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
from aiohttp import web, WSMsgType
|
||||
|
||||
import aiohttp_cors
|
||||
import orjson
|
||||
from aiohttp import WSMsgType, web
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.common.logger import get_logger
|
||||
@@ -57,8 +57,8 @@ class ContextWebManager:
|
||||
def __init__(self, max_messages: int = 10, port: int = 8765):
|
||||
self.max_messages = max_messages
|
||||
self.port = port
|
||||
self.contexts: Dict[str, deque] = {} # chat_id -> deque of ContextMessage
|
||||
self.websockets: List[web.WebSocketResponse] = []
|
||||
self.contexts: dict[str, deque] = {} # chat_id -> deque of ContextMessage
|
||||
self.websockets: list[web.WebSocketResponse] = []
|
||||
self.app = None
|
||||
self.runner = None
|
||||
self.site = None
|
||||
@@ -674,7 +674,7 @@ class ContextWebManager:
|
||||
|
||||
|
||||
# 全局实例
|
||||
_context_web_manager: Optional[ContextWebManager] = None
|
||||
_context_web_manager: ContextWebManager | None = None
|
||||
|
||||
|
||||
def get_context_web_manager() -> ContextWebManager:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
from typing import Dict, Tuple, Callable, Optional
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
@@ -23,11 +23,11 @@ class GiftManager:
|
||||
|
||||
def __init__(self):
|
||||
"""初始化礼物管理器"""
|
||||
self.pending_gifts: Dict[Tuple[str, str], PendingGift] = {}
|
||||
self.pending_gifts: dict[tuple[str, str], PendingGift] = {}
|
||||
self.debounce_timeout = 5.0 # 3秒防抖时间
|
||||
|
||||
async def handle_gift(
|
||||
self, message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]] = None
|
||||
self, message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None = None
|
||||
) -> bool:
|
||||
"""处理礼物消息,返回是否应该立即处理
|
||||
|
||||
@@ -53,7 +53,7 @@ class GiftManager:
|
||||
await self._create_pending_gift(gift_key, message, callback)
|
||||
return False
|
||||
|
||||
async def _merge_gift(self, gift_key: Tuple[str, str], new_message: MessageRecvS4U) -> None:
|
||||
async def _merge_gift(self, gift_key: tuple[str, str], new_message: MessageRecvS4U) -> None:
|
||||
"""合并礼物消息"""
|
||||
pending_gift = self.pending_gifts[gift_key]
|
||||
|
||||
@@ -81,7 +81,7 @@ class GiftManager:
|
||||
logger.debug(f"合并礼物: {gift_key}, 总数量: {pending_gift.total_count}")
|
||||
|
||||
async def _create_pending_gift(
|
||||
self, gift_key: Tuple[str, str], message: MessageRecvS4U, callback: Optional[Callable[[MessageRecvS4U], None]]
|
||||
self, gift_key: tuple[str, str], message: MessageRecvS4U, callback: Callable[[MessageRecvS4U], None] | None
|
||||
) -> None:
|
||||
"""创建新的等待礼物"""
|
||||
try:
|
||||
@@ -100,7 +100,7 @@ class GiftManager:
|
||||
|
||||
logger.debug(f"创建等待礼物: {gift_key}, 初始数量: {initial_count}")
|
||||
|
||||
async def _gift_timeout(self, gift_key: Tuple[str, str]) -> None:
|
||||
async def _gift_timeout(self, gift_key: tuple[str, str]) -> None:
|
||||
"""礼物防抖超时处理"""
|
||||
try:
|
||||
# 等待防抖时间
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
class InternalManager:
|
||||
def __init__(self):
|
||||
self.now_internal_state = str()
|
||||
self.now_internal_state = ""
|
||||
|
||||
def set_internal_state(self, internal_state: str):
|
||||
self.now_internal_state = internal_state
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
import time
|
||||
import random
|
||||
from typing import Optional, Dict, Tuple, List # 导入类型提示
|
||||
from maim_message import UserInfo, Seg
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from .s4u_stream_generator import S4UStreamGenerator
|
||||
from src.chat.message_receive.message import MessageSending, MessageRecv, MessageRecvS4U
|
||||
from src.config.config import global_config
|
||||
from src.common.message.api import get_global_api
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from .s4u_watching_manager import watching_manager
|
||||
import time
|
||||
import traceback
|
||||
|
||||
import orjson
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
from maim_message import Seg, UserInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import ChatStream, get_chat_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U, MessageSending
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.common.logger import get_logger
|
||||
from src.common.message.api import get_global_api
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import PersonInfoManager
|
||||
from src.person_info.relationship_builder_manager import relationship_builder_manager
|
||||
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from .s4u_stream_generator import S4UStreamGenerator
|
||||
from .s4u_watching_manager import watching_manager
|
||||
from .super_chat_manager import get_super_chat_manager
|
||||
from .yes_or_no import yes_or_no_head
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("S4U_chat")
|
||||
|
||||
@@ -32,7 +34,7 @@ class MessageSenderContainer:
|
||||
self.original_message = original_message
|
||||
self.queue = asyncio.Queue()
|
||||
self.storage = MessageStorage()
|
||||
self._task: Optional[asyncio.Task] = None
|
||||
self._task: asyncio.Task | None = None
|
||||
self._paused_event = asyncio.Event()
|
||||
self._paused_event.set() # 默认设置为非暂停状态
|
||||
|
||||
@@ -158,7 +160,7 @@ class MessageSenderContainer:
|
||||
|
||||
class S4UChatManager:
|
||||
def __init__(self):
|
||||
self.s4u_chats: Dict[str, "S4UChat"] = {}
|
||||
self.s4u_chats: dict[str, "S4UChat"] = {}
|
||||
|
||||
def get_or_create_chat(self, chat_stream: ChatStream) -> "S4UChat":
|
||||
if chat_stream.stream_id not in self.s4u_chats:
|
||||
@@ -196,16 +198,16 @@ class S4UChat:
|
||||
self._new_message_event = asyncio.Event() # 用于唤醒处理器
|
||||
|
||||
self._processing_task = asyncio.create_task(self._message_processor())
|
||||
self._current_generation_task: Optional[asyncio.Task] = None
|
||||
self._current_generation_task: asyncio.Task | None = None
|
||||
# 当前消息的元数据:(队列类型, 优先级分数, 计数器, 消息对象)
|
||||
self._current_message_being_replied: Optional[Tuple[str, float, int, MessageRecv]] = None
|
||||
self._current_message_being_replied: tuple[str, float, int, MessageRecv] | None = None
|
||||
|
||||
self._is_replying = False
|
||||
self.gpt = S4UStreamGenerator()
|
||||
self.gpt.chat_stream = self.chat_stream
|
||||
self.interest_dict: Dict[str, float] = {} # 用户兴趣分
|
||||
self.interest_dict: dict[str, float] = {} # 用户兴趣分
|
||||
|
||||
self.internal_message: List[MessageRecvS4U] = []
|
||||
self.internal_message: list[MessageRecvS4U] = []
|
||||
|
||||
self.msg_id = ""
|
||||
self.voice_done = ""
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
import asyncio
|
||||
import orjson
|
||||
import time
|
||||
|
||||
import orjson
|
||||
|
||||
from src.chat.message_receive.message import MessageRecv
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_by_timestamp_with_chat_inclusive
|
||||
from src.config.config import global_config, model_config
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config, model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from src.manager.async_task_manager import AsyncTask, async_task_manager
|
||||
from src.plugin_system.apis import send_api
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
"""
|
||||
情绪管理系统使用说明:
|
||||
|
||||
@@ -1,33 +1,33 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
from maim_message.message_base import GroupInfo
|
||||
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
from src.chat.message_receive.message import MessageRecv, MessageRecvS4U
|
||||
from maim_message.message_base import GroupInfo
|
||||
from src.chat.message_receive.storage import MessageStorage
|
||||
from src.chat.message_receive.chat_stream import get_chat_manager
|
||||
from src.chat.utils.timer_calculator import Timer
|
||||
from src.chat.utils.utils import is_mentioned_bot_in_message
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.body_emotion_action_manager import action_manager
|
||||
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
|
||||
from src.mais4u.mais4u_chat.context_web_manager import get_context_web_manager
|
||||
from src.mais4u.mais4u_chat.gift_manager import gift_manager
|
||||
from src.mais4u.mais4u_chat.s4u_mood_manager import mood_manager
|
||||
from src.mais4u.mais4u_chat.s4u_watching_manager import watching_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
|
||||
from .s4u_chat import get_s4u_chat_manager
|
||||
|
||||
|
||||
# from ..message_receive.message_buffer import message_buffer
|
||||
|
||||
logger = get_logger("chat")
|
||||
|
||||
|
||||
async def _calculate_interest(message: MessageRecv) -> Tuple[float, bool]:
|
||||
async def _calculate_interest(message: MessageRecv) -> tuple[float, bool]:
|
||||
"""计算消息的兴趣度
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,25 +1,27 @@
|
||||
from src.config.config import global_config
|
||||
from src.common.logger import get_logger
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
import time
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
import asyncio
|
||||
|
||||
# 旧的Hippocampus系统已被移除,现在使用增强记忆系统
|
||||
# from src.chat.memory_system.enhanced_memory_manager import enhanced_memory_manager
|
||||
import random
|
||||
import time
|
||||
from datetime import datetime
|
||||
import asyncio
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
|
||||
from src.chat.express.expression_selector import expression_selector
|
||||
from .s4u_mood_manager import mood_manager
|
||||
from src.chat.message_receive.chat_stream import ChatStream
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.chat.utils.chat_message_builder import build_readable_messages, get_raw_msg_before_timestamp_with_chat
|
||||
from src.chat.utils.prompt import Prompt, global_prompt_manager
|
||||
from src.chat.utils.utils import get_recent_group_speaker
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import global_config
|
||||
from src.mais4u.mais4u_chat.internal_manager import internal_manager
|
||||
from src.mais4u.mais4u_chat.screen_manager import screen_manager
|
||||
from src.mais4u.mais4u_chat.super_chat_manager import get_super_chat_manager
|
||||
from src.mais4u.s4u_config import s4u_config
|
||||
from src.person_info.person_info import PersonInfoManager, get_person_info_manager
|
||||
from src.person_info.relationship_fetcher import relationship_fetcher_manager
|
||||
|
||||
from .s4u_mood_manager import mood_manager
|
||||
|
||||
logger = get_logger("prompt")
|
||||
|
||||
@@ -206,7 +208,7 @@ class PromptBuilder:
|
||||
limit=300,
|
||||
)
|
||||
|
||||
talk_type = f"{message.message_info.platform}:{str(message.chat_stream.user_info.user_id)}"
|
||||
talk_type = f"{message.message_info.platform}:{message.chat_stream.user_info.user_id!s}"
|
||||
|
||||
core_dialogue_list = []
|
||||
background_dialogue_list = []
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
from typing import AsyncGenerator
|
||||
from src.mais4u.openai_client import AsyncOpenAIClient
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||
from src.common.logger import get_logger
|
||||
import asyncio
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||
from src.mais4u.openai_client import AsyncOpenAIClient
|
||||
|
||||
logger = get_logger("s4u_stream_generator")
|
||||
|
||||
@@ -99,7 +99,7 @@ class S4UStreamGenerator:
|
||||
|
||||
logger.info(
|
||||
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
||||
) # noqa: E501
|
||||
)
|
||||
|
||||
current_client = self.client_1
|
||||
self.current_model_name = self.model_1_name
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
class ScreenManager:
|
||||
def __init__(self):
|
||||
self.now_screen = str()
|
||||
self.now_screen = ""
|
||||
|
||||
def set_screen(self, screen_str: str):
|
||||
self.now_screen = screen_str
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional
|
||||
from src.common.logger import get_logger
|
||||
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.common.logger import get_logger
|
||||
|
||||
# 全局SuperChat管理器实例
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
@@ -23,7 +23,7 @@ class SuperChatRecord:
|
||||
message_text: str
|
||||
timestamp: float
|
||||
expire_time: float
|
||||
group_name: Optional[str] = None
|
||||
group_name: str | None = None
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""检查SuperChat是否已过期"""
|
||||
@@ -53,8 +53,8 @@ class SuperChatManager:
|
||||
"""SuperChat管理器,负责管理和跟踪SuperChat消息"""
|
||||
|
||||
def __init__(self):
|
||||
self.super_chats: Dict[str, List[SuperChatRecord]] = {} # chat_id -> SuperChat列表
|
||||
self._cleanup_task: Optional[asyncio.Task] = None
|
||||
self.super_chats: dict[str, list[SuperChatRecord]] = {} # chat_id -> SuperChat列表
|
||||
self._cleanup_task: asyncio.Task | None = None
|
||||
self._is_initialized = False
|
||||
logger.info("SuperChat管理器已初始化")
|
||||
|
||||
@@ -186,7 +186,7 @@ class SuperChatManager:
|
||||
|
||||
logger.info(f"添加SuperChat记录: {user_info.user_nickname} - {price}元 - {message.superchat_message_text}")
|
||||
|
||||
def get_superchats_by_chat(self, chat_id: str) -> List[SuperChatRecord]:
|
||||
def get_superchats_by_chat(self, chat_id: str) -> list[SuperChatRecord]:
|
||||
"""获取指定聊天的所有有效SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
@@ -198,7 +198,7 @@ class SuperChatManager:
|
||||
valid_superchats = [sc for sc in self.super_chats[chat_id] if not sc.is_expired()]
|
||||
return valid_superchats
|
||||
|
||||
def get_all_valid_superchats(self) -> Dict[str, List[SuperChatRecord]]:
|
||||
def get_all_valid_superchats(self) -> dict[str, list[SuperChatRecord]]:
|
||||
"""获取所有有效的SuperChat"""
|
||||
# 确保清理任务已启动
|
||||
self._ensure_cleanup_task_started()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.common.logger import get_logger
|
||||
from src.config.config import model_config
|
||||
from src.llm_models.utils_model import LLMRequest
|
||||
from src.plugin_system.apis import send_api
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
@@ -11,14 +12,14 @@ class ChatMessage:
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
class AsyncOpenAIClient:
|
||||
"""异步OpenAI客户端,支持流式传输"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
@@ -34,10 +35,10 @@ class AsyncOpenAIClient:
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
@@ -81,10 +82,10 @@ class AsyncOpenAIClient:
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
"""
|
||||
@@ -129,10 +130,10 @@ class AsyncOpenAIClient:
|
||||
|
||||
async def get_stream_content(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
@@ -156,10 +157,10 @@ class AsyncOpenAIClient:
|
||||
|
||||
async def collect_stream_response(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -199,7 +200,7 @@ class AsyncOpenAIClient:
|
||||
class ConversationManager:
|
||||
"""对话管理器,用于管理对话历史"""
|
||||
|
||||
def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None):
|
||||
def __init__(self, client: AsyncOpenAIClient, system_prompt: str | None = None):
|
||||
"""
|
||||
初始化对话管理器
|
||||
|
||||
@@ -208,7 +209,7 @@ class ConversationManager:
|
||||
system_prompt: 系统提示词
|
||||
"""
|
||||
self.client = client
|
||||
self.messages: List[ChatMessage] = []
|
||||
self.messages: list[ChatMessage] = []
|
||||
|
||||
if system_prompt:
|
||||
self.messages.append(ChatMessage(role="system", content=system_prompt))
|
||||
@@ -281,6 +282,6 @@ class ConversationManager:
|
||||
"""获取消息数量"""
|
||||
return len(self.messages)
|
||||
|
||||
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||
def get_conversation_history(self) -> list[dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return [msg.to_dict() for msg in self.messages]
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import os
|
||||
import tomlkit
|
||||
import shutil
|
||||
from dataclasses import MISSING, dataclass, field, fields
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeVar, get_args, get_origin
|
||||
|
||||
import tomlkit
|
||||
from tomlkit import TOMLDocument
|
||||
from tomlkit.items import Table
|
||||
from dataclasses import dataclass, fields, MISSING, field
|
||||
from typing import TypeVar, Type, Any, get_origin, get_args, Literal
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
from typing_extensions import Self
|
||||
|
||||
from src.common.logger import get_logger
|
||||
from src.mais4u.constant_s4u import ENABLE_S4U
|
||||
|
||||
logger = get_logger("s4u_config")
|
||||
|
||||
@@ -46,7 +49,7 @@ class S4UConfigBase:
|
||||
"""S4U配置类的基类"""
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls: Type[T], data: dict[str, Any]) -> T:
|
||||
def from_dict(cls, data: dict[str, Any]) -> Self:
|
||||
"""从字典加载配置字段"""
|
||||
data = table_to_dict(data) # 递归转dict,兼容tomlkit Table
|
||||
if not is_dict_like(data):
|
||||
@@ -81,7 +84,7 @@ class S4UConfigBase:
|
||||
return cls()
|
||||
|
||||
@classmethod
|
||||
def _convert_field(cls, value: Any, field_type: Type[Any]) -> Any:
|
||||
def _convert_field(cls, value: Any, field_type: type[Any]) -> Any:
|
||||
"""转换字段值为指定类型"""
|
||||
# 如果是嵌套的 dataclass,递归调用 from_dict 方法
|
||||
if isinstance(field_type, type) and issubclass(field_type, S4UConfigBase):
|
||||
@@ -271,9 +274,9 @@ def update_s4u_config():
|
||||
return
|
||||
|
||||
# 读取旧配置文件和模板文件
|
||||
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
|
||||
with open(CONFIG_PATH, encoding="utf-8") as f:
|
||||
old_config = tomlkit.load(f)
|
||||
with open(TEMPLATE_PATH, "r", encoding="utf-8") as f:
|
||||
with open(TEMPLATE_PATH, encoding="utf-8") as f:
|
||||
new_config = tomlkit.load(f)
|
||||
|
||||
# 检查version是否相同
|
||||
@@ -344,7 +347,7 @@ def load_s4u_config(config_path: str) -> S4UGlobalConfig:
|
||||
:return: S4UGlobalConfig对象
|
||||
"""
|
||||
# 读取配置文件
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
with open(config_path, encoding="utf-8") as f:
|
||||
config_data = tomlkit.load(f)
|
||||
|
||||
# 创建S4UGlobalConfig对象
|
||||
|
||||
Reference in New Issue
Block a user