修复代码格式和文件名大小写问题
This commit is contained in:
@@ -23,23 +23,23 @@ logger = get_logger("AioHTTP-Gemini客户端")
|
||||
def _format_to_mime_type(image_format: str) -> str:
|
||||
"""
|
||||
将图片格式转换为正确的MIME类型
|
||||
|
||||
|
||||
Args:
|
||||
image_format (str): 图片格式 (如 'jpg', 'png' 等)
|
||||
|
||||
|
||||
Returns:
|
||||
str: 对应的MIME类型
|
||||
"""
|
||||
format_mapping = {
|
||||
"jpg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"jpeg": "image/jpeg",
|
||||
"png": "image/png",
|
||||
"webp": "image/webp",
|
||||
"gif": "image/gif",
|
||||
"heic": "image/heic",
|
||||
"heif": "image/heif"
|
||||
"heif": "image/heif",
|
||||
}
|
||||
|
||||
|
||||
return format_mapping.get(image_format.lower(), f"image/{image_format.lower()}")
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ def _convert_messages(messages: list[Message]) -> tuple[list[dict], list[str] |
|
||||
:param messages: 消息列表
|
||||
:return: (contents, system_instructions)
|
||||
"""
|
||||
|
||||
|
||||
def _convert_message_item(message: Message) -> dict:
|
||||
"""转换单个消息格式"""
|
||||
# 转换角色名称
|
||||
@@ -59,7 +59,7 @@ def _convert_messages(messages: list[Message]) -> tuple[list[dict], list[str] |
|
||||
role = "user"
|
||||
else:
|
||||
raise ValueError(f"不支持的消息角色: {message.role}")
|
||||
|
||||
|
||||
# 转换内容
|
||||
parts = []
|
||||
if isinstance(message.content, str):
|
||||
@@ -67,25 +67,17 @@ def _convert_messages(messages: list[Message]) -> tuple[list[dict], list[str] |
|
||||
elif isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if isinstance(item, tuple): # (format, base64_data)
|
||||
parts.append({
|
||||
"inline_data": {
|
||||
"mime_type": _format_to_mime_type(item[0]),
|
||||
"data": item[1]
|
||||
}
|
||||
})
|
||||
parts.append({"inline_data": {"mime_type": _format_to_mime_type(item[0]), "data": item[1]}})
|
||||
elif isinstance(item, str):
|
||||
parts.append({"text": item})
|
||||
else:
|
||||
raise RuntimeError("无法触及的代码:请使用MessageBuilder类构建消息对象")
|
||||
|
||||
return {
|
||||
"role": role,
|
||||
"parts": parts
|
||||
}
|
||||
|
||||
|
||||
return {"role": role, "parts": parts}
|
||||
|
||||
contents = []
|
||||
system_instructions = []
|
||||
|
||||
|
||||
for message in messages:
|
||||
if message.role == RoleType.System:
|
||||
if isinstance(message.content, str):
|
||||
@@ -96,13 +88,10 @@ def _convert_messages(messages: list[Message]) -> tuple[list[dict], list[str] |
|
||||
# 工具调用结果处理
|
||||
if not message.tool_call_id:
|
||||
raise ValueError("工具调用消息缺少tool_call_id")
|
||||
contents.append({
|
||||
"role": "function",
|
||||
"parts": [{"text": str(message.content)}]
|
||||
})
|
||||
contents.append({"role": "function", "parts": [{"text": str(message.content)}]})
|
||||
else:
|
||||
contents.append(_convert_message_item(message))
|
||||
|
||||
|
||||
return contents, system_instructions if system_instructions else None
|
||||
|
||||
|
||||
@@ -110,7 +99,7 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
|
||||
"""
|
||||
转换工具选项格式 - 将工具选项转换为Gemini REST API所需的格式
|
||||
"""
|
||||
|
||||
|
||||
def _convert_tool_param(param: ToolParam) -> dict:
|
||||
"""转换工具参数"""
|
||||
result = {
|
||||
@@ -120,40 +109,28 @@ def _convert_tool_options(tool_options: list[ToolOption]) -> list[dict]:
|
||||
if param.enum_values:
|
||||
result["enum"] = param.enum_values
|
||||
return result
|
||||
|
||||
|
||||
def _convert_tool_option_item(tool_option: ToolOption) -> dict:
|
||||
"""转换单个工具选项"""
|
||||
function_declaration = {
|
||||
"name": tool_option.name,
|
||||
"description": tool_option.description,
|
||||
}
|
||||
|
||||
|
||||
if tool_option.params:
|
||||
function_declaration["parameters"] = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
param.name: _convert_tool_param(param)
|
||||
for param in tool_option.params
|
||||
},
|
||||
"required": [
|
||||
param.name
|
||||
for param in tool_option.params
|
||||
if param.required
|
||||
],
|
||||
"properties": {param.name: _convert_tool_param(param) for param in tool_option.params},
|
||||
"required": [param.name for param in tool_option.params if param.required],
|
||||
}
|
||||
|
||||
return {
|
||||
"function_declarations": [function_declaration]
|
||||
}
|
||||
|
||||
|
||||
return {"function_declarations": [function_declaration]}
|
||||
|
||||
return [_convert_tool_option_item(tool_option) for tool_option in tool_options]
|
||||
|
||||
|
||||
def _build_generation_config(
|
||||
max_tokens: int,
|
||||
temperature: float,
|
||||
response_format: RespFormat | None = None,
|
||||
extra_params: dict | None = None
|
||||
max_tokens: int, temperature: float, response_format: RespFormat | None = None, extra_params: dict | None = None
|
||||
) -> dict:
|
||||
"""构建生成配置"""
|
||||
config = {
|
||||
@@ -162,7 +139,7 @@ def _build_generation_config(
|
||||
"topK": 1,
|
||||
"topP": 1,
|
||||
}
|
||||
|
||||
|
||||
# 处理响应格式
|
||||
if response_format:
|
||||
if response_format.format_type == RespFormatType.JSON_OBJ:
|
||||
@@ -170,95 +147,89 @@ def _build_generation_config(
|
||||
elif response_format.format_type == RespFormatType.JSON_SCHEMA:
|
||||
config["responseMimeType"] = "application/json"
|
||||
config["responseSchema"] = response_format.to_dict()
|
||||
|
||||
|
||||
# 合并额外参数
|
||||
if extra_params:
|
||||
config.update(extra_params)
|
||||
|
||||
|
||||
return config
|
||||
|
||||
|
||||
class AiohttpGeminiStreamParser:
|
||||
"""流式响应解析器"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
self.content_buffer = io.StringIO()
|
||||
self.reasoning_buffer = io.StringIO()
|
||||
self.tool_calls_buffer = []
|
||||
self.usage_record = None
|
||||
|
||||
|
||||
def parse_chunk(self, chunk_text: str):
|
||||
"""解析单个流式数据块"""
|
||||
try:
|
||||
if not chunk_text.strip():
|
||||
return
|
||||
|
||||
|
||||
# 移除data:前缀
|
||||
if chunk_text.startswith("data: "):
|
||||
chunk_text = chunk_text[6:].strip()
|
||||
|
||||
|
||||
if chunk_text == "[DONE]":
|
||||
return
|
||||
|
||||
|
||||
chunk_data = orjson.loads(chunk_text)
|
||||
|
||||
|
||||
# 解析候选项
|
||||
if "candidates" in chunk_data and chunk_data["candidates"]:
|
||||
candidate = chunk_data["candidates"][0]
|
||||
|
||||
|
||||
# 解析内容
|
||||
if "content" in candidate and "parts" in candidate["content"]:
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
self.content_buffer.write(part["text"])
|
||||
|
||||
|
||||
# 解析工具调用
|
||||
if "functionCall" in candidate:
|
||||
func_call = candidate["functionCall"]
|
||||
call_id = f"gemini_call_{len(self.tool_calls_buffer)}"
|
||||
self.tool_calls_buffer.append({
|
||||
"id": call_id,
|
||||
"name": func_call.get("name", ""),
|
||||
"args": func_call.get("args", {})
|
||||
})
|
||||
|
||||
self.tool_calls_buffer.append(
|
||||
{"id": call_id, "name": func_call.get("name", ""), "args": func_call.get("args", {})}
|
||||
)
|
||||
|
||||
# 解析使用统计
|
||||
if "usageMetadata" in chunk_data:
|
||||
usage = chunk_data["usageMetadata"]
|
||||
self.usage_record = (
|
||||
usage.get("promptTokenCount", 0),
|
||||
usage.get("candidatesTokenCount", 0),
|
||||
usage.get("totalTokenCount", 0)
|
||||
usage.get("totalTokenCount", 0),
|
||||
)
|
||||
|
||||
|
||||
except orjson.JSONDecodeError as e:
|
||||
logger.warning(f"解析流式数据块失败: {e}, 数据: {chunk_text}")
|
||||
except Exception as e:
|
||||
logger.error(f"处理流式数据块时出错: {e}")
|
||||
|
||||
|
||||
def get_response(self) -> APIResponse:
|
||||
"""获取最终响应"""
|
||||
response = APIResponse()
|
||||
|
||||
|
||||
if self.content_buffer.tell() > 0:
|
||||
response.content = self.content_buffer.getvalue()
|
||||
|
||||
|
||||
if self.reasoning_buffer.tell() > 0:
|
||||
response.reasoning_content = self.reasoning_buffer.getvalue()
|
||||
|
||||
|
||||
if self.tool_calls_buffer:
|
||||
response.tool_calls = []
|
||||
for call_data in self.tool_calls_buffer:
|
||||
response.tool_calls.append(ToolCall(
|
||||
call_data["id"],
|
||||
call_data["name"],
|
||||
call_data["args"]
|
||||
))
|
||||
|
||||
response.tool_calls.append(ToolCall(call_data["id"], call_data["name"], call_data["args"]))
|
||||
|
||||
# 清理缓冲区
|
||||
self.content_buffer.close()
|
||||
self.reasoning_buffer.close()
|
||||
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@@ -268,19 +239,19 @@ async def _default_stream_response_handler(
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
"""默认流式响应处理器"""
|
||||
parser = AiohttpGeminiStreamParser()
|
||||
|
||||
|
||||
try:
|
||||
async for line in response.content:
|
||||
if interrupt_flag and interrupt_flag.is_set():
|
||||
raise ReqAbortException("请求被外部信号中断")
|
||||
|
||||
line_text = line.decode('utf-8').strip()
|
||||
|
||||
line_text = line.decode("utf-8").strip()
|
||||
if line_text:
|
||||
parser.parse_chunk(line_text)
|
||||
|
||||
|
||||
api_response = parser.get_response()
|
||||
return api_response, parser.usage_record
|
||||
|
||||
|
||||
except Exception as e:
|
||||
if not isinstance(e, ReqAbortException):
|
||||
raise RespParseException(None, f"流式响应解析失败: {e}") from e
|
||||
@@ -292,31 +263,29 @@ def _default_normal_response_parser(
|
||||
) -> tuple[APIResponse, Optional[tuple[int, int, int]]]:
|
||||
"""默认普通响应解析器"""
|
||||
api_response = APIResponse()
|
||||
|
||||
|
||||
try:
|
||||
# 解析候选项
|
||||
if "candidates" in response_data and response_data["candidates"]:
|
||||
candidate = response_data["candidates"][0]
|
||||
|
||||
|
||||
# 解析文本内容
|
||||
if "content" in candidate and "parts" in candidate["content"]:
|
||||
content_parts = []
|
||||
for part in candidate["content"]["parts"]:
|
||||
if "text" in part:
|
||||
content_parts.append(part["text"])
|
||||
|
||||
|
||||
if content_parts:
|
||||
api_response.content = "".join(content_parts)
|
||||
|
||||
|
||||
# 解析工具调用
|
||||
if "functionCall" in candidate:
|
||||
func_call = candidate["functionCall"]
|
||||
api_response.tool_calls = [ToolCall(
|
||||
"gemini_call_0",
|
||||
func_call.get("name", ""),
|
||||
func_call.get("args", {})
|
||||
)]
|
||||
|
||||
api_response.tool_calls = [
|
||||
ToolCall("gemini_call_0", func_call.get("name", ""), func_call.get("args", {}))
|
||||
]
|
||||
|
||||
# 解析使用统计
|
||||
usage_record = None
|
||||
if "usageMetadata" in response_data:
|
||||
@@ -324,12 +293,12 @@ def _default_normal_response_parser(
|
||||
usage_record = (
|
||||
usage.get("promptTokenCount", 0),
|
||||
usage.get("candidatesTokenCount", 0),
|
||||
usage.get("totalTokenCount", 0)
|
||||
usage.get("totalTokenCount", 0),
|
||||
)
|
||||
|
||||
|
||||
api_response.raw_data = response_data
|
||||
return api_response, usage_record
|
||||
|
||||
|
||||
except Exception as e:
|
||||
raise RespParseException(response_data, f"响应解析失败: {e}") from e
|
||||
|
||||
@@ -337,26 +306,21 @@ def _default_normal_response_parser(
|
||||
@client_registry.register_client_class("aiohttp_gemini")
|
||||
class AiohttpGeminiClient(BaseClient):
|
||||
"""使用aiohttp的Gemini客户端"""
|
||||
|
||||
|
||||
def __init__(self, api_provider: APIProvider):
|
||||
super().__init__(api_provider)
|
||||
self.base_url = "https://generativelanguage.googleapis.com/v1beta"
|
||||
self.session: aiohttp.ClientSession | None = None
|
||||
self.api_key = api_provider.api_key
|
||||
|
||||
|
||||
# 如果提供了自定义base_url,使用它
|
||||
if api_provider.base_url:
|
||||
self.base_url = api_provider.base_url.rstrip('/')
|
||||
|
||||
self.base_url = api_provider.base_url.rstrip("/")
|
||||
|
||||
# 移除全局 session,全部请求都用 with aiohttp.ClientSession() as session:
|
||||
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
endpoint: str,
|
||||
data: dict | None = None,
|
||||
stream: bool = False
|
||||
self, method: str, endpoint: str, data: dict | None = None, stream: bool = False
|
||||
) -> aiohttp.ClientResponse:
|
||||
"""发起HTTP请求(每次都用 with aiohttp.ClientSession() as session)"""
|
||||
url = f"{self.base_url}/{endpoint}?key={self.api_key}"
|
||||
@@ -364,16 +328,11 @@ class AiohttpGeminiClient(BaseClient):
|
||||
try:
|
||||
async with aiohttp.ClientSession(
|
||||
timeout=timeout,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "MMC-AioHTTP-Gemini-Client/1.0"
|
||||
}
|
||||
headers={"Content-Type": "application/json", "User-Agent": "MMC-AioHTTP-Gemini-Client/1.0"},
|
||||
) as session:
|
||||
if method.upper() == "POST":
|
||||
response = await session.post(
|
||||
url,
|
||||
json=data,
|
||||
headers={"Accept": "text/event-stream" if stream else "application/json"}
|
||||
url, json=data, headers={"Accept": "text/event-stream" if stream else "application/json"}
|
||||
)
|
||||
else:
|
||||
response = await session.get(url)
|
||||
@@ -386,7 +345,7 @@ class AiohttpGeminiClient(BaseClient):
|
||||
return response
|
||||
except aiohttp.ClientError as e:
|
||||
raise NetworkConnectionError() from e
|
||||
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -401,9 +360,7 @@ class AiohttpGeminiClient(BaseClient):
|
||||
Coroutine[Any, Any, tuple[APIResponse, Optional[tuple[int, int, int]]]],
|
||||
]
|
||||
] = None,
|
||||
async_response_parser: Optional[
|
||||
Callable[[dict], tuple[APIResponse, Optional[tuple[int, int, int]]]]
|
||||
] = None,
|
||||
async_response_parser: Optional[Callable[[dict], tuple[APIResponse, Optional[tuple[int, int, int]]]]] = None,
|
||||
interrupt_flag: asyncio.Event | None = None,
|
||||
extra_params: dict[str, Any] | None = None,
|
||||
) -> APIResponse:
|
||||
@@ -412,65 +369,57 @@ class AiohttpGeminiClient(BaseClient):
|
||||
"""
|
||||
if stream_response_handler is None:
|
||||
stream_response_handler = _default_stream_response_handler
|
||||
|
||||
|
||||
if async_response_parser is None:
|
||||
async_response_parser = _default_normal_response_parser
|
||||
|
||||
|
||||
# 转换消息格式
|
||||
contents, system_instructions = _convert_messages(message_list)
|
||||
|
||||
|
||||
# 构建请求体
|
||||
request_data = {
|
||||
"contents": contents,
|
||||
"generationConfig": _build_generation_config(
|
||||
max_tokens, temperature, response_format, extra_params
|
||||
)
|
||||
"generationConfig": _build_generation_config(max_tokens, temperature, response_format, extra_params),
|
||||
}
|
||||
|
||||
|
||||
# 添加系统指令
|
||||
if system_instructions:
|
||||
request_data["systemInstruction"] = {
|
||||
"parts": [{"text": instr} for instr in system_instructions]
|
||||
}
|
||||
|
||||
request_data["systemInstruction"] = {"parts": [{"text": instr} for instr in system_instructions]}
|
||||
|
||||
# 添加工具定义
|
||||
if tool_options:
|
||||
request_data["tools"] = _convert_tool_options(tool_options)
|
||||
|
||||
|
||||
try:
|
||||
if model_info.force_stream_mode:
|
||||
# 流式请求
|
||||
endpoint = f"models/{model_info.model_identifier}:streamGenerateContent"
|
||||
req_task = asyncio.create_task(
|
||||
self._make_request("POST", endpoint, request_data, stream=True)
|
||||
)
|
||||
|
||||
req_task = asyncio.create_task(self._make_request("POST", endpoint, request_data, stream=True))
|
||||
|
||||
while not req_task.done():
|
||||
if interrupt_flag and interrupt_flag.is_set():
|
||||
req_task.cancel()
|
||||
raise ReqAbortException("请求被外部信号中断")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
response = req_task.result()
|
||||
api_response, usage_record = await stream_response_handler(response, interrupt_flag)
|
||||
|
||||
|
||||
else:
|
||||
# 普通请求
|
||||
endpoint = f"models/{model_info.model_identifier}:generateContent"
|
||||
req_task = asyncio.create_task(
|
||||
self._make_request("POST", endpoint, request_data)
|
||||
)
|
||||
|
||||
req_task = asyncio.create_task(self._make_request("POST", endpoint, request_data))
|
||||
|
||||
while not req_task.done():
|
||||
if interrupt_flag and interrupt_flag.is_set():
|
||||
req_task.cancel()
|
||||
raise ReqAbortException("请求被外部信号中断")
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
|
||||
response = req_task.result()
|
||||
response_data = await response.json()
|
||||
api_response, usage_record = async_response_parser(response_data)
|
||||
|
||||
|
||||
except (ReqAbortException, NetworkConnectionError, RespNotOkException, RespParseException):
|
||||
# 直接重抛项目定义的异常
|
||||
raise
|
||||
@@ -478,7 +427,7 @@ class AiohttpGeminiClient(BaseClient):
|
||||
logger.debug(e)
|
||||
# 其他异常转换为网络连接错误
|
||||
raise NetworkConnectionError() from e
|
||||
|
||||
|
||||
# 设置使用统计
|
||||
if usage_record:
|
||||
api_response.usage = UsageRecord(
|
||||
@@ -488,9 +437,9 @@ class AiohttpGeminiClient(BaseClient):
|
||||
completion_tokens=usage_record[1],
|
||||
total_tokens=usage_record[2],
|
||||
)
|
||||
|
||||
|
||||
return api_response
|
||||
|
||||
|
||||
async def get_embedding(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -501,7 +450,7 @@ class AiohttpGeminiClient(BaseClient):
|
||||
获取文本嵌入 - 此客户端不支持嵌入功能
|
||||
"""
|
||||
raise NotImplementedError("AioHTTP Gemini客户端不支持文本嵌入功能")
|
||||
|
||||
|
||||
async def get_audio_transcriptions(
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
@@ -512,31 +461,30 @@ class AiohttpGeminiClient(BaseClient):
|
||||
获取音频转录
|
||||
"""
|
||||
# 构建包含音频的内容
|
||||
contents = [{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{"text": "Generate a transcript of the speech. The language of the transcript should match the language of the speech."},
|
||||
{
|
||||
"inline_data": {
|
||||
"mime_type": "audio/wav",
|
||||
"data": audio_base64
|
||||
}
|
||||
}
|
||||
]
|
||||
}]
|
||||
|
||||
contents = [
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"text": "Generate a transcript of the speech. The language of the transcript should match the language of the speech."
|
||||
},
|
||||
{"inline_data": {"mime_type": "audio/wav", "data": audio_base64}},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
request_data = {
|
||||
"contents": contents,
|
||||
"generationConfig": _build_generation_config(2048, 0.1, None, extra_params)
|
||||
"generationConfig": _build_generation_config(2048, 0.1, None, extra_params),
|
||||
}
|
||||
|
||||
|
||||
try:
|
||||
endpoint = f"models/{model_info.model_identifier}:generateContent"
|
||||
response = await self._make_request("POST", endpoint, request_data)
|
||||
response_data = await response.json()
|
||||
|
||||
|
||||
api_response, usage_record = _default_normal_response_parser(response_data)
|
||||
|
||||
|
||||
if usage_record:
|
||||
api_response.usage = UsageRecord(
|
||||
model_name=model_info.name,
|
||||
@@ -545,18 +493,18 @@ class AiohttpGeminiClient(BaseClient):
|
||||
completion_tokens=usage_record[1],
|
||||
total_tokens=usage_record[2],
|
||||
)
|
||||
|
||||
|
||||
return api_response
|
||||
|
||||
|
||||
except (NetworkConnectionError, RespNotOkException, RespParseException):
|
||||
raise
|
||||
except Exception as e:
|
||||
raise NetworkConnectionError() from e
|
||||
|
||||
|
||||
def get_support_image_formats(self) -> list[str]:
|
||||
"""
|
||||
获取支持的图片格式
|
||||
"""
|
||||
return ["png", "jpg", "jpeg", "webp", "heic", "heif"]
|
||||
|
||||
|
||||
# 移除 __aenter__、__aexit__、__del__,不再持有全局 session
|
||||
|
||||
@@ -472,7 +472,7 @@ class OpenaiClient(BaseClient):
|
||||
req_task.cancel()
|
||||
raise ReqAbortException("请求被外部信号中断")
|
||||
await asyncio.sleep(0.1) # 等待0.5秒后再次检查任务&中断信号量状态
|
||||
|
||||
|
||||
# logger.info(f"OpenAI请求时间: {model_info.model_identifier} {time.time() - start_time} \n{messages}")
|
||||
|
||||
resp, usage_record = async_response_parser(req_task.result())
|
||||
@@ -516,7 +516,7 @@ class OpenaiClient(BaseClient):
|
||||
# 添加详细的错误信息以便调试
|
||||
logger.error(f"OpenAI API连接错误(嵌入模型): {str(e)}")
|
||||
logger.error(f"错误类型: {type(e)}")
|
||||
if hasattr(e, '__cause__') and e.__cause__:
|
||||
if hasattr(e, "__cause__") and e.__cause__:
|
||||
logger.error(f"底层错误: {str(e.__cause__)}")
|
||||
raise NetworkConnectionError() from e
|
||||
except APIStatusError as e:
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .tool_option import ToolCall
|
||||
|
||||
__all__ = ["ToolCall"]
|
||||
__all__ = ["ToolCall"]
|
||||
|
||||
@@ -48,8 +48,7 @@ def _json_schema_type_check(instance) -> str | None:
|
||||
elif not isinstance(instance["name"], str) or instance["name"].strip() == "":
|
||||
return "schema的'name'字段必须是非空字符串"
|
||||
if "description" in instance and (
|
||||
not isinstance(instance["description"], str)
|
||||
or instance["description"].strip() == ""
|
||||
not isinstance(instance["description"], str) or instance["description"].strip() == ""
|
||||
):
|
||||
return "schema的'description'字段只能填入非空字符串"
|
||||
if "schema" not in instance:
|
||||
@@ -101,9 +100,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
# 如果当前Schema是列表,则遍历每个元素
|
||||
for i in range(len(sub_schema)):
|
||||
if isinstance(sub_schema[i], dict):
|
||||
sub_schema[i] = link_definitions_recursive(
|
||||
f"{path}/{str(i)}", sub_schema[i], defs
|
||||
)
|
||||
sub_schema[i] = link_definitions_recursive(f"{path}/{str(i)}", sub_schema[i], defs)
|
||||
else:
|
||||
# 否则为字典
|
||||
if "$defs" in sub_schema:
|
||||
@@ -125,9 +122,7 @@ def _link_definitions(schema: dict[str, Any]) -> dict[str, Any]:
|
||||
for key, value in sub_schema.items():
|
||||
if isinstance(value, (dict, list)):
|
||||
# 如果当前值是字典或列表,则递归调用
|
||||
sub_schema[key] = link_definitions_recursive(
|
||||
f"{path}/{key}", value, defs
|
||||
)
|
||||
sub_schema[key] = link_definitions_recursive(f"{path}/{key}", value, defs)
|
||||
|
||||
return sub_schema
|
||||
|
||||
@@ -163,9 +158,7 @@ class RespFormat:
|
||||
def _generate_schema_from_model(schema):
|
||||
json_schema = {
|
||||
"name": schema.__name__,
|
||||
"schema": _remove_defs(
|
||||
_link_definitions(_remove_title(schema.model_json_schema()))
|
||||
),
|
||||
"schema": _remove_defs(_link_definitions(_remove_title(schema.model_json_schema()))),
|
||||
"strict": False,
|
||||
}
|
||||
if schema.__doc__:
|
||||
|
||||
@@ -145,37 +145,42 @@ class LLMUsageRecorder:
|
||||
LLM使用情况记录器(SQLAlchemy版本)
|
||||
"""
|
||||
|
||||
|
||||
def record_usage_to_database(
|
||||
self, model_info: ModelInfo, model_usage: UsageRecord, user_id: str, request_type: str, endpoint: str, time_cost: float = 0.0
|
||||
self,
|
||||
model_info: ModelInfo,
|
||||
model_usage: UsageRecord,
|
||||
user_id: str,
|
||||
request_type: str,
|
||||
endpoint: str,
|
||||
time_cost: float = 0.0,
|
||||
):
|
||||
input_cost = (model_usage.prompt_tokens / 1000000) * model_info.price_in
|
||||
output_cost = (model_usage.completion_tokens / 1000000) * model_info.price_out
|
||||
total_cost = round(input_cost + output_cost, 6)
|
||||
|
||||
|
||||
session = None
|
||||
try:
|
||||
# 使用 SQLAlchemy 会话创建记录
|
||||
with get_db_session() as session:
|
||||
usage_record = LLMUsage(
|
||||
model_name=model_info.model_identifier,
|
||||
model_assign_name=model_info.name,
|
||||
model_api_provider=model_info.api_provider,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=model_usage.prompt_tokens or 0,
|
||||
completion_tokens=model_usage.completion_tokens or 0,
|
||||
total_tokens=model_usage.total_tokens or 0,
|
||||
cost=total_cost or 0.0,
|
||||
time_cost = round(time_cost or 0.0, 3),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
|
||||
)
|
||||
|
||||
model_name=model_info.model_identifier,
|
||||
model_assign_name=model_info.name,
|
||||
model_api_provider=model_info.api_provider,
|
||||
user_id=user_id,
|
||||
request_type=request_type,
|
||||
endpoint=endpoint,
|
||||
prompt_tokens=model_usage.prompt_tokens or 0,
|
||||
completion_tokens=model_usage.completion_tokens or 0,
|
||||
total_tokens=model_usage.total_tokens or 0,
|
||||
cost=total_cost or 0.0,
|
||||
time_cost=round(time_cost or 0.0, 3),
|
||||
status="success",
|
||||
timestamp=datetime.now(), # SQLAlchemy 会处理 DateTime 字段
|
||||
)
|
||||
|
||||
session.add(usage_record)
|
||||
session.commit()
|
||||
|
||||
|
||||
logger.debug(
|
||||
f"Token使用情况 - 模型: {model_usage.model_name}, "
|
||||
f"用户: {user_id}, 类型: {request_type}, "
|
||||
@@ -186,4 +191,4 @@ class LLMUsageRecorder:
|
||||
logger.error(f"记录token使用情况失败: {str(e)}")
|
||||
|
||||
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
llm_usage_recorder = LLMUsageRecorder()
|
||||
|
||||
@@ -37,16 +37,16 @@ error_code_mapping = {
|
||||
def _normalize_image_format(image_format: str) -> str:
|
||||
"""
|
||||
标准化图片格式名称,确保与各种API的兼容性
|
||||
|
||||
|
||||
Args:
|
||||
image_format (str): 原始图片格式
|
||||
|
||||
|
||||
Returns:
|
||||
str: 标准化后的图片格式
|
||||
"""
|
||||
format_mapping = {
|
||||
"jpg": "jpeg",
|
||||
"JPG": "jpeg",
|
||||
"JPG": "jpeg",
|
||||
"JPEG": "jpeg",
|
||||
"jpeg": "jpeg",
|
||||
"png": "png",
|
||||
@@ -58,9 +58,9 @@ def _normalize_image_format(image_format: str) -> str:
|
||||
"heic": "heic",
|
||||
"HEIC": "heic",
|
||||
"heif": "heif",
|
||||
"HEIF": "heif"
|
||||
"HEIF": "heif",
|
||||
}
|
||||
|
||||
|
||||
normalized = format_mapping.get(image_format, image_format.lower())
|
||||
logger.debug(f"图片格式标准化: {image_format} -> {normalized}")
|
||||
return normalized
|
||||
@@ -109,8 +109,8 @@ async def execute_concurrently(
|
||||
# 如果所有请求都失败了,记录所有异常并抛出第一个
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"并发任务 {i+1}/{concurrency_count} 失败: {res}")
|
||||
|
||||
logger.error(f"并发任务 {i + 1}/{concurrency_count} 失败: {res}")
|
||||
|
||||
first_exception = next((res for res in results if isinstance(res, Exception)), None)
|
||||
if first_exception:
|
||||
raise first_exception
|
||||
@@ -129,7 +129,7 @@ class LLMRequest:
|
||||
model: (0, 0, 0) for model in self.model_for_task.model_list
|
||||
}
|
||||
"""模型使用量记录,用于进行负载均衡,对应为(total_tokens, penalty, usage_penalty),惩罚值是为了能在某个模型请求不给力或正在被使用的时候进行调整"""
|
||||
|
||||
|
||||
# 内容混淆过滤指令
|
||||
self.noise_instruction = """**【核心过滤规则】**
|
||||
在交互过程中,你会在用户的输入内容中看到一些无意义的、由字母和数字组成的随机字符串。这些字符串是API传输过程中的技术伪装信息,【不是】用户真实意图的一部分。
|
||||
@@ -137,7 +137,7 @@ class LLMRequest:
|
||||
你的任务是【完全并彻底地忽略】这些随机字符串。
|
||||
**【严禁】** 将这些字符串写入你的回复、对它们进行任何形式的解读、或将它们作为故事的一部分进行模拟。当看到这些字符串时,请将其视为空白或直接跳过,并根据用户其他有意义的输入(如有)或当前故事情节来继续你的叙事。此规则为最高优先级。
|
||||
"""
|
||||
|
||||
|
||||
# 反截断指令
|
||||
self.end_marker = "###MAI_RESPONSE_END###"
|
||||
self.anti_truncation_instruction = f"""
|
||||
@@ -169,7 +169,7 @@ class LLMRequest:
|
||||
"""
|
||||
# 标准化图片格式以确保API兼容性
|
||||
normalized_format = _normalize_image_format(image_format)
|
||||
|
||||
|
||||
# 模型选择
|
||||
start_time = time.time()
|
||||
model_info, api_provider, client = self._select_model()
|
||||
@@ -178,7 +178,9 @@ class LLMRequest:
|
||||
message_builder = MessageBuilder()
|
||||
message_builder.add_text_content(prompt)
|
||||
message_builder.add_image_content(
|
||||
image_base64=image_base64, image_format=normalized_format, support_formats=client.get_support_image_formats()
|
||||
image_base64=image_base64,
|
||||
image_format=normalized_format,
|
||||
support_formats=client.get_support_image_formats(),
|
||||
)
|
||||
messages = [message_builder.build()]
|
||||
|
||||
@@ -296,7 +298,7 @@ class LLMRequest:
|
||||
for model_info, api_provider, client in model_scheduler:
|
||||
start_time = time.time()
|
||||
model_name = model_info.name
|
||||
logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏
|
||||
logger.debug(f"正在尝试使用模型: {model_name}") # 你不许刷屏
|
||||
|
||||
try:
|
||||
# 检查是否启用反截断
|
||||
@@ -306,7 +308,7 @@ class LLMRequest:
|
||||
if use_anti_truncation:
|
||||
processed_prompt += self.anti_truncation_instruction
|
||||
logger.info(f"模型 '{model_name}' (任务: '{self.task_name}') 已启用反截断功能。")
|
||||
|
||||
|
||||
processed_prompt = self._apply_content_obfuscation(processed_prompt, api_provider)
|
||||
|
||||
message_builder = MessageBuilder()
|
||||
@@ -351,7 +353,9 @@ class LLMRequest:
|
||||
empty_retry_count += 1
|
||||
if empty_retry_count <= max_empty_retry:
|
||||
reason = "空回复" if is_empty_reply else "截断"
|
||||
logger.warning(f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成...")
|
||||
logger.warning(
|
||||
f"模型 '{model_name}' 检测到{reason},正在进行第 {empty_retry_count}/{max_empty_retry} 次重新生成..."
|
||||
)
|
||||
if empty_retry_interval > 0:
|
||||
await asyncio.sleep(empty_retry_interval)
|
||||
continue # 继续使用当前模型重试
|
||||
@@ -364,16 +368,20 @@ class LLMRequest:
|
||||
# 成功获取响应
|
||||
if usage := response.usage:
|
||||
llm_usage_recorder.record_usage_to_database(
|
||||
model_info=model_info, model_usage=usage, time_cost=time.time() - start_time,
|
||||
user_id="system", request_type=self.request_type, endpoint="/chat/completions",
|
||||
model_info=model_info,
|
||||
model_usage=usage,
|
||||
time_cost=time.time() - start_time,
|
||||
user_id="system",
|
||||
request_type=self.request_type,
|
||||
endpoint="/chat/completions",
|
||||
)
|
||||
|
||||
if not content and not tool_calls:
|
||||
if raise_when_empty:
|
||||
raise RuntimeError("生成空回复")
|
||||
content = "生成的响应为空"
|
||||
|
||||
logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏
|
||||
|
||||
logger.debug(f"模型 '{model_name}' 成功生成回复。") # 你也不许刷屏
|
||||
return content, (reasoning_content, model_name, tool_calls)
|
||||
|
||||
except RespNotOkException as e:
|
||||
@@ -381,7 +389,7 @@ class LLMRequest:
|
||||
logger.error(f"模型 '{model_name}' 遇到认证/权限错误 (Code: {e.status_code}),将尝试下一个模型。")
|
||||
failed_models.add(model_name)
|
||||
last_exception = e
|
||||
continue # 切换到下一个模型
|
||||
continue # 切换到下一个模型
|
||||
else:
|
||||
logger.error(f"模型 '{model_name}' 请求失败,HTTP状态码: {e.status_code}")
|
||||
if raise_when_empty:
|
||||
@@ -394,13 +402,13 @@ class LLMRequest:
|
||||
logger.error(f"模型 '{model_name}' 在所有重试后仍然失败: {e},将尝试下一个模型。")
|
||||
failed_models.add(model_name)
|
||||
last_exception = e
|
||||
continue # 切换到下一个模型
|
||||
continue # 切换到下一个模型
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"使用模型 '{model_name}' 时发生未知异常: {e}")
|
||||
failed_models.add(model_name)
|
||||
last_exception = e
|
||||
continue # 切换到下一个模型
|
||||
continue # 切换到下一个模型
|
||||
|
||||
# 所有模型都尝试失败
|
||||
logger.error("所有可用模型都已尝试失败。")
|
||||
@@ -408,7 +416,7 @@ class LLMRequest:
|
||||
if last_exception:
|
||||
raise RuntimeError("所有模型都请求失败") from last_exception
|
||||
raise RuntimeError("所有模型都请求失败,且没有具体的异常信息")
|
||||
|
||||
|
||||
return "所有模型都请求失败", ("", "unknown", None)
|
||||
|
||||
async def get_embedding(self, embedding_input: str) -> Tuple[List[float], str]:
|
||||
@@ -455,12 +463,12 @@ class LLMRequest:
|
||||
for model_name in self.model_for_task.model_list:
|
||||
if model_name in failed_models:
|
||||
continue
|
||||
|
||||
|
||||
model_info = model_config.get_model_info(model_name)
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
force_new_client = (self.request_type == "embedding")
|
||||
force_new_client = self.request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
|
||||
|
||||
yield model_info, api_provider, client
|
||||
|
||||
def _select_model(self) -> Tuple[ModelInfo, APIProvider, BaseClient]:
|
||||
@@ -475,7 +483,7 @@ class LLMRequest:
|
||||
api_provider = model_config.get_provider(model_info.api_provider)
|
||||
|
||||
# 对于嵌入任务,强制创建新的客户端实例以避免事件循环问题
|
||||
force_new_client = (self.request_type == "embedding")
|
||||
force_new_client = self.request_type == "embedding"
|
||||
client = client_registry.get_client_class_instance(api_provider, force_new=force_new_client)
|
||||
logger.debug(f"选择请求模型: {model_info.name}")
|
||||
total_tokens, penalty, usage_penalty = self.model_usage[model_info.name]
|
||||
@@ -690,9 +698,11 @@ class LLMRequest:
|
||||
for i, m_name in enumerate(self.model_for_task.model_list):
|
||||
if m_name == old_model_name:
|
||||
self.model_for_task.model_list[i] = new_model_name
|
||||
logger.warning(f"将任务 {self.task_name} 的模型列表中的 {old_model_name} 临时降级至 {new_model_name}")
|
||||
logger.warning(
|
||||
f"将任务 {self.task_name} 的模型列表中的 {old_model_name} 临时降级至 {new_model_name}"
|
||||
)
|
||||
break
|
||||
return 0, None # 立即重试
|
||||
return 0, None # 立即重试
|
||||
# 客户端错误
|
||||
logger.warning(
|
||||
f"任务-'{task_name}' 模型-'{model_name}': 请求失败,错误代码-{e.status_code},错误信息-{e.message}"
|
||||
@@ -782,55 +792,55 @@ class LLMRequest:
|
||||
|
||||
def _apply_content_obfuscation(self, text: str, api_provider) -> str:
|
||||
"""根据API提供商配置对文本进行混淆处理"""
|
||||
if not hasattr(api_provider, 'enable_content_obfuscation') or not api_provider.enable_content_obfuscation:
|
||||
if not hasattr(api_provider, "enable_content_obfuscation") or not api_provider.enable_content_obfuscation:
|
||||
logger.debug(f"API提供商 '{api_provider.name}' 未启用内容混淆")
|
||||
return text
|
||||
|
||||
intensity = getattr(api_provider, 'obfuscation_intensity', 1)
|
||||
|
||||
intensity = getattr(api_provider, "obfuscation_intensity", 1)
|
||||
logger.info(f"为API提供商 '{api_provider.name}' 启用内容混淆,强度级别: {intensity}")
|
||||
|
||||
|
||||
# 在开头加入过滤规则指令
|
||||
processed_text = self.noise_instruction + "\n\n" + text
|
||||
logger.debug(f"已添加过滤规则指令,文本长度: {len(text)} -> {len(processed_text)}")
|
||||
|
||||
|
||||
# 添加随机乱码
|
||||
final_text = self._inject_random_noise(processed_text, intensity)
|
||||
logger.debug(f"乱码注入完成,最终文本长度: {len(final_text)}")
|
||||
|
||||
|
||||
return final_text
|
||||
|
||||
|
||||
def _inject_random_noise(self, text: str, intensity: int) -> str:
|
||||
"""在文本中注入随机乱码"""
|
||||
import random
|
||||
import string
|
||||
|
||||
|
||||
def generate_noise(length: int) -> str:
|
||||
"""生成指定长度的随机乱码字符"""
|
||||
chars = (
|
||||
string.ascii_letters + # a-z, A-Z
|
||||
string.digits + # 0-9
|
||||
'!@#$%^&*()_+-=[]{}|;:,.<>?' + # 特殊符号
|
||||
'一二三四五六七八九零壹贰叁' + # 中文字符
|
||||
'αβγδεζηθικλμνξοπρστυφχψω' + # 希腊字母
|
||||
'∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵' # 数学符号
|
||||
string.ascii_letters # a-z, A-Z
|
||||
+ string.digits # 0-9
|
||||
+ "!@#$%^&*()_+-=[]{}|;:,.<>?" # 特殊符号
|
||||
+ "一二三四五六七八九零壹贰叁" # 中文字符
|
||||
+ "αβγδεζηθικλμνξοπρστυφχψω" # 希腊字母
|
||||
+ "∀∃∈∉∪∩⊂⊃∧∨¬→↔∴∵" # 数学符号
|
||||
)
|
||||
return ''.join(random.choice(chars) for _ in range(length))
|
||||
|
||||
return "".join(random.choice(chars) for _ in range(length))
|
||||
|
||||
# 强度参数映射
|
||||
params = {
|
||||
1: {"probability": 15, "length": (3, 6)}, # 低强度:15%概率,3-6个字符
|
||||
2: {"probability": 25, "length": (5, 10)}, # 中强度:25%概率,5-10个字符
|
||||
3: {"probability": 35, "length": (8, 15)} # 高强度:35%概率,8-15个字符
|
||||
1: {"probability": 15, "length": (3, 6)}, # 低强度:15%概率,3-6个字符
|
||||
2: {"probability": 25, "length": (5, 10)}, # 中强度:25%概率,5-10个字符
|
||||
3: {"probability": 35, "length": (8, 15)}, # 高强度:35%概率,8-15个字符
|
||||
}
|
||||
|
||||
|
||||
config = params.get(intensity, params[1])
|
||||
logger.debug(f"乱码注入参数: 概率={config['probability']}%, 长度范围={config['length']}")
|
||||
|
||||
|
||||
# 按词分割处理
|
||||
words = text.split()
|
||||
result = []
|
||||
noise_count = 0
|
||||
|
||||
|
||||
for word in words:
|
||||
result.append(word)
|
||||
# 根据概率插入乱码
|
||||
@@ -839,6 +849,6 @@ class LLMRequest:
|
||||
noise = generate_noise(noise_length)
|
||||
result.append(noise)
|
||||
noise_count += 1
|
||||
|
||||
|
||||
logger.debug(f"共注入 {noise_count} 个乱码片段,原词数: {len(words)}")
|
||||
return ' '.join(result)
|
||||
return " ".join(result)
|
||||
|
||||
Reference in New Issue
Block a user