diff --git a/src/mais4u/mais4u_chat/s4u_stream_generator.py b/src/mais4u/mais4u_chat/s4u_stream_generator.py index da12d9f9d..04689f5e0 100644 --- a/src/mais4u/mais4u_chat/s4u_stream_generator.py +++ b/src/mais4u/mais4u_chat/s4u_stream_generator.py @@ -1,5 +1,6 @@ from typing import AsyncGenerator -from src.mais4u.openai_client import AsyncOpenAIClient +from src.llm_models.utils_model import LLMRequest, RequestType +from src.llm_models.payload_content.message import MessageBuilder from src.config.config import model_config from src.chat.message_receive.message import MessageRecvS4U from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder @@ -13,29 +14,12 @@ logger = get_logger("s4u_stream_generator") class S4UStreamGenerator: def __init__(self): - replyer_config = model_config.model_task_config.replyer - model_to_use = replyer_config.model_list[0] - model_info = model_config.get_model_info(model_to_use) - if not model_info: - logger.error(f"模型 {model_to_use} 在配置中未找到") - raise ValueError(f"模型 {model_to_use} 在配置中未找到") - provider_name = model_info.api_provider - provider_info = model_config.get_provider(provider_name) - if not provider_info: - logger.error("`replyer` 找不到对应的Provider") - raise ValueError("`replyer` 找不到对应的Provider") - - api_key = provider_info.api_key - base_url = provider_info.base_url - - if not api_key: - logger.error(f"{provider_name}没有配置API KEY") - raise ValueError(f"{provider_name}没有配置API KEY") - - self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url) - self.model_1_name = model_to_use - self.replyer_config = replyer_config - + # 使用LLMRequest替代AsyncOpenAIClient + self.llm_request = LLMRequest( + model_set=model_config.model_task_config.replyer, + request_type="s4u_replyer" + ) + self.current_model_name = "unknown model" self.partial_response = "" @@ -100,68 +84,124 @@ class S4UStreamGenerator: f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}" ) # noqa: E501 - current_client = self.client_1 - self.current_model_name = self.model_1_name - - extra_kwargs = {} - if self.replyer_config.get("enable_thinking") is not None: - extra_kwargs["enable_thinking"] = self.replyer_config.get("enable_thinking") - if self.replyer_config.get("thinking_budget") is not None: - extra_kwargs["thinking_budget"] = self.replyer_config.get("thinking_budget") - - async for chunk in self._generate_response_with_model( - prompt, current_client, self.current_model_name, **extra_kwargs - ): + # 使用LLMRequest进行流式生成 + async for chunk in self._generate_response_with_llm_request(prompt): yield chunk - async def _generate_response_with_model( - self, - prompt: str, - client: AsyncOpenAIClient, - model_name: str, - **kwargs, - ) -> AsyncGenerator[str, None]: - buffer = "" - delimiters = ",。!?,.!?\n\r" # For final trimming + async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]: + """使用LLMRequest进行流式响应生成""" + + # 构建消息 + message_builder = MessageBuilder() + message_builder.add_text_content(prompt) + messages = [message_builder.build()] + + # 选择模型 + model_info, api_provider, client = self.llm_request._select_model() + self.current_model_name = model_info.name + + # 如果模型支持强制流式模式,使用真正的流式处理 + if model_info.force_stream_mode: + # 简化流式处理:直接使用LLMRequest的流式功能 + try: + # 直接调用LLMRequest的流式处理 + response = await self.llm_request._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + ) + + # 处理响应内容 + content = response.content or "" + if content: + # 将内容按句子分割并输出 + async for chunk in self._process_content_streaming(content): + yield chunk + + except Exception as e: + logger.error(f"流式请求执行失败: {e}") + # 如果流式请求失败,回退到普通模式 + response = await self.llm_request._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + ) + content = response.content or "" + async for chunk in self._process_content_streaming(content): + yield chunk + + else: + # 如果不支持流式,使用普通方式然后模拟流式输出 + response = await self.llm_request._execute_request( + api_provider=api_provider, + client=client, + request_type=RequestType.RESPONSE, + model_info=model_info, + message_list=messages, + ) + + content = response.content or "" + async for chunk in self._process_content_streaming(content): + yield chunk + + async def _process_buffer_streaming(self, buffer: str) -> AsyncGenerator[str, None]: + """实时处理缓冲区内容,输出完整句子""" + # 使用正则表达式匹配完整句子 + for match in self.sentence_split_pattern.finditer(buffer): + sentence = match.group(0).strip() + if sentence and match.end(0) <= len(buffer): + # 检查句子是否完整(以标点符号结尾) + if sentence.endswith(("。", "!", "?", ".", "!", "?")): + if sentence not in [",", ",", ".", "。", "!", "!", "?", "?"]: + self.partial_response += sentence + yield sentence + + async def _process_content_streaming(self, content: str) -> AsyncGenerator[str, None]: + """处理内容进行流式输出(用于非流式模型的模拟流式输出)""" + buffer = content punctuation_buffer = "" + + # 使用正则表达式匹配句子 + last_match_end = 0 + for match in self.sentence_split_pattern.finditer(buffer): + sentence = match.group(0).strip() + if sentence: + # 检查是否只是一个标点符号 + if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]: + punctuation_buffer += sentence + else: + # 发送之前累积的标点和当前句子 + to_yield = punctuation_buffer + sentence + if to_yield.endswith((",", ",")): + to_yield = to_yield.rstrip(",,") - async for content in client.get_stream_content( - messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs - ): - buffer += content + self.partial_response += to_yield + yield to_yield + punctuation_buffer = "" # 清空标点符号缓冲区 - # 使用正则表达式匹配句子 - last_match_end = 0 - for match in self.sentence_split_pattern.finditer(buffer): - sentence = match.group(0).strip() - if sentence: - # 如果句子看起来完整(即不只是等待更多内容),则发送 - if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)): - # 检查是否只是一个标点符号 - if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]: - punctuation_buffer += sentence - else: - # 发送之前累积的标点和当前句子 - to_yield = punctuation_buffer + sentence - if to_yield.endswith((",", ",")): - to_yield = to_yield.rstrip(",,") - - self.partial_response += to_yield - yield to_yield - punctuation_buffer = "" # 清空标点符号缓冲区 - await asyncio.sleep(0) # 允许其他任务运行 - - last_match_end = match.end(0) - - # 从缓冲区移除已发送的部分 - if last_match_end > 0: - buffer = buffer[last_match_end:] + last_match_end = match.end(0) # 发送缓冲区中剩余的任何内容 - to_yield = (punctuation_buffer + buffer).strip() + remaining = buffer[last_match_end:].strip() + to_yield = (punctuation_buffer + remaining).strip() if to_yield: if to_yield.endswith((",", ",")): to_yield = to_yield.rstrip(",,") if to_yield: self.partial_response += to_yield yield to_yield + + async def _generate_response_with_model( + self, + prompt: str, + client, + model_name: str, + **kwargs, + ) -> AsyncGenerator[str, None]: + """保留原有方法签名以保持兼容性,但重定向到新的实现""" + async for chunk in self._generate_response_with_llm_request(prompt): + yield chunk