fix:用新LLMREQ处理S4u
This commit is contained in:
@@ -1,5 +1,6 @@
|
|||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
from src.mais4u.openai_client import AsyncOpenAIClient
|
from src.llm_models.utils_model import LLMRequest, RequestType
|
||||||
|
from src.llm_models.payload_content.message import MessageBuilder
|
||||||
from src.config.config import model_config
|
from src.config.config import model_config
|
||||||
from src.chat.message_receive.message import MessageRecvS4U
|
from src.chat.message_receive.message import MessageRecvS4U
|
||||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||||
@@ -13,29 +14,12 @@ logger = get_logger("s4u_stream_generator")
|
|||||||
|
|
||||||
class S4UStreamGenerator:
|
class S4UStreamGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
replyer_config = model_config.model_task_config.replyer
|
# 使用LLMRequest替代AsyncOpenAIClient
|
||||||
model_to_use = replyer_config.model_list[0]
|
self.llm_request = LLMRequest(
|
||||||
model_info = model_config.get_model_info(model_to_use)
|
model_set=model_config.model_task_config.replyer,
|
||||||
if not model_info:
|
request_type="s4u_replyer"
|
||||||
logger.error(f"模型 {model_to_use} 在配置中未找到")
|
)
|
||||||
raise ValueError(f"模型 {model_to_use} 在配置中未找到")
|
|
||||||
provider_name = model_info.api_provider
|
|
||||||
provider_info = model_config.get_provider(provider_name)
|
|
||||||
if not provider_info:
|
|
||||||
logger.error("`replyer` 找不到对应的Provider")
|
|
||||||
raise ValueError("`replyer` 找不到对应的Provider")
|
|
||||||
|
|
||||||
api_key = provider_info.api_key
|
|
||||||
base_url = provider_info.base_url
|
|
||||||
|
|
||||||
if not api_key:
|
|
||||||
logger.error(f"{provider_name}没有配置API KEY")
|
|
||||||
raise ValueError(f"{provider_name}没有配置API KEY")
|
|
||||||
|
|
||||||
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
|
|
||||||
self.model_1_name = model_to_use
|
|
||||||
self.replyer_config = replyer_config
|
|
||||||
|
|
||||||
self.current_model_name = "unknown model"
|
self.current_model_name = "unknown model"
|
||||||
self.partial_response = ""
|
self.partial_response = ""
|
||||||
|
|
||||||
@@ -100,68 +84,124 @@ class S4UStreamGenerator:
|
|||||||
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
||||||
) # noqa: E501
|
) # noqa: E501
|
||||||
|
|
||||||
current_client = self.client_1
|
# 使用LLMRequest进行流式生成
|
||||||
self.current_model_name = self.model_1_name
|
async for chunk in self._generate_response_with_llm_request(prompt):
|
||||||
|
|
||||||
extra_kwargs = {}
|
|
||||||
if self.replyer_config.get("enable_thinking") is not None:
|
|
||||||
extra_kwargs["enable_thinking"] = self.replyer_config.get("enable_thinking")
|
|
||||||
if self.replyer_config.get("thinking_budget") is not None:
|
|
||||||
extra_kwargs["thinking_budget"] = self.replyer_config.get("thinking_budget")
|
|
||||||
|
|
||||||
async for chunk in self._generate_response_with_model(
|
|
||||||
prompt, current_client, self.current_model_name, **extra_kwargs
|
|
||||||
):
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _generate_response_with_model(
|
async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]:
|
||||||
self,
|
"""使用LLMRequest进行流式响应生成"""
|
||||||
prompt: str,
|
|
||||||
client: AsyncOpenAIClient,
|
# 构建消息
|
||||||
model_name: str,
|
message_builder = MessageBuilder()
|
||||||
**kwargs,
|
message_builder.add_text_content(prompt)
|
||||||
) -> AsyncGenerator[str, None]:
|
messages = [message_builder.build()]
|
||||||
buffer = ""
|
|
||||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
# 选择模型
|
||||||
|
model_info, api_provider, client = self.llm_request._select_model()
|
||||||
|
self.current_model_name = model_info.name
|
||||||
|
|
||||||
|
# 如果模型支持强制流式模式,使用真正的流式处理
|
||||||
|
if model_info.force_stream_mode:
|
||||||
|
# 简化流式处理:直接使用LLMRequest的流式功能
|
||||||
|
try:
|
||||||
|
# 直接调用LLMRequest的流式处理
|
||||||
|
response = await self.llm_request._execute_request(
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
request_type=RequestType.RESPONSE,
|
||||||
|
model_info=model_info,
|
||||||
|
message_list=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理响应内容
|
||||||
|
content = response.content or ""
|
||||||
|
if content:
|
||||||
|
# 将内容按句子分割并输出
|
||||||
|
async for chunk in self._process_content_streaming(content):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"流式请求执行失败: {e}")
|
||||||
|
# 如果流式请求失败,回退到普通模式
|
||||||
|
response = await self.llm_request._execute_request(
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
request_type=RequestType.RESPONSE,
|
||||||
|
model_info=model_info,
|
||||||
|
message_list=messages,
|
||||||
|
)
|
||||||
|
content = response.content or ""
|
||||||
|
async for chunk in self._process_content_streaming(content):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 如果不支持流式,使用普通方式然后模拟流式输出
|
||||||
|
response = await self.llm_request._execute_request(
|
||||||
|
api_provider=api_provider,
|
||||||
|
client=client,
|
||||||
|
request_type=RequestType.RESPONSE,
|
||||||
|
model_info=model_info,
|
||||||
|
message_list=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
content = response.content or ""
|
||||||
|
async for chunk in self._process_content_streaming(content):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _process_buffer_streaming(self, buffer: str) -> AsyncGenerator[str, None]:
|
||||||
|
"""实时处理缓冲区内容,输出完整句子"""
|
||||||
|
# 使用正则表达式匹配完整句子
|
||||||
|
for match in self.sentence_split_pattern.finditer(buffer):
|
||||||
|
sentence = match.group(0).strip()
|
||||||
|
if sentence and match.end(0) <= len(buffer):
|
||||||
|
# 检查句子是否完整(以标点符号结尾)
|
||||||
|
if sentence.endswith(("。", "!", "?", ".", "!", "?")):
|
||||||
|
if sentence not in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||||
|
self.partial_response += sentence
|
||||||
|
yield sentence
|
||||||
|
|
||||||
|
async def _process_content_streaming(self, content: str) -> AsyncGenerator[str, None]:
|
||||||
|
"""处理内容进行流式输出(用于非流式模型的模拟流式输出)"""
|
||||||
|
buffer = content
|
||||||
punctuation_buffer = ""
|
punctuation_buffer = ""
|
||||||
|
|
||||||
|
# 使用正则表达式匹配句子
|
||||||
|
last_match_end = 0
|
||||||
|
for match in self.sentence_split_pattern.finditer(buffer):
|
||||||
|
sentence = match.group(0).strip()
|
||||||
|
if sentence:
|
||||||
|
# 检查是否只是一个标点符号
|
||||||
|
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||||
|
punctuation_buffer += sentence
|
||||||
|
else:
|
||||||
|
# 发送之前累积的标点和当前句子
|
||||||
|
to_yield = punctuation_buffer + sentence
|
||||||
|
if to_yield.endswith((",", ",")):
|
||||||
|
to_yield = to_yield.rstrip(",,")
|
||||||
|
|
||||||
async for content in client.get_stream_content(
|
self.partial_response += to_yield
|
||||||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
yield to_yield
|
||||||
):
|
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||||
buffer += content
|
|
||||||
|
|
||||||
# 使用正则表达式匹配句子
|
last_match_end = match.end(0)
|
||||||
last_match_end = 0
|
|
||||||
for match in self.sentence_split_pattern.finditer(buffer):
|
|
||||||
sentence = match.group(0).strip()
|
|
||||||
if sentence:
|
|
||||||
# 如果句子看起来完整(即不只是等待更多内容),则发送
|
|
||||||
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
|
|
||||||
# 检查是否只是一个标点符号
|
|
||||||
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
|
||||||
punctuation_buffer += sentence
|
|
||||||
else:
|
|
||||||
# 发送之前累积的标点和当前句子
|
|
||||||
to_yield = punctuation_buffer + sentence
|
|
||||||
if to_yield.endswith((",", ",")):
|
|
||||||
to_yield = to_yield.rstrip(",,")
|
|
||||||
|
|
||||||
self.partial_response += to_yield
|
|
||||||
yield to_yield
|
|
||||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
|
||||||
await asyncio.sleep(0) # 允许其他任务运行
|
|
||||||
|
|
||||||
last_match_end = match.end(0)
|
|
||||||
|
|
||||||
# 从缓冲区移除已发送的部分
|
|
||||||
if last_match_end > 0:
|
|
||||||
buffer = buffer[last_match_end:]
|
|
||||||
|
|
||||||
# 发送缓冲区中剩余的任何内容
|
# 发送缓冲区中剩余的任何内容
|
||||||
to_yield = (punctuation_buffer + buffer).strip()
|
remaining = buffer[last_match_end:].strip()
|
||||||
|
to_yield = (punctuation_buffer + remaining).strip()
|
||||||
if to_yield:
|
if to_yield:
|
||||||
if to_yield.endswith((",", ",")):
|
if to_yield.endswith((",", ",")):
|
||||||
to_yield = to_yield.rstrip(",,")
|
to_yield = to_yield.rstrip(",,")
|
||||||
if to_yield:
|
if to_yield:
|
||||||
self.partial_response += to_yield
|
self.partial_response += to_yield
|
||||||
yield to_yield
|
yield to_yield
|
||||||
|
|
||||||
|
async def _generate_response_with_model(
|
||||||
|
self,
|
||||||
|
prompt: str,
|
||||||
|
client,
|
||||||
|
model_name: str,
|
||||||
|
**kwargs,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""保留原有方法签名以保持兼容性,但重定向到新的实现"""
|
||||||
|
async for chunk in self._generate_response_with_llm_request(prompt):
|
||||||
|
yield chunk
|
||||||
|
|||||||
Reference in New Issue
Block a user