re-style: 格式化代码
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import AsyncGenerator, Dict, List, Optional, Union
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
|
||||
@@ -11,14 +12,14 @@ class ChatMessage:
|
||||
role: str
|
||||
content: str
|
||||
|
||||
def to_dict(self) -> Dict[str, str]:
|
||||
def to_dict(self) -> dict[str, str]:
|
||||
return {"role": self.role, "content": self.content}
|
||||
|
||||
|
||||
class AsyncOpenAIClient:
|
||||
"""异步OpenAI客户端,支持流式传输"""
|
||||
|
||||
def __init__(self, api_key: str, base_url: Optional[str] = None):
|
||||
def __init__(self, api_key: str, base_url: str | None = None):
|
||||
"""
|
||||
初始化客户端
|
||||
|
||||
@@ -34,10 +35,10 @@ class AsyncOpenAIClient:
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> ChatCompletion:
|
||||
"""
|
||||
@@ -81,10 +82,10 @@ class AsyncOpenAIClient:
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[ChatCompletionChunk, None]:
|
||||
"""
|
||||
@@ -129,10 +130,10 @@ class AsyncOpenAIClient:
|
||||
|
||||
async def get_stream_content(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
@@ -156,10 +157,10 @@ class AsyncOpenAIClient:
|
||||
|
||||
async def collect_stream_response(
|
||||
self,
|
||||
messages: List[Union[ChatMessage, Dict[str, str]]],
|
||||
messages: list[ChatMessage | dict[str, str]],
|
||||
model: str = "gpt-3.5-turbo",
|
||||
temperature: float = 0.7,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_tokens: int | None = None,
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -199,7 +200,7 @@ class AsyncOpenAIClient:
|
||||
class ConversationManager:
|
||||
"""对话管理器,用于管理对话历史"""
|
||||
|
||||
def __init__(self, client: AsyncOpenAIClient, system_prompt: Optional[str] = None):
|
||||
def __init__(self, client: AsyncOpenAIClient, system_prompt: str | None = None):
|
||||
"""
|
||||
初始化对话管理器
|
||||
|
||||
@@ -208,7 +209,7 @@ class ConversationManager:
|
||||
system_prompt: 系统提示词
|
||||
"""
|
||||
self.client = client
|
||||
self.messages: List[ChatMessage] = []
|
||||
self.messages: list[ChatMessage] = []
|
||||
|
||||
if system_prompt:
|
||||
self.messages.append(ChatMessage(role="system", content=system_prompt))
|
||||
@@ -281,6 +282,6 @@ class ConversationManager:
|
||||
"""获取消息数量"""
|
||||
return len(self.messages)
|
||||
|
||||
def get_conversation_history(self) -> List[Dict[str, str]]:
|
||||
def get_conversation_history(self) -> list[dict[str, str]]:
|
||||
"""获取对话历史"""
|
||||
return [msg.to_dict() for msg in self.messages]
|
||||
|
||||
Reference in New Issue
Block a user