refactor(models):统一请求处理并优化响应处理 (refactor/unified_request)
对 `utils_model.py` 中的请求处理逻辑进行重构,创建统一的请求执行方法 `_execute_request`。该方法集中处理请求构建、重试逻辑和响应处理,替代了 `generate_response`、`generate_response_for_image` 和 `generate_response_async` 中的冗余代码。 关键变更: - 引入 `_execute_request` 作为 API 请求的单一入口 - 新增支持自定义重试策略和响应处理器 - 通过 `_build_payload` 简化图像和文本载荷构建 - 改进错误处理和日志记录 - 移除已弃用的同步方法 - 加入了`max_response_length`以兼容koboldcpp硬编码的默认值500 此次重构在保持现有功能的同时提高了代码可维护性,减少了重复代码
This commit is contained in:
@@ -28,6 +28,7 @@ enable_pic_translate = false
|
|||||||
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
|
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
|
||||||
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
|
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
|
||||||
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
|
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
|
||||||
|
max_response_length = 1024 # 麦麦回答的最大token数
|
||||||
|
|
||||||
[memory]
|
[memory]
|
||||||
build_memory_interval = 300 # 记忆构建间隔 单位秒
|
build_memory_interval = 300 # 记忆构建间隔 单位秒
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ class BotConfig:
|
|||||||
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
|
||||||
|
|
||||||
ban_words = set()
|
ban_words = set()
|
||||||
|
|
||||||
|
max_response_length: int = 1024 # 最大回复长度
|
||||||
|
|
||||||
# 模型配置
|
# 模型配置
|
||||||
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
|
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
|
||||||
@@ -113,6 +115,7 @@ class BotConfig:
|
|||||||
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
|
||||||
config.API_USING = response_config.get("api_using", config.API_USING)
|
config.API_USING = response_config.get("api_using", config.API_USING)
|
||||||
config.API_PAID = response_config.get("api_paid", config.API_PAID)
|
config.API_PAID = response_config.get("api_paid", config.API_PAID)
|
||||||
|
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
|
||||||
|
|
||||||
# 加载模型配置
|
# 加载模型配置
|
||||||
if "model" in toml_dict:
|
if "model" in toml_dict:
|
||||||
|
|||||||
@@ -64,15 +64,15 @@ class CQCode:
|
|||||||
"""初始化LLM实例"""
|
"""初始化LLM实例"""
|
||||||
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
|
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
|
||||||
|
|
||||||
def translate(self):
|
async def translate(self):
|
||||||
"""根据CQ码类型进行相应的翻译处理"""
|
"""根据CQ码类型进行相应的翻译处理"""
|
||||||
if self.type == 'text':
|
if self.type == 'text':
|
||||||
self.translated_plain_text = self.params.get('text', '')
|
self.translated_plain_text = self.params.get('text', '')
|
||||||
elif self.type == 'image':
|
elif self.type == 'image':
|
||||||
if self.params.get('sub_type') == '0':
|
if self.params.get('sub_type') == '0':
|
||||||
self.translated_plain_text = self.translate_image()
|
self.translated_plain_text = await self.translate_image()
|
||||||
else:
|
else:
|
||||||
self.translated_plain_text = self.translate_emoji()
|
self.translated_plain_text = await self.translate_emoji()
|
||||||
elif self.type == 'at':
|
elif self.type == 'at':
|
||||||
user_nickname = get_user_nickname(self.params.get('qq', ''))
|
user_nickname = get_user_nickname(self.params.get('qq', ''))
|
||||||
if user_nickname:
|
if user_nickname:
|
||||||
@@ -158,7 +158,7 @@ class CQCode:
|
|||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def translate_emoji(self) -> str:
|
async def translate_emoji(self) -> str:
|
||||||
"""处理表情包类型的CQ码"""
|
"""处理表情包类型的CQ码"""
|
||||||
if 'url' not in self.params:
|
if 'url' not in self.params:
|
||||||
return '[表情包]'
|
return '[表情包]'
|
||||||
@@ -167,12 +167,12 @@ class CQCode:
|
|||||||
# 将 base64 字符串转换为字节类型
|
# 将 base64 字符串转换为字节类型
|
||||||
image_bytes = base64.b64decode(base64_str)
|
image_bytes = base64.b64decode(base64_str)
|
||||||
storage_emoji(image_bytes)
|
storage_emoji(image_bytes)
|
||||||
return self.get_emoji_description(base64_str)
|
return await self.get_emoji_description(base64_str)
|
||||||
else:
|
else:
|
||||||
return '[表情包]'
|
return '[表情包]'
|
||||||
|
|
||||||
|
|
||||||
def translate_image(self) -> str:
|
async def translate_image(self) -> str:
|
||||||
"""处理图片类型的CQ码,区分普通图片和表情包"""
|
"""处理图片类型的CQ码,区分普通图片和表情包"""
|
||||||
#没有url,直接返回默认文本
|
#没有url,直接返回默认文本
|
||||||
if 'url' not in self.params:
|
if 'url' not in self.params:
|
||||||
@@ -181,25 +181,27 @@ class CQCode:
|
|||||||
if base64_str:
|
if base64_str:
|
||||||
image_bytes = base64.b64decode(base64_str)
|
image_bytes = base64.b64decode(base64_str)
|
||||||
storage_image(image_bytes)
|
storage_image(image_bytes)
|
||||||
return self.get_image_description(base64_str)
|
return await self.get_image_description(base64_str)
|
||||||
else:
|
else:
|
||||||
return '[图片]'
|
return '[图片]'
|
||||||
|
|
||||||
def get_emoji_description(self, image_base64: str) -> str:
|
async def get_emoji_description(self, image_base64: str) -> str:
|
||||||
"""调用AI接口获取表情包描述"""
|
"""调用AI接口获取表情包描述"""
|
||||||
try:
|
try:
|
||||||
prompt = "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
|
prompt = "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
|
||||||
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
||||||
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||||
return f"[表情包:{description}]"
|
return f"[表情包:{description}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
||||||
return "[表情包]"
|
return "[表情包]"
|
||||||
|
|
||||||
def get_image_description(self, image_base64: str) -> str:
|
async def get_image_description(self, image_base64: str) -> str:
|
||||||
"""调用AI接口获取普通图片描述"""
|
"""调用AI接口获取普通图片描述"""
|
||||||
try:
|
try:
|
||||||
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
|
||||||
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
# description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
|
||||||
|
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
|
||||||
return f"[图片:{description}]"
|
return f"[图片:{description}]"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import time
|
|||||||
import random
|
import random
|
||||||
from ..schedule.schedule_generator import bot_schedule
|
from ..schedule.schedule_generator import bot_schedule
|
||||||
import os
|
import os
|
||||||
from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text,find_similar_topics
|
from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text
|
||||||
from ...common.database import Database
|
from ...common.database import Database
|
||||||
from .config import global_config
|
from .config import global_config
|
||||||
from .topic_identifier import topic_identifier
|
from .topic_identifier import topic_identifier
|
||||||
@@ -60,7 +60,7 @@ class PromptBuilder:
|
|||||||
|
|
||||||
prompt_info = ''
|
prompt_info = ''
|
||||||
promt_info_prompt = ''
|
promt_info_prompt = ''
|
||||||
prompt_info = self.get_prompt_info(message_txt,threshold=0.5)
|
prompt_info = await self.get_prompt_info(message_txt,threshold=0.5)
|
||||||
if prompt_info:
|
if prompt_info:
|
||||||
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
|
||||||
|
|
||||||
@@ -214,10 +214,10 @@ class PromptBuilder:
|
|||||||
return prompt_for_initiative
|
return prompt_for_initiative
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_info(self,message:str,threshold:float):
|
async def get_prompt_info(self,message:str,threshold:float):
|
||||||
related_info = ''
|
related_info = ''
|
||||||
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
|
||||||
embedding = get_embedding(message)
|
embedding = await get_embedding(message)
|
||||||
related_info += self.get_info_from_db(embedding,threshold=threshold)
|
related_info += self.get_info_from_db(embedding,threshold=threshold)
|
||||||
|
|
||||||
return related_info
|
return related_info
|
||||||
|
|||||||
@@ -32,16 +32,18 @@ def combine_messages(messages: List[Message]) -> str:
|
|||||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
|
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
|
||||||
name = message.user_nickname or f"用户{message.user_id}"
|
name = message.user_nickname or f"用户{message.user_id}"
|
||||||
content = message.processed_plain_text or message.plain_text
|
content = message.processed_plain_text or message.plain_text
|
||||||
|
|
||||||
result += f"[{time_str}] {name}: {content}\n"
|
result += f"[{time_str}] {name}: {content}\n"
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def db_message_to_str (message_dict: Dict) -> str:
|
|
||||||
|
def db_message_to_str(message_dict: Dict) -> str:
|
||||||
print(f"message_dict: {message_dict}")
|
print(f"message_dict: {message_dict}")
|
||||||
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
|
||||||
try:
|
try:
|
||||||
name="[(%s)%s]%s" % (message_dict['user_id'],message_dict.get("user_nickname", ""),message_dict.get("user_cardname", ""))
|
name = "[(%s)%s]%s" % (
|
||||||
|
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
|
||||||
except:
|
except:
|
||||||
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
|
||||||
content = message_dict.get("processed_plain_text", "")
|
content = message_dict.get("processed_plain_text", "")
|
||||||
@@ -58,6 +60,7 @@ def is_mentioned_bot_in_message(message: Message) -> bool:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def is_mentioned_bot_in_txt(message: str) -> bool:
|
def is_mentioned_bot_in_txt(message: str) -> bool:
|
||||||
"""检查消息是否提到了机器人"""
|
"""检查消息是否提到了机器人"""
|
||||||
keywords = [global_config.BOT_NICKNAME]
|
keywords = [global_config.BOT_NICKNAME]
|
||||||
@@ -66,10 +69,13 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_embedding(text):
|
|
||||||
|
async def get_embedding(text):
|
||||||
"""获取文本的embedding向量"""
|
"""获取文本的embedding向量"""
|
||||||
llm = LLM_request(model=global_config.embedding)
|
llm = LLM_request(model=global_config.embedding)
|
||||||
return llm.get_embedding_sync(text)
|
# return llm.get_embedding_sync(text)
|
||||||
|
return await llm.get_embedding(text)
|
||||||
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2):
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
@@ -77,51 +83,54 @@ def cosine_similarity(v1, v2):
|
|||||||
norm2 = np.linalg.norm(v2)
|
norm2 = np.linalg.norm(v2)
|
||||||
return dot_product / (norm1 * norm2)
|
return dot_product / (norm1 * norm2)
|
||||||
|
|
||||||
|
|
||||||
def calculate_information_content(text):
|
def calculate_information_content(text):
|
||||||
"""计算文本的信息量(熵)"""
|
"""计算文本的信息量(熵)"""
|
||||||
char_count = Counter(text)
|
char_count = Counter(text)
|
||||||
total_chars = len(text)
|
total_chars = len(text)
|
||||||
|
|
||||||
entropy = 0
|
entropy = 0
|
||||||
for count in char_count.values():
|
for count in char_count.values():
|
||||||
probability = count / total_chars
|
probability = count / total_chars
|
||||||
entropy -= probability * math.log2(probability)
|
entropy -= probability * math.log2(probability)
|
||||||
|
|
||||||
return entropy
|
return entropy
|
||||||
|
|
||||||
|
|
||||||
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
def get_cloest_chat_from_db(db, length: int, timestamp: str):
|
||||||
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
|
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
|
||||||
chat_text = ''
|
chat_text = ''
|
||||||
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
|
||||||
|
|
||||||
if closest_record and closest_record.get('memorized', 0) < 4:
|
if closest_record and closest_record.get('memorized', 0) < 4:
|
||||||
closest_time = closest_record['time']
|
closest_time = closest_record['time']
|
||||||
group_id = closest_record['group_id'] # 获取groupid
|
group_id = closest_record['group_id'] # 获取groupid
|
||||||
# 获取该时间戳之后的length条消息,且groupid相同
|
# 获取该时间戳之后的length条消息,且groupid相同
|
||||||
chat_records = list(db.db.messages.find(
|
chat_records = list(db.db.messages.find(
|
||||||
{"time": {"$gt": closest_time}, "group_id": group_id}
|
{"time": {"$gt": closest_time}, "group_id": group_id}
|
||||||
).sort('time', 1).limit(length))
|
).sort('time', 1).limit(length))
|
||||||
|
|
||||||
# 更新每条消息的memorized属性
|
# 更新每条消息的memorized属性
|
||||||
for record in chat_records:
|
for record in chat_records:
|
||||||
# 检查当前记录的memorized值
|
# 检查当前记录的memorized值
|
||||||
current_memorized = record.get('memorized', 0)
|
current_memorized = record.get('memorized', 0)
|
||||||
if current_memorized > 3:
|
if current_memorized > 3:
|
||||||
# print(f"消息已读取3次,跳过")
|
# print(f"消息已读取3次,跳过")
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
# 更新memorized值
|
# 更新memorized值
|
||||||
db.db.messages.update_one(
|
db.db.messages.update_one(
|
||||||
{"_id": record["_id"]},
|
{"_id": record["_id"]},
|
||||||
{"$set": {"memorized": current_memorized + 1}}
|
{"$set": {"memorized": current_memorized + 1}}
|
||||||
)
|
)
|
||||||
|
|
||||||
chat_text += record["detailed_plain_text"]
|
chat_text += record["detailed_plain_text"]
|
||||||
|
|
||||||
return chat_text
|
return chat_text
|
||||||
# print(f"消息已读取3次,跳过")
|
# print(f"消息已读取3次,跳过")
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
|
||||||
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
||||||
"""从数据库获取群组最近的消息记录
|
"""从数据库获取群组最近的消息记录
|
||||||
|
|
||||||
@@ -134,7 +143,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
list: Message对象列表,按时间正序排列
|
list: Message对象列表,按时间正序排列
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# 从数据库获取最近消息
|
# 从数据库获取最近消息
|
||||||
recent_messages = list(db.db.messages.find(
|
recent_messages = list(db.db.messages.find(
|
||||||
{"group_id": group_id},
|
{"group_id": group_id},
|
||||||
# {
|
# {
|
||||||
@@ -149,7 +158,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
|
|
||||||
if not recent_messages:
|
if not recent_messages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
# 转换为 Message对象列表
|
# 转换为 Message对象列表
|
||||||
from .message import Message
|
from .message import Message
|
||||||
message_objects = []
|
message_objects = []
|
||||||
@@ -168,12 +177,13 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
print("[WARNING] 数据库中存在无效的消息")
|
print("[WARNING] 数据库中存在无效的消息")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# 按时间正序排列
|
# 按时间正序排列
|
||||||
message_objects.reverse()
|
message_objects.reverse()
|
||||||
return message_objects
|
return message_objects
|
||||||
|
|
||||||
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,combine = False):
|
|
||||||
|
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
|
||||||
recent_messages = list(db.db.messages.find(
|
recent_messages = list(db.db.messages.find(
|
||||||
{"group_id": group_id},
|
{"group_id": group_id},
|
||||||
{
|
{
|
||||||
@@ -187,16 +197,16 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
|
|||||||
|
|
||||||
if not recent_messages:
|
if not recent_messages:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
message_detailed_plain_text = ''
|
message_detailed_plain_text = ''
|
||||||
message_detailed_plain_text_list = []
|
message_detailed_plain_text_list = []
|
||||||
|
|
||||||
# 反转消息列表,使最新的消息在最后
|
# 反转消息列表,使最新的消息在最后
|
||||||
recent_messages.reverse()
|
recent_messages.reverse()
|
||||||
|
|
||||||
if combine:
|
if combine:
|
||||||
for msg_db_data in recent_messages:
|
for msg_db_data in recent_messages:
|
||||||
message_detailed_plain_text+=str(msg_db_data["detailed_plain_text"])
|
message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
|
||||||
return message_detailed_plain_text
|
return message_detailed_plain_text
|
||||||
else:
|
else:
|
||||||
for msg_db_data in recent_messages:
|
for msg_db_data in recent_messages:
|
||||||
@@ -204,7 +214,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
|
|||||||
return message_detailed_plain_text_list
|
return message_detailed_plain_text_list
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
||||||
"""将文本分割成句子,但保持书名号中的内容完整
|
"""将文本分割成句子,但保持书名号中的内容完整
|
||||||
Args:
|
Args:
|
||||||
@@ -224,30 +233,30 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
|||||||
split_strength = 0.7
|
split_strength = 0.7
|
||||||
else:
|
else:
|
||||||
split_strength = 0.9
|
split_strength = 0.9
|
||||||
#先移除换行符
|
# 先移除换行符
|
||||||
# print(f"split_strength: {split_strength}")
|
# print(f"split_strength: {split_strength}")
|
||||||
|
|
||||||
# print(f"处理前的文本: {text}")
|
# print(f"处理前的文本: {text}")
|
||||||
|
|
||||||
# 统一将英文逗号转换为中文逗号
|
# 统一将英文逗号转换为中文逗号
|
||||||
text = text.replace(',', ',')
|
text = text.replace(',', ',')
|
||||||
text = text.replace('\n', ' ')
|
text = text.replace('\n', ' ')
|
||||||
|
|
||||||
# print(f"处理前的文本: {text}")
|
# print(f"处理前的文本: {text}")
|
||||||
|
|
||||||
text_no_1 = ''
|
text_no_1 = ''
|
||||||
for letter in text:
|
for letter in text:
|
||||||
# print(f"当前字符: {letter}")
|
# print(f"当前字符: {letter}")
|
||||||
if letter in ['!','!','?','?']:
|
if letter in ['!', '!', '?', '?']:
|
||||||
# print(f"当前字符: {letter}, 随机数: {random.random()}")
|
# print(f"当前字符: {letter}, 随机数: {random.random()}")
|
||||||
if random.random() < split_strength:
|
if random.random() < split_strength:
|
||||||
letter = ''
|
letter = ''
|
||||||
if letter in ['。','…']:
|
if letter in ['。', '…']:
|
||||||
# print(f"当前字符: {letter}, 随机数: {random.random()}")
|
# print(f"当前字符: {letter}, 随机数: {random.random()}")
|
||||||
if random.random() < 1 - split_strength:
|
if random.random() < 1 - split_strength:
|
||||||
letter = ''
|
letter = ''
|
||||||
text_no_1 += letter
|
text_no_1 += letter
|
||||||
|
|
||||||
# 对每个逗号单独判断是否分割
|
# 对每个逗号单独判断是否分割
|
||||||
sentences = [text_no_1]
|
sentences = [text_no_1]
|
||||||
new_sentences = []
|
new_sentences = []
|
||||||
@@ -276,15 +285,16 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
|
|||||||
sentences_done = []
|
sentences_done = []
|
||||||
for sentence in sentences:
|
for sentence in sentences:
|
||||||
sentence = sentence.rstrip(',,')
|
sentence = sentence.rstrip(',,')
|
||||||
if random.random() < split_strength*0.5:
|
if random.random() < split_strength * 0.5:
|
||||||
sentence = sentence.replace(',', '').replace(',', '')
|
sentence = sentence.replace(',', '').replace(',', '')
|
||||||
elif random.random() < split_strength:
|
elif random.random() < split_strength:
|
||||||
sentence = sentence.replace(',', ' ').replace(',', ' ')
|
sentence = sentence.replace(',', ' ').replace(',', ' ')
|
||||||
sentences_done.append(sentence)
|
sentences_done.append(sentence)
|
||||||
|
|
||||||
print(f"处理后的句子: {sentences_done}")
|
print(f"处理后的句子: {sentences_done}")
|
||||||
return sentences_done
|
return sentences_done
|
||||||
|
|
||||||
|
|
||||||
# 常见的错别字映射
|
# 常见的错别字映射
|
||||||
TYPO_DICT = {
|
TYPO_DICT = {
|
||||||
'的': '地得',
|
'的': '地得',
|
||||||
@@ -355,6 +365,7 @@ TYPO_DICT = {
|
|||||||
'嘻': '嘻西希'
|
'嘻': '嘻西希'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def random_remove_punctuation(text: str) -> str:
|
def random_remove_punctuation(text: str) -> str:
|
||||||
"""随机处理标点符号,模拟人类打字习惯
|
"""随机处理标点符号,模拟人类打字习惯
|
||||||
|
|
||||||
@@ -366,7 +377,7 @@ def random_remove_punctuation(text: str) -> str:
|
|||||||
"""
|
"""
|
||||||
result = ''
|
result = ''
|
||||||
text_len = len(text)
|
text_len = len(text)
|
||||||
|
|
||||||
for i, char in enumerate(text):
|
for i, char in enumerate(text):
|
||||||
if char == '。' and i == text_len - 1: # 结尾的句号
|
if char == '。' and i == text_len - 1: # 结尾的句号
|
||||||
if random.random() > 0.4: # 80%概率删除结尾句号
|
if random.random() > 0.4: # 80%概率删除结尾句号
|
||||||
@@ -381,6 +392,7 @@ def random_remove_punctuation(text: str) -> str:
|
|||||||
result += char
|
result += char
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def add_typos(text: str) -> str:
|
def add_typos(text: str) -> str:
|
||||||
TYPO_RATE = 0.02 # 控制错别字出现的概率(2%)
|
TYPO_RATE = 0.02 # 控制错别字出现的概率(2%)
|
||||||
result = ""
|
result = ""
|
||||||
@@ -393,20 +405,22 @@ def add_typos(text: str) -> str:
|
|||||||
result += char
|
result += char
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def process_llm_response(text: str) -> List[str]:
|
def process_llm_response(text: str) -> List[str]:
|
||||||
# processed_response = process_text_with_typos(content)
|
# processed_response = process_text_with_typos(content)
|
||||||
if len(text) > 300:
|
if len(text) > 300:
|
||||||
print(f"回复过长 ({len(text)} 字符),返回默认回复")
|
print(f"回复过长 ({len(text)} 字符),返回默认回复")
|
||||||
return ['懒得说']
|
return ['懒得说']
|
||||||
# 处理长消息
|
# 处理长消息
|
||||||
sentences = split_into_sentences_w_remove_punctuation(add_typos(text))
|
sentences = split_into_sentences_w_remove_punctuation(add_typos(text))
|
||||||
# 检查分割后的消息数量是否过多(超过3条)
|
# 检查分割后的消息数量是否过多(超过3条)
|
||||||
if len(sentences) > 4:
|
if len(sentences) > 4:
|
||||||
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
|
||||||
return [f'{global_config.BOT_NICKNAME}不知道哦']
|
return [f'{global_config.BOT_NICKNAME}不知道哦']
|
||||||
|
|
||||||
return sentences
|
return sentences
|
||||||
|
|
||||||
|
|
||||||
def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float:
|
def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float:
|
||||||
"""
|
"""
|
||||||
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
|
计算输入字符串所需的时间,中文和英文字符有不同的输入时间
|
||||||
@@ -419,32 +433,10 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_
|
|||||||
if '\u4e00' <= char <= '\u9fff': # 判断是否为中文字符
|
if '\u4e00' <= char <= '\u9fff': # 判断是否为中文字符
|
||||||
total_time += chinese_time
|
total_time += chinese_time
|
||||||
else: # 其他字符(如英文)
|
else: # 其他字符(如英文)
|
||||||
total_time += english_time
|
total_time += english_time
|
||||||
return total_time
|
return total_time
|
||||||
|
|
||||||
|
|
||||||
def find_similar_topics(message_txt: str, all_memory_topic: list, top_k: int = 5) -> list:
|
|
||||||
"""使用重排序API找出与输入文本最相似的话题
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_txt: 输入文本
|
|
||||||
all_memory_topic: 所有记忆主题列表
|
|
||||||
top_k: 返回最相似的话题数量
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: 最相似话题列表及其相似度分数
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not all_memory_topic:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
llm = LLM_request(model=global_config.rerank)
|
|
||||||
return llm.rerank_sync(message_txt, all_memory_topic, top_k)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"重排序API调用出错: {str(e)}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
def cosine_similarity(v1, v2):
|
def cosine_similarity(v1, v2):
|
||||||
"""计算余弦相似度"""
|
"""计算余弦相似度"""
|
||||||
dot_product = np.dot(v1, v2)
|
dot_product = np.dot(v1, v2)
|
||||||
@@ -454,6 +446,7 @@ def cosine_similarity(v1, v2):
|
|||||||
return 0
|
return 0
|
||||||
return dot_product / (norm1 * norm2)
|
return dot_product / (norm1 * norm2)
|
||||||
|
|
||||||
|
|
||||||
def text_to_vector(text):
|
def text_to_vector(text):
|
||||||
"""将文本转换为词频向量"""
|
"""将文本转换为词频向量"""
|
||||||
# 分词
|
# 分词
|
||||||
@@ -462,11 +455,12 @@ def text_to_vector(text):
|
|||||||
word_freq = Counter(words)
|
word_freq = Counter(words)
|
||||||
return word_freq
|
return word_freq
|
||||||
|
|
||||||
|
|
||||||
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
||||||
"""使用简单的余弦相似度计算文本相似度"""
|
"""使用简单的余弦相似度计算文本相似度"""
|
||||||
# 将输入文本转换为词频向量
|
# 将输入文本转换为词频向量
|
||||||
text_vector = text_to_vector(text)
|
text_vector = text_to_vector(text)
|
||||||
|
|
||||||
# 计算每个主题的相似度
|
# 计算每个主题的相似度
|
||||||
similarities = []
|
similarities = []
|
||||||
for topic in topics:
|
for topic in topics:
|
||||||
@@ -479,6 +473,6 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
|
|||||||
# 计算相似度
|
# 计算相似度
|
||||||
similarity = cosine_similarity(v1, v2)
|
similarity = cosine_similarity(v1, v2)
|
||||||
similarities.append((topic, similarity))
|
similarities.append((topic, similarity))
|
||||||
|
|
||||||
# 按相似度降序排序并返回前k个
|
# 按相似度降序排序并返回前k个
|
||||||
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
|
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ from ..chat.config import global_config
|
|||||||
from ...common.database import Database # 使用正确的导入语法
|
from ...common.database import Database # 使用正确的导入语法
|
||||||
from ..models.utils_model import LLM_request
|
from ..models.utils_model import LLM_request
|
||||||
import math
|
import math
|
||||||
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db ,find_similar_topics,text_to_vector,cosine_similarity
|
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db ,text_to_vector,cosine_similarity
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -25,354 +25,195 @@ class LLM_request:
|
|||||||
self.model_name = model["name"]
|
self.model_name = model["name"]
|
||||||
self.params = kwargs
|
self.params = kwargs
|
||||||
|
|
||||||
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
async def _execute_request(
|
||||||
"""根据输入的提示生成模型的异步响应"""
|
self,
|
||||||
headers = {
|
endpoint: str,
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
prompt: str = None,
|
||||||
"Content-Type": "application/json"
|
image_base64: str = None,
|
||||||
|
payload: dict = None,
|
||||||
|
retry_policy: dict = None,
|
||||||
|
response_handler: callable = None,
|
||||||
|
):
|
||||||
|
"""统一请求执行入口
|
||||||
|
Args:
|
||||||
|
endpoint: API端点路径 (如 "chat/completions")
|
||||||
|
prompt: prompt文本
|
||||||
|
image_base64: 图片的base64编码
|
||||||
|
payload: 请求体数据
|
||||||
|
is_async: 是否异步
|
||||||
|
retry_policy: 自定义重试策略
|
||||||
|
(示例: {"max_retries":3, "base_wait":15, "retry_codes":[429,500]})
|
||||||
|
response_handler: 自定义响应处理器
|
||||||
|
"""
|
||||||
|
# 合并重试策略
|
||||||
|
default_retry = {
|
||||||
|
"max_retries": 3, "base_wait": 15,
|
||||||
|
"retry_codes": [429, 413, 500, 503],
|
||||||
|
"abort_codes": [400, 401, 402, 403]}
|
||||||
|
policy = {**default_retry, **(retry_policy or {})}
|
||||||
|
|
||||||
|
# 常见Error Code Mapping
|
||||||
|
error_code_mapping = {
|
||||||
|
400: "参数不正确",
|
||||||
|
401: "API key 错误,认证失败",
|
||||||
|
402: "账号余额不足",
|
||||||
|
403: "需要实名,或余额不足",
|
||||||
|
404: "Not Found",
|
||||||
|
429: "请求过于频繁,请稍后再试",
|
||||||
|
500: "服务器内部故障",
|
||||||
|
503: "服务器负载过高"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
|
||||||
|
logger.info(f"发送请求到URL: {api_url}{self.model_name}")
|
||||||
|
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
data = {
|
if image_base64:
|
||||||
"model": self.model_name,
|
payload = await self._build_payload(prompt, image_base64)
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
elif payload is None:
|
||||||
**self.params
|
payload = await self._build_payload(prompt)
|
||||||
}
|
|
||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
session_method = aiohttp.ClientSession()
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL
|
|
||||||
|
|
||||||
max_retries = 3
|
for retry in range(policy["max_retries"]):
|
||||||
base_wait_time = 15
|
|
||||||
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
try:
|
||||||
async with aiohttp.ClientSession() as session:
|
# 使用上下文管理器处理会话
|
||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
headers = await self._build_headers()
|
||||||
if response.status == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
|
||||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if response.status in [500, 503]:
|
async with session_method as session:
|
||||||
logger.error(f"服务器错误: {response.status}")
|
response = await session.post(api_url, headers=headers, json=payload)
|
||||||
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
# 处理需要重试的状态码
|
||||||
|
if response.status in policy["retry_codes"]:
|
||||||
|
wait_time = policy["base_wait"] * (2 ** retry)
|
||||||
|
logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试")
|
||||||
|
if response.status == 413:
|
||||||
|
logger.warning("请求体过大,尝试压缩...")
|
||||||
|
image_base64 = compress_base64_image_by_scale(image_base64)
|
||||||
|
payload = await self._build_payload(prompt, image_base64)
|
||||||
|
elif response.status in [500, 503]:
|
||||||
|
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
|
||||||
|
raise RuntimeError("服务器负载过高,模型恢复失败QAQ")
|
||||||
|
else:
|
||||||
|
logger.warning(f"请求限制(429),等待{wait_time}秒后重试...")
|
||||||
|
|
||||||
result = await response.json()
|
await asyncio.sleep(wait_time)
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
continue
|
||||||
message = result["choices"][0]["message"]
|
elif response.status in policy["abort_codes"]:
|
||||||
content = message.get("content", "")
|
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
|
||||||
think_match = None
|
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
|
||||||
reasoning_content = message.get("reasoning_content", "")
|
|
||||||
if not reasoning_content:
|
response.raise_for_status()
|
||||||
think_match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
|
result = await response.json()
|
||||||
if think_match:
|
|
||||||
reasoning_content = think_match.group(1).strip()
|
# 使用自定义处理器或默认处理
|
||||||
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
|
return response_handler(result) if response_handler else self._default_response_handler(result)
|
||||||
return content, reasoning_content
|
|
||||||
return "没有返回结果", ""
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
if retry < policy["max_retries"] - 1:
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
wait_time = policy["base_wait"] * (2 ** retry)
|
||||||
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
logger.error(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
||||||
await asyncio.sleep(wait_time)
|
await asyncio.sleep(wait_time)
|
||||||
else:
|
else:
|
||||||
logger.critical(f"请求失败: {str(e)}", exc_info=True)
|
logger.critical(f"请求失败: {str(e)}")
|
||||||
logger.critical(f"请求头: {headers} 请求体: {data}")
|
logger.critical(f"请求头: {self._build_headers()} 请求体: {payload}")
|
||||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
raise RuntimeError(f"API请求失败: {str(e)}")
|
||||||
|
|
||||||
logger.error("达到最大重试次数,请求仍然失败")
|
logger.error("达到最大重试次数,请求仍然失败")
|
||||||
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
||||||
|
|
||||||
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
async def _build_payload(self, prompt: str, image_base64: str = None) -> dict:
|
||||||
"""根据输入的提示和图片生成模型的异步响应"""
|
"""构建请求体"""
|
||||||
headers = {
|
if image_base64:
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
def build_request_data(img_base64: str):
|
|
||||||
return {
|
return {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
{
|
{"type": "text", "text": prompt},
|
||||||
"type": "text",
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
|
||||||
"text": prompt
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{img_base64}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
|
"max_tokens": global_config.max_response_length,
|
||||||
|
**self.params
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"model": self.model_name,
|
||||||
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
|
"max_tokens": global_config.max_response_length,
|
||||||
**self.params
|
**self.params
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _default_response_handler(self, result: dict) -> Tuple:
|
||||||
|
"""默认响应解析"""
|
||||||
|
if "choices" in result and result["choices"]:
|
||||||
|
message = result["choices"][0]["message"]
|
||||||
|
content = message.get("content", "")
|
||||||
|
content, reasoning = self._extract_reasoning(content)
|
||||||
|
reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
|
||||||
|
if not reasoning_content:
|
||||||
|
reasoning_content = reasoning
|
||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
return content, reasoning_content
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL
|
|
||||||
|
|
||||||
max_retries = 3
|
return "没有返回结果", ""
|
||||||
base_wait_time = 15
|
|
||||||
|
|
||||||
current_image_base64 = image_base64
|
def _extract_reasoning(self, content: str) -> tuple[str, str]:
|
||||||
current_image_base64 = compress_base64_image_by_scale(current_image_base64)
|
"""CoT思维链提取"""
|
||||||
|
match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
|
||||||
|
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
|
||||||
|
if match:
|
||||||
|
reasoning = match.group(1).strip()
|
||||||
|
else:
|
||||||
|
reasoning = ""
|
||||||
|
return content, reasoning
|
||||||
|
|
||||||
for retry in range(max_retries):
|
async def _build_headers(self) -> dict:
|
||||||
try:
|
"""构建请求头"""
|
||||||
data = build_request_data(current_image_base64)
|
return {
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
|
||||||
if response.status == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
|
||||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
elif response.status == 413:
|
|
||||||
logger.warning("图片太大(413),尝试压缩...")
|
|
||||||
current_image_base64 = compress_base64_image_by_scale(current_image_base64)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
message = result["choices"][0]["message"]
|
|
||||||
content = message.get("content", "")
|
|
||||||
think_match = None
|
|
||||||
reasoning_content = message.get("reasoning_content", "")
|
|
||||||
if not reasoning_content:
|
|
||||||
think_match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
|
|
||||||
if think_match:
|
|
||||||
reasoning_content = think_match.group(1).strip()
|
|
||||||
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
|
|
||||||
return content, reasoning_content
|
|
||||||
return "没有返回结果", ""
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.error(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
logger.critical(f"请求失败: {str(e)}", exc_info=True)
|
|
||||||
logger.critical(f"请求头: {headers} 请求体: {data}")
|
|
||||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
|
||||||
|
|
||||||
logger.error("达到最大重试次数,请求仍然失败")
|
|
||||||
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
|
||||||
|
|
||||||
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
|
||||||
"""异步方式根据输入的提示生成模型的响应"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def generate_response(self, prompt: str) -> Tuple[str, str]:
|
||||||
|
"""根据输入的提示生成模型的异步响应"""
|
||||||
|
|
||||||
|
content, reasoning_content = await self._execute_request(
|
||||||
|
endpoint="/chat/completions",
|
||||||
|
prompt=prompt
|
||||||
|
)
|
||||||
|
return content, reasoning_content
|
||||||
|
|
||||||
|
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
||||||
|
"""根据输入的提示和图片生成模型的异步响应"""
|
||||||
|
|
||||||
|
content, reasoning_content = await self._execute_request(
|
||||||
|
endpoint="/chat/completions",
|
||||||
|
prompt=prompt,
|
||||||
|
image_base64=image_base64
|
||||||
|
)
|
||||||
|
return content, reasoning_content
|
||||||
|
|
||||||
|
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
|
||||||
|
"""异步方式根据输入的提示生成模型的响应"""
|
||||||
# 构建请求体
|
# 构建请求体
|
||||||
data = {
|
data = {
|
||||||
"model": self.model_name,
|
"model": self.model_name,
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
"messages": [{"role": "user", "content": prompt}],
|
||||||
"temperature": 0.5,
|
"temperature": 0.5,
|
||||||
|
"max_tokens": global_config.max_response_length,
|
||||||
**self.params
|
**self.params
|
||||||
}
|
}
|
||||||
|
|
||||||
# 发送请求到完整的 chat/completions 端点
|
content, reasoning_content = await self._execute_request(
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
endpoint="/chat/completions",
|
||||||
logger.info(f"Request URL: {api_url}") # 记录请求的 URL
|
payload=data,
|
||||||
|
prompt=prompt
|
||||||
max_retries = 3
|
)
|
||||||
base_wait_time = 15
|
return content, reasoning_content
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
|
||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
|
||||||
if response.status == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry) # 指数退避
|
|
||||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
message = result["choices"][0]["message"]
|
|
||||||
content = message.get("content", "")
|
|
||||||
think_match = None
|
|
||||||
reasoning_content = message.get("reasoning_content", "")
|
|
||||||
if not reasoning_content:
|
|
||||||
think_match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
|
|
||||||
if think_match:
|
|
||||||
reasoning_content = think_match.group(1).strip()
|
|
||||||
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
|
|
||||||
return content, reasoning_content
|
|
||||||
return "没有返回结果", ""
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
logger.error(f"请求失败: {str(e)}")
|
|
||||||
logger.critical(f"请求头: {headers} 请求体: {data}")
|
|
||||||
return f"请求失败: {str(e)}", ""
|
|
||||||
|
|
||||||
logger.error("达到最大重试次数,请求仍然失败")
|
|
||||||
return "达到最大重试次数,请求仍然失败", ""
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
|
|
||||||
"""同步方法:根据输入的提示和图片生成模型的响应"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
image_base64=compress_base64_image_by_scale(image_base64)
|
|
||||||
|
|
||||||
# 构建请求体
|
|
||||||
data = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": [
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": prompt
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "image_url",
|
|
||||||
"image_url": {
|
|
||||||
"url": f"data:image/jpeg;base64,{image_base64}"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
],
|
|
||||||
**self.params
|
|
||||||
}
|
|
||||||
|
|
||||||
# 发送请求到完整的chat/completions端点
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
|
|
||||||
logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL
|
|
||||||
|
|
||||||
max_retries = 2
|
|
||||||
base_wait_time = 6
|
|
||||||
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
|
||||||
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
|
||||||
|
|
||||||
if response.status_code == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status() # 检查其他响应状态
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
message = result["choices"][0]["message"]
|
|
||||||
content = message.get("content", "")
|
|
||||||
think_match = None
|
|
||||||
reasoning_content = message.get("reasoning_content", "")
|
|
||||||
if not reasoning_content:
|
|
||||||
think_match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
|
|
||||||
if think_match:
|
|
||||||
reasoning_content = think_match.group(1).strip()
|
|
||||||
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
|
|
||||||
return content, reasoning_content
|
|
||||||
return "没有返回结果", ""
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if retry < max_retries - 1: # 如果还有重试机会
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.error(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
|
||||||
time.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
logger.critical(f"请求失败: {str(e)}", exc_info=True)
|
|
||||||
logger.critical(f"请求头: {headers} 请求体: {data}")
|
|
||||||
raise RuntimeError(f"API请求失败: {str(e)}")
|
|
||||||
|
|
||||||
logger.error("达到最大重试次数,请求仍然失败")
|
|
||||||
raise RuntimeError("达到最大重试次数,API请求仍然失败")
|
|
||||||
|
|
||||||
def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
|
||||||
"""同步方法:获取文本的embedding向量
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 需要获取embedding的文本
|
|
||||||
model: 使用的模型名称,默认为"BAAI/bge-m3"
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: embedding向量,如果失败则返回None
|
|
||||||
"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": model,
|
|
||||||
"input": text,
|
|
||||||
"encoding_format": "float"
|
|
||||||
}
|
|
||||||
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
|
||||||
logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL
|
|
||||||
|
|
||||||
max_retries = 2
|
|
||||||
base_wait_time = 6
|
|
||||||
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
|
||||||
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
|
||||||
|
|
||||||
if response.status_code == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
if 'data' in result and len(result['data']) > 0:
|
|
||||||
return result['data'][0]['embedding']
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if retry < max_retries - 1:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.error(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
|
||||||
time.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
|
|
||||||
logger.critical(f"请求头: {headers} 请求体: {data}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.error("达到最大重试次数,embedding请求仍然失败")
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
|
||||||
"""异步方法:获取文本的embedding向量
|
"""异步方法:获取文本的embedding向量
|
||||||
@@ -384,245 +225,24 @@ class LLM_request:
|
|||||||
Returns:
|
Returns:
|
||||||
list: embedding向量,如果失败则返回None
|
list: embedding向量,如果失败则返回None
|
||||||
"""
|
"""
|
||||||
headers = {
|
def embedding_handler(result):
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"""处理响应"""
|
||||||
"Content-Type": "application/json"
|
if "data" in result and len(result["data"]) > 0:
|
||||||
}
|
return result["data"][0].get("embedding", None)
|
||||||
|
return None
|
||||||
|
|
||||||
data = {
|
embedding = await self._execute_request(
|
||||||
"model": model,
|
endpoint="/embeddings",
|
||||||
"input": text,
|
prompt=text,
|
||||||
"encoding_format": "float"
|
payload={
|
||||||
}
|
"model": model,
|
||||||
|
"input": text,
|
||||||
api_url = f"{self.base_url.rstrip('/')}/embeddings"
|
"encoding_format": "float"
|
||||||
logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL
|
},
|
||||||
|
retry_policy={
|
||||||
max_retries = 3
|
"max_retries": 2,
|
||||||
base_wait_time = 15
|
"base_wait": 6
|
||||||
|
},
|
||||||
for retry in range(max_retries):
|
response_handler=embedding_handler
|
||||||
try:
|
)
|
||||||
async with aiohttp.ClientSession() as session:
|
return embedding
|
||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
|
||||||
if response.status == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if 'data' in result and len(result['data']) > 0:
|
|
||||||
return result['data'][0]['embedding']
|
|
||||||
return None
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if retry < max_retries - 1:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.error(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
|
|
||||||
logger.critical(f"请求头: {headers} 请求体: {data}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
logger.error("达到最大重试次数,embedding请求仍然失败")
|
|
||||||
return None
|
|
||||||
|
|
||||||
def rerank_sync(self, query: str, documents: list, top_k: int = 5) -> list:
|
|
||||||
"""同步方法:使用重排序API对文档进行排序
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 查询文本
|
|
||||||
documents: 待排序的文档列表
|
|
||||||
top_k: 返回前k个结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: [(document, score), ...] 格式的结果列表
|
|
||||||
"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"query": query,
|
|
||||||
"documents": documents,
|
|
||||||
"top_n": top_k,
|
|
||||||
"return_documents": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/rerank"
|
|
||||||
logger.info(f"发送请求到URL: {api_url}")
|
|
||||||
|
|
||||||
max_retries = 2
|
|
||||||
base_wait_time = 6
|
|
||||||
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
|
||||||
response = requests.post(api_url, headers=headers, json=data, timeout=30)
|
|
||||||
|
|
||||||
if response.status_code == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if response.status_code in [500, 503]:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.error(f"服务器错误({response.status_code}),等待{wait_time}秒后重试...")
|
|
||||||
if retry < max_retries - 1:
|
|
||||||
time.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# 如果是最后一次重试,尝试使用chat/completions作为备选方案
|
|
||||||
return self._fallback_rerank_with_chat(query, documents, top_k)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
if 'results' in result:
|
|
||||||
return [(item["document"], item["score"]) for item in result["results"]]
|
|
||||||
return []
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if retry < max_retries - 1:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.error(f"[rerank_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
|
||||||
time.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
logger.critical(f"重排序请求失败: {str(e)}", exc_info=True)
|
|
||||||
|
|
||||||
logger.error("达到最大重试次数,重排序请求仍然失败")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def rerank(self, query: str, documents: list, top_k: int = 5) -> list:
|
|
||||||
"""异步方法:使用重排序API对文档进行排序
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 查询文本
|
|
||||||
documents: 待排序的文档列表
|
|
||||||
top_k: 返回前k个结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: [(document, score), ...] 格式的结果列表
|
|
||||||
"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"query": query,
|
|
||||||
"documents": documents,
|
|
||||||
"top_n": top_k,
|
|
||||||
"return_documents": True,
|
|
||||||
}
|
|
||||||
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/v1/rerank"
|
|
||||||
logger.info(f"发送请求到URL: {api_url}")
|
|
||||||
|
|
||||||
max_retries = 3
|
|
||||||
base_wait_time = 15
|
|
||||||
|
|
||||||
for retry in range(max_retries):
|
|
||||||
try:
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
|
||||||
if response.status == 429:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if response.status in [500, 503]:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.error(f"服务器错误({response.status}),等待{wait_time}秒后重试...")
|
|
||||||
if retry < max_retries - 1:
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
# 如果是最后一次重试,尝试使用chat/completions作为备选方案
|
|
||||||
return await self._fallback_rerank_with_chat_async(query, documents, top_k)
|
|
||||||
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = await response.json()
|
|
||||||
if 'results' in result:
|
|
||||||
return [(item["document"], item["score"]) for item in result["results"]]
|
|
||||||
return []
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
if retry < max_retries - 1:
|
|
||||||
wait_time = base_wait_time * (2 ** retry)
|
|
||||||
logger.error(f"[rerank]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
|
|
||||||
await asyncio.sleep(wait_time)
|
|
||||||
else:
|
|
||||||
logger.critical(f"重排序请求失败: {str(e)}", exc_info=True)
|
|
||||||
# 作为最后的备选方案,尝试使用chat/completions
|
|
||||||
return await self._fallback_rerank_with_chat_async(query, documents, top_k)
|
|
||||||
|
|
||||||
logger.error("达到最大重试次数,重排序请求仍然失败")
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def _fallback_rerank_with_chat_async(self, query: str, documents: list, top_k: int = 5) -> list:
|
|
||||||
"""当rerank API失败时的备选方案,使用chat/completions异步实现重排序
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 查询文本
|
|
||||||
documents: 待排序的文档列表
|
|
||||||
top_k: 返回前k个结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
list: [(document, score), ...] 格式的结果列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
logger.info("使用chat/completions作为重排序的备选方案")
|
|
||||||
|
|
||||||
# 构建提示词
|
|
||||||
prompt = f"""请对以下文档列表进行重排序,按照与查询的相关性从高到低排序。
|
|
||||||
查询: {query}
|
|
||||||
|
|
||||||
文档列表:
|
|
||||||
{documents}
|
|
||||||
|
|
||||||
请以JSON格式返回排序结果,格式为:
|
|
||||||
[{{"document": "文档内容", "score": 相关性分数}}, ...]
|
|
||||||
只返回JSON,不要其他任何文字。"""
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
data = {
|
|
||||||
"model": self.model_name,
|
|
||||||
"messages": [{"role": "user", "content": prompt}],
|
|
||||||
**self.params
|
|
||||||
}
|
|
||||||
|
|
||||||
api_url = f"{self.base_url.rstrip('/')}/v1/chat/completions"
|
|
||||||
|
|
||||||
async with aiohttp.ClientSession() as session:
|
|
||||||
async with session.post(api_url, headers=headers, json=data) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
result = await response.json()
|
|
||||||
|
|
||||||
if "choices" in result and len(result["choices"]) > 0:
|
|
||||||
message = result["choices"][0]["message"]
|
|
||||||
content = message.get("content", "")
|
|
||||||
try:
|
|
||||||
import json
|
|
||||||
parsed_content = json.loads(content)
|
|
||||||
if isinstance(parsed_content, list):
|
|
||||||
return [(item["document"], item["score"]) for item in parsed_content]
|
|
||||||
except:
|
|
||||||
pass
|
|
||||||
return []
|
|
||||||
except Exception as e:
|
|
||||||
logger.error(f"备选方案也失败了: {str(e)}")
|
|
||||||
return []
|
|
||||||
|
|||||||
Reference in New Issue
Block a user