@@ -21,7 +21,7 @@ config = driver.config
|
|||||||
|
|
||||||
class ResponseGenerator:
|
class ResponseGenerator:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000)
|
self.model_r1 = LLM_request(model=global_config.llm_reasoning, temperature=0.7,max_tokens=1000,stream=True)
|
||||||
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
|
self.model_v3 = LLM_request(model=global_config.llm_normal, temperature=0.7,max_tokens=1000)
|
||||||
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
|
self.model_r1_distill = LLM_request(model=global_config.llm_reasoning_minor, temperature=0.7,max_tokens=1000)
|
||||||
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000)
|
self.model_v25 = LLM_request(model=global_config.llm_normal_minor, temperature=0.7,max_tokens=1000)
|
||||||
@@ -194,6 +194,6 @@ class InitiativeMessageGenerate:
|
|||||||
prompt = prompt_builder._build_initiative_prompt(
|
prompt = prompt_builder._build_initiative_prompt(
|
||||||
select_dot, prompt_template, memory
|
select_dot, prompt_template, memory
|
||||||
)
|
)
|
||||||
content, reasoning = self.model_r1.generate_response(prompt)
|
content, reasoning = self.model_r1.generate_response_async(prompt)
|
||||||
print(f"[DEBUG] {content} {reasoning}")
|
print(f"[DEBUG] {content} {reasoning}")
|
||||||
return content
|
return content
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
import aiohttp
|
import aiohttp
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import json
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
import re
|
import re
|
||||||
@@ -138,7 +139,12 @@ class LLM_request:
|
|||||||
}
|
}
|
||||||
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||||
logger.info(f"发送请求到URL: {api_url}")
|
#判断是否为流式
|
||||||
|
stream_mode = self.params.get("stream", False)
|
||||||
|
if self.params.get("stream", False) is True:
|
||||||
|
logger.info(f"进入流式输出模式,发送请求到URL: {api_url}")
|
||||||
|
else:
|
||||||
|
logger.info(f"发送请求到URL: {api_url}")
|
||||||
logger.info(f"使用模型: {self.model_name}")
|
logger.info(f"使用模型: {self.model_name}")
|
||||||
|
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
@@ -151,6 +157,9 @@ class LLM_request:
|
|||||||
try:
|
try:
|
||||||
# 使用上下文管理器处理会话
|
# 使用上下文管理器处理会话
|
||||||
headers = await self._build_headers()
|
headers = await self._build_headers()
|
||||||
|
#似乎是openai流式必须要的东西,不过阿里云的qwq-plus加了这个没有影响
|
||||||
|
if stream_mode:
|
||||||
|
headers["Accept"] = "text/event-stream"
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(api_url, headers=headers, json=payload) as response:
|
async with session.post(api_url, headers=headers, json=payload) as response:
|
||||||
@@ -175,10 +184,39 @@ class LLM_request:
|
|||||||
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
|
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
|
||||||
|
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
result = await response.json()
|
|
||||||
|
|
||||||
# 使用自定义处理器或默认处理
|
if stream_mode:
|
||||||
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint)
|
accumulated_content = ""
|
||||||
|
async for line_bytes in response.content:
|
||||||
|
line = line_bytes.decode("utf-8").strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
if line.startswith("data:"):
|
||||||
|
data_str = line[5:].strip()
|
||||||
|
if data_str == "[DONE]":
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = json.loads(data_str)
|
||||||
|
delta = chunk["choices"][0]["delta"]
|
||||||
|
delta_content = delta.get("content")
|
||||||
|
if delta_content is None:
|
||||||
|
delta_content = ""
|
||||||
|
accumulated_content += delta_content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解析流式输出错误: {e}")
|
||||||
|
content = accumulated_content
|
||||||
|
reasoning_content = ""
|
||||||
|
think_match = re.search(r'<think>(.*?)</think>', content, re.DOTALL)
|
||||||
|
if think_match:
|
||||||
|
reasoning_content = think_match.group(1).strip()
|
||||||
|
content = re.sub(r'<think>.*?</think>', '', content, flags=re.DOTALL).strip()
|
||||||
|
# 构造一个伪result以便调用自定义响应处理器或默认处理器
|
||||||
|
result = {"choices": [{"message": {"content": content, "reasoning_content": reasoning_content}}]}
|
||||||
|
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint)
|
||||||
|
else:
|
||||||
|
result = await response.json()
|
||||||
|
# 使用自定义处理器或默认处理
|
||||||
|
return response_handler(result) if response_handler else self._default_response_handler(result, user_id, request_type, endpoint)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < policy["max_retries"] - 1:
|
if retry < policy["max_retries"] - 1:
|
||||||
@@ -195,8 +233,18 @@ class LLM_request:
|
|||||||
|
|
||||||
async def _build_payload(self, prompt: str, image_base64: str = None) -> dict:
|
async def _build_payload(self, prompt: str, image_base64: str = None) -> dict:
|
||||||
"""构建请求体"""
|
"""构建请求体"""
|
||||||
|
# 复制一份参数,避免直接修改 self.params
|
||||||
|
params_copy = dict(self.params)
|
||||||
|
if self.model_name.lower() == "o3-mini" or "o1-mini" or "o1" or "o1-2024-12-17" or "o1-preview-2024-09-12" or "o3-mini-2025-01-31" or "o1-mini-2024-09-12":
|
||||||
|
# 删除可能存在的 'temprature' 参数
|
||||||
|
params_copy.pop("temprature", None)
|
||||||
|
# 如果存在 'max_tokens' 参数,则将其替换为 'max_completion_tokens'
|
||||||
|
if "max_tokens" in params_copy:
|
||||||
|
params_copy["max_completion_tokens"] = params_copy.pop("max_tokens")
|
||||||
|
# 构造基础请求体,注意这里依然使用 global_config.max_response_length 填充 'max_tokens'
|
||||||
|
# 如果需要统一改为 max_completion_tokens,也可以在下面做相同的调整
|
||||||
if image_base64:
|
if image_base64:
|
||||||
return {
|
payload = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
@@ -208,15 +256,20 @@ class LLM_request:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
"max_tokens": global_config.max_response_length,
|
"max_tokens": global_config.max_response_length,
|
||||||
**self.params
|
**params_copy
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
return {
|
payload = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"max_tokens": global_config.max_response_length,
|
"max_tokens": global_config.max_response_length,
|
||||||
**self.params
|
**params_copy
|
||||||
}
|
}
|
||||||
|
# 如果是 o3-mini 模型,也将基础请求体中的 max_tokens 改为 max_completion_tokens
|
||||||
|
if self.model_name.lower() == "o3-mini" and "max_tokens" in payload:
|
||||||
|
payload["max_completion_tokens"] = payload.pop("max_tokens")
|
||||||
|
return payload
|
||||||
|
|
||||||
|
|
||||||
def _default_response_handler(self, result: dict, user_id: str = "system",
|
def _default_response_handler(self, result: dict, user_id: str = "system",
|
||||||
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
|
request_type: str = "chat", endpoint: str = "/chat/completions") -> Tuple:
|
||||||
|
|||||||
Reference in New Issue
Block a user