fix:用新LLMREQ处理S4u
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from typing import AsyncGenerator
|
||||
from src.mais4u.openai_client import AsyncOpenAIClient
|
||||
from src.llm_models.utils_model import LLMRequest, RequestType
|
||||
from src.llm_models.payload_content.message import MessageBuilder
|
||||
from src.config.config import model_config
|
||||
from src.chat.message_receive.message import MessageRecvS4U
|
||||
from src.mais4u.mais4u_chat.s4u_prompt import prompt_builder
|
||||
@@ -13,29 +14,12 @@ logger = get_logger("s4u_stream_generator")
|
||||
|
||||
class S4UStreamGenerator:
|
||||
def __init__(self):
|
||||
replyer_config = model_config.model_task_config.replyer
|
||||
model_to_use = replyer_config.model_list[0]
|
||||
model_info = model_config.get_model_info(model_to_use)
|
||||
if not model_info:
|
||||
logger.error(f"模型 {model_to_use} 在配置中未找到")
|
||||
raise ValueError(f"模型 {model_to_use} 在配置中未找到")
|
||||
provider_name = model_info.api_provider
|
||||
provider_info = model_config.get_provider(provider_name)
|
||||
if not provider_info:
|
||||
logger.error("`replyer` 找不到对应的Provider")
|
||||
raise ValueError("`replyer` 找不到对应的Provider")
|
||||
|
||||
api_key = provider_info.api_key
|
||||
base_url = provider_info.base_url
|
||||
|
||||
if not api_key:
|
||||
logger.error(f"{provider_name}没有配置API KEY")
|
||||
raise ValueError(f"{provider_name}没有配置API KEY")
|
||||
|
||||
self.client_1 = AsyncOpenAIClient(api_key=api_key, base_url=base_url)
|
||||
self.model_1_name = model_to_use
|
||||
self.replyer_config = replyer_config
|
||||
|
||||
# 使用LLMRequest替代AsyncOpenAIClient
|
||||
self.llm_request = LLMRequest(
|
||||
model_set=model_config.model_task_config.replyer,
|
||||
request_type="s4u_replyer"
|
||||
)
|
||||
|
||||
self.current_model_name = "unknown model"
|
||||
self.partial_response = ""
|
||||
|
||||
@@ -100,68 +84,124 @@ class S4UStreamGenerator:
|
||||
f"{self.current_model_name}思考:{message_txt[:30] + '...' if len(message_txt) > 30 else message_txt}"
|
||||
) # noqa: E501
|
||||
|
||||
current_client = self.client_1
|
||||
self.current_model_name = self.model_1_name
|
||||
|
||||
extra_kwargs = {}
|
||||
if self.replyer_config.get("enable_thinking") is not None:
|
||||
extra_kwargs["enable_thinking"] = self.replyer_config.get("enable_thinking")
|
||||
if self.replyer_config.get("thinking_budget") is not None:
|
||||
extra_kwargs["thinking_budget"] = self.replyer_config.get("thinking_budget")
|
||||
|
||||
async for chunk in self._generate_response_with_model(
|
||||
prompt, current_client, self.current_model_name, **extra_kwargs
|
||||
):
|
||||
# 使用LLMRequest进行流式生成
|
||||
async for chunk in self._generate_response_with_llm_request(prompt):
|
||||
yield chunk
|
||||
|
||||
async def _generate_response_with_model(
|
||||
self,
|
||||
prompt: str,
|
||||
client: AsyncOpenAIClient,
|
||||
model_name: str,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
buffer = ""
|
||||
delimiters = ",。!?,.!?\n\r" # For final trimming
|
||||
async def _generate_response_with_llm_request(self, prompt: str) -> AsyncGenerator[str, None]:
|
||||
"""使用LLMRequest进行流式响应生成"""
|
||||
|
||||
# 构建消息
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
# 选择模型
|
||||
model_info, api_provider, client = self.llm_request._select_model()
|
||||
self.current_model_name = model_info.name
|
||||
|
||||
# 如果模型支持强制流式模式,使用真正的流式处理
|
||||
if model_info.force_stream_mode:
|
||||
# 简化流式处理:直接使用LLMRequest的流式功能
|
||||
try:
|
||||
# 直接调用LLMRequest的流式处理
|
||||
response = await self.llm_request._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
)
|
||||
|
||||
# 处理响应内容
|
||||
content = response.content or ""
|
||||
if content:
|
||||
# 将内容按句子分割并输出
|
||||
async for chunk in self._process_content_streaming(content):
|
||||
yield chunk
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"流式请求执行失败: {e}")
|
||||
# 如果流式请求失败,回退到普通模式
|
||||
response = await self.llm_request._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
)
|
||||
content = response.content or ""
|
||||
async for chunk in self._process_content_streaming(content):
|
||||
yield chunk
|
||||
|
||||
else:
|
||||
# 如果不支持流式,使用普通方式然后模拟流式输出
|
||||
response = await self.llm_request._execute_request(
|
||||
api_provider=api_provider,
|
||||
client=client,
|
||||
request_type=RequestType.RESPONSE,
|
||||
model_info=model_info,
|
||||
message_list=messages,
|
||||
)
|
||||
|
||||
content = response.content or ""
|
||||
async for chunk in self._process_content_streaming(content):
|
||||
yield chunk
|
||||
|
||||
async def _process_buffer_streaming(self, buffer: str) -> AsyncGenerator[str, None]:
|
||||
"""实时处理缓冲区内容,输出完整句子"""
|
||||
# 使用正则表达式匹配完整句子
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
sentence = match.group(0).strip()
|
||||
if sentence and match.end(0) <= len(buffer):
|
||||
# 检查句子是否完整(以标点符号结尾)
|
||||
if sentence.endswith(("。", "!", "?", ".", "!", "?")):
|
||||
if sentence not in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||
self.partial_response += sentence
|
||||
yield sentence
|
||||
|
||||
async def _process_content_streaming(self, content: str) -> AsyncGenerator[str, None]:
|
||||
"""处理内容进行流式输出(用于非流式模型的模拟流式输出)"""
|
||||
buffer = content
|
||||
punctuation_buffer = ""
|
||||
|
||||
# 使用正则表达式匹配句子
|
||||
last_match_end = 0
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
sentence = match.group(0).strip()
|
||||
if sentence:
|
||||
# 检查是否只是一个标点符号
|
||||
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||
punctuation_buffer += sentence
|
||||
else:
|
||||
# 发送之前累积的标点和当前句子
|
||||
to_yield = punctuation_buffer + sentence
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
|
||||
async for content in client.get_stream_content(
|
||||
messages=[{"role": "user", "content": prompt}], model=model_name, **kwargs
|
||||
):
|
||||
buffer += content
|
||||
self.partial_response += to_yield
|
||||
yield to_yield
|
||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||
|
||||
# 使用正则表达式匹配句子
|
||||
last_match_end = 0
|
||||
for match in self.sentence_split_pattern.finditer(buffer):
|
||||
sentence = match.group(0).strip()
|
||||
if sentence:
|
||||
# 如果句子看起来完整(即不只是等待更多内容),则发送
|
||||
if match.end(0) < len(buffer) or sentence.endswith(tuple(delimiters)):
|
||||
# 检查是否只是一个标点符号
|
||||
if sentence in [",", ",", ".", "。", "!", "!", "?", "?"]:
|
||||
punctuation_buffer += sentence
|
||||
else:
|
||||
# 发送之前累积的标点和当前句子
|
||||
to_yield = punctuation_buffer + sentence
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
|
||||
self.partial_response += to_yield
|
||||
yield to_yield
|
||||
punctuation_buffer = "" # 清空标点符号缓冲区
|
||||
await asyncio.sleep(0) # 允许其他任务运行
|
||||
|
||||
last_match_end = match.end(0)
|
||||
|
||||
# 从缓冲区移除已发送的部分
|
||||
if last_match_end > 0:
|
||||
buffer = buffer[last_match_end:]
|
||||
last_match_end = match.end(0)
|
||||
|
||||
# 发送缓冲区中剩余的任何内容
|
||||
to_yield = (punctuation_buffer + buffer).strip()
|
||||
remaining = buffer[last_match_end:].strip()
|
||||
to_yield = (punctuation_buffer + remaining).strip()
|
||||
if to_yield:
|
||||
if to_yield.endswith((",", ",")):
|
||||
to_yield = to_yield.rstrip(",,")
|
||||
if to_yield:
|
||||
self.partial_response += to_yield
|
||||
yield to_yield
|
||||
|
||||
async def _generate_response_with_model(
|
||||
self,
|
||||
prompt: str,
|
||||
client,
|
||||
model_name: str,
|
||||
**kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""保留原有方法签名以保持兼容性,但重定向到新的实现"""
|
||||
async for chunk in self._generate_response_with_llm_request(prompt):
|
||||
yield chunk
|
||||
|
||||
Reference in New Issue
Block a user