Merge remote-tracking branch 'upstream/debug' into debug

This commit is contained in:
Rikki
2025-03-07 06:42:24 +08:00
11 changed files with 377 additions and 817 deletions

View File

@@ -58,6 +58,7 @@ class ChatBot:
plain_text=event.get_plaintext(), plain_text=event.get_plaintext(),
reply_message=event.reply, reply_message=event.reply,
) )
await message.initialize()
# 过滤词 # 过滤词
for word in global_config.ban_words: for word in global_config.ban_words:
@@ -163,12 +164,6 @@ class ChatBot:
message_manager.add_message(message_set) message_manager.add_message(message_set)
bot_response_time = tinking_time_point bot_response_time = tinking_time_point
emotion = await self.gpt._get_emotion_tags(raw_content)
print(f"'{response}' 获取到的情感标签为:{emotion}")
valuedict={
'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25
}
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
if random() < global_config.emoji_chance: if random() < global_config.emoji_chance:
emoji_path = await emoji_manager.get_emoji_for_text(response) emoji_path = await emoji_manager.get_emoji_for_text(response)
@@ -196,6 +191,12 @@ class ChatBot:
# reply_message_id=message.message_id # reply_message_id=message.message_id
) )
message_manager.add_message(bot_message) message_manager.add_message(bot_message)
emotion = await self.gpt._get_emotion_tags(raw_content)
print(f"'{response}' 获取到的情感标签为:{emotion}")
valuedict={
'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25
}
await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]])
# willing_manager.change_reply_willing_after_sent(event.group_id) # willing_manager.change_reply_willing_after_sent(event.group_id)

View File

@@ -30,9 +30,13 @@ class BotConfig:
forget_memory_interval: int = 300 # 记忆遗忘间隔(秒) forget_memory_interval: int = 300 # 记忆遗忘间隔(秒)
EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟)
EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟)
EMOJI_CHECK_PROMPT: str = "不要包含违反公序良俗的内容" # 表情包过滤要求 EMOJI_SAVE: bool = True # 表情包
EMOJI_CHECK: bool = False #是否开启过滤
EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求
ban_words = set() ban_words = set()
max_response_length: int = 1024 # 最大回复长度
# 模型配置 # 模型配置
llm_reasoning: Dict[str, str] = field(default_factory=lambda: {}) llm_reasoning: Dict[str, str] = field(default_factory=lambda: {})
@@ -96,6 +100,8 @@ class BotConfig:
config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL) config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL)
config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL) config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL)
config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt',config.EMOJI_CHECK_PROMPT) config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt',config.EMOJI_CHECK_PROMPT)
config.EMOJI_SAVE = emoji_config.get('auto_save',config.EMOJI_SAVE)
config.EMOJI_CHECK = emoji_config.get('enable_check',config.EMOJI_CHECK)
if "cq_code" in toml_dict: if "cq_code" in toml_dict:
cq_code_config = toml_dict["cq_code"] cq_code_config = toml_dict["cq_code"]
@@ -115,6 +121,7 @@ class BotConfig:
config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY) config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY)
config.API_USING = response_config.get("api_using", config.API_USING) config.API_USING = response_config.get("api_using", config.API_USING)
config.API_PAID = response_config.get("api_paid", config.API_PAID) config.API_PAID = response_config.get("api_paid", config.API_PAID)
config.max_response_length = response_config.get("max_response_length", config.max_response_length)
# 加载模型配置 # 加载模型配置
if "model" in toml_dict: if "model" in toml_dict:

View File

@@ -10,11 +10,11 @@ from nonebot.adapters.onebot.v11 import Bot
from .config import global_config from .config import global_config
import time import time
import asyncio import asyncio
from .utils_image import storage_image,storage_emoji from .utils_image import storage_image, storage_emoji
from .utils_user import get_user_nickname from .utils_user import get_user_nickname
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
#解析各种CQ码 # 解析各种CQ码
#包含CQ码类 # 包含CQ码类
import urllib3 import urllib3
from urllib3.util import create_urllib3_context from urllib3.util import create_urllib3_context
from nonebot import get_driver from nonebot import get_driver
@@ -27,6 +27,7 @@ ctx = create_urllib3_context()
ctx.load_default_certs() ctx.load_default_certs()
ctx.set_ciphers("AES128-GCM-SHA256") ctx.set_ciphers("AES128-GCM-SHA256")
class TencentSSLAdapter(requests.adapters.HTTPAdapter): class TencentSSLAdapter(requests.adapters.HTTPAdapter):
def __init__(self, ssl_context=None, **kwargs): def __init__(self, ssl_context=None, **kwargs):
self.ssl_context = ssl_context self.ssl_context = ssl_context
@@ -37,6 +38,7 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter):
num_pools=connections, maxsize=maxsize, num_pools=connections, maxsize=maxsize,
block=block, ssl_context=self.ssl_context) block=block, ssl_context=self.ssl_context)
@dataclass @dataclass
class CQCode: class CQCode:
""" """
@@ -64,15 +66,15 @@ class CQCode:
"""初始化LLM实例""" """初始化LLM实例"""
self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300)
def translate(self): async def translate(self):
"""根据CQ码类型进行相应的翻译处理""" """根据CQ码类型进行相应的翻译处理"""
if self.type == 'text': if self.type == 'text':
self.translated_plain_text = self.params.get('text', '') self.translated_plain_text = self.params.get('text', '')
elif self.type == 'image': elif self.type == 'image':
if self.params.get('sub_type') == '0': if self.params.get('sub_type') == '0':
self.translated_plain_text = self.translate_image() self.translated_plain_text = await self.translate_image()
else: else:
self.translated_plain_text = self.translate_emoji() self.translated_plain_text = await self.translate_emoji()
elif self.type == 'at': elif self.type == 'at':
user_nickname = get_user_nickname(self.params.get('qq', '')) user_nickname = get_user_nickname(self.params.get('qq', ''))
if user_nickname: if user_nickname:
@@ -80,13 +82,13 @@ class CQCode:
else: else:
self.translated_plain_text = f"@某人" self.translated_plain_text = f"@某人"
elif self.type == 'reply': elif self.type == 'reply':
self.translated_plain_text = self.translate_reply() self.translated_plain_text = await self.translate_reply()
elif self.type == 'face': elif self.type == 'face':
face_id = self.params.get('id', '') face_id = self.params.get('id', '')
# self.translated_plain_text = f"[表情{face_id}]" # self.translated_plain_text = f"[表情{face_id}]"
self.translated_plain_text = f"[表情]" self.translated_plain_text = f"[表情]"
elif self.type == 'forward': elif self.type == 'forward':
self.translated_plain_text = self.translate_forward() self.translated_plain_text = await self.translate_forward()
else: else:
self.translated_plain_text = f"[{self.type}]" self.translated_plain_text = f"[{self.type}]"
@@ -133,7 +135,7 @@ class CQCode:
# 腾讯服务器特殊状态码处理 # 腾讯服务器特殊状态码处理
if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url: if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url:
return None return None
if response.status_code != 200: if response.status_code != 200:
raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") raise requests.exceptions.HTTPError(f"HTTP {response.status_code}")
@@ -157,8 +159,8 @@ class CQCode:
return None return None
return None return None
def translate_emoji(self) -> str: async def translate_emoji(self) -> str:
"""处理表情包类型的CQ码""" """处理表情包类型的CQ码"""
if 'url' not in self.params: if 'url' not in self.params:
return '[表情包]' return '[表情包]'
@@ -167,50 +169,51 @@ class CQCode:
# 将 base64 字符串转换为字节类型 # 将 base64 字符串转换为字节类型
image_bytes = base64.b64decode(base64_str) image_bytes = base64.b64decode(base64_str)
storage_emoji(image_bytes) storage_emoji(image_bytes)
return self.get_emoji_description(base64_str) return await self.get_emoji_description(base64_str)
else: else:
return '[表情包]' return '[表情包]'
async def translate_image(self) -> str:
def translate_image(self) -> str:
"""处理图片类型的CQ码区分普通图片和表情包""" """处理图片类型的CQ码区分普通图片和表情包"""
#没有url直接返回默认文本 # 没有url直接返回默认文本
if 'url' not in self.params: if 'url' not in self.params:
return '[图片]' return '[图片]'
base64_str = self.get_img() base64_str = self.get_img()
if base64_str: if base64_str:
image_bytes = base64.b64decode(base64_str) image_bytes = base64.b64decode(base64_str)
storage_image(image_bytes) storage_image(image_bytes)
return self.get_image_description(base64_str) return await self.get_image_description(base64_str)
else: else:
return '[图片]' return '[图片]'
def get_emoji_description(self, image_base64: str) -> str: async def get_emoji_description(self, image_base64: str) -> str:
"""调用AI接口获取表情包描述""" """调用AI接口获取表情包描述"""
try: try:
prompt = "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。" prompt = "这是一个表情包请用简短的中文描述这个表情包传达的情感和含义。最多20个字。"
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) # description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[表情包:{description}]" return f"[表情包:{description}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[表情包]" return "[表情包]"
def get_image_description(self, image_base64: str) -> str: async def get_image_description(self, image_base64: str) -> str:
"""调用AI接口获取普通图片描述""" """调用AI接口获取普通图片描述"""
try: try:
prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" prompt = "请用中文描述这张图片的内容。如果有文字请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。"
description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) # description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64)
description, _ = await self._llm.generate_response_for_image(prompt, image_base64)
return f"[图片:{description}]" return f"[图片:{description}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}")
return "[图片]" return "[图片]"
def translate_forward(self) -> str: async def translate_forward(self) -> str:
"""处理转发消息""" """处理转发消息"""
try: try:
if 'content' not in self.params: if 'content' not in self.params:
return '[转发消息]' return '[转发消息]'
# 解析content内容需要先反转义 # 解析content内容需要先反转义
content = self.unescape(self.params['content']) content = self.unescape(self.params['content'])
# print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}") # print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}")
@@ -221,17 +224,17 @@ class CQCode:
except ValueError as e: except ValueError as e:
print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}")
return '[转发消息]' return '[转发消息]'
# 处理每条消息 # 处理每条消息
formatted_messages = [] formatted_messages = []
for msg in messages: for msg in messages:
sender = msg.get('sender', {}) sender = msg.get('sender', {})
nickname = sender.get('card') or sender.get('nickname', '未知用户') nickname = sender.get('card') or sender.get('nickname', '未知用户')
# 获取消息内容并使用Message类处理 # 获取消息内容并使用Message类处理
raw_message = msg.get('raw_message', '') raw_message = msg.get('raw_message', '')
message_array = msg.get('message', []) message_array = msg.get('message', [])
if message_array and isinstance(message_array, list): if message_array and isinstance(message_array, list):
# 检查是否包含嵌套的转发消息 # 检查是否包含嵌套的转发消息
for message_part in message_array: for message_part in message_array:
@@ -249,6 +252,7 @@ class CQCode:
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_id=msg.get('group_id', 0)
) )
await message_obj.initialize()
content = message_obj.processed_plain_text content = message_obj.processed_plain_text
else: else:
content = '[空消息]' content = '[空消息]'
@@ -263,23 +267,24 @@ class CQCode:
plain_text=raw_message, plain_text=raw_message,
group_id=msg.get('group_id', 0) group_id=msg.get('group_id', 0)
) )
await message_obj.initialize()
content = message_obj.processed_plain_text content = message_obj.processed_plain_text
else: else:
content = '[空消息]' content = '[空消息]'
formatted_msg = f"{nickname}: {content}" formatted_msg = f"{nickname}: {content}"
formatted_messages.append(formatted_msg) formatted_messages.append(formatted_msg)
# 合并所有消息 # 合并所有消息
combined_messages = '\n'.join(formatted_messages) combined_messages = '\n'.join(formatted_messages)
print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}") print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}")
return f"[转发消息:\n{combined_messages}]" return f"[转发消息:\n{combined_messages}]"
except Exception as e: except Exception as e:
print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}")
return '[转发消息]' return '[转发消息]'
def translate_reply(self) -> str: async def translate_reply(self) -> str:
"""处理回复类型的CQ码""" """处理回复类型的CQ码"""
# 创建Message对象 # 创建Message对象
@@ -287,7 +292,7 @@ class CQCode:
if self.reply_message == None: if self.reply_message == None:
# print(f"\033[1;31m[错误]\033[0m 回复消息为空") # print(f"\033[1;31m[错误]\033[0m 回复消息为空")
return '[回复某人消息]' return '[回复某人消息]'
if self.reply_message.sender.user_id: if self.reply_message.sender.user_id:
message_obj = Message( message_obj = Message(
user_id=self.reply_message.sender.user_id, user_id=self.reply_message.sender.user_id,
@@ -295,6 +300,7 @@ class CQCode:
raw_message=str(self.reply_message.message), raw_message=str(self.reply_message.message),
group_id=self.group_id group_id=self.group_id
) )
await message_obj.initialize()
if message_obj.user_id == global_config.BOT_QQ: if message_obj.user_id == global_config.BOT_QQ:
return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]" return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]"
else: else:
@@ -308,9 +314,9 @@ class CQCode:
def unescape(text: str) -> str: def unescape(text: str) -> str:
"""反转义CQ码中的特殊字符""" """反转义CQ码中的特殊字符"""
return text.replace('&#44;', ',') \ return text.replace('&#44;', ',') \
.replace('&#91;', '[') \ .replace('&#91;', '[') \
.replace('&#93;', ']') \ .replace('&#93;', ']') \
.replace('&amp;', '&') .replace('&amp;', '&')
@staticmethod @staticmethod
def create_emoji_cq(file_path: str) -> str: def create_emoji_cq(file_path: str) -> str:
@@ -325,15 +331,16 @@ class CQCode:
abs_path = os.path.abspath(file_path) abs_path = os.path.abspath(file_path)
# 转义特殊字符 # 转义特殊字符
escaped_path = abs_path.replace('&', '&amp;') \ escaped_path = abs_path.replace('&', '&amp;') \
.replace('[', '&#91;') \ .replace('[', '&#91;') \
.replace(']', '&#93;') \ .replace(']', '&#93;') \
.replace(',', '&#44;') .replace(',', '&#44;')
# 生成CQ码设置sub_type=1表示这是表情包 # 生成CQ码设置sub_type=1表示这是表情包
return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" return f"[CQ:image,file=file:///{escaped_path},sub_type=1]"
class CQCode_tool: class CQCode_tool:
@staticmethod @staticmethod
def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode:
""" """
将CQ码字典转换为CQCode对象 将CQ码字典转换为CQCode对象
@@ -352,7 +359,7 @@ class CQCode_tool:
params['text'] = cq_code.get('data', {}).get('text', '') params['text'] = cq_code.get('data', {}).get('text', '')
else: else:
params = cq_code.get('data', {}) params = cq_code.get('data', {})
instance = CQCode( instance = CQCode(
type=cq_type, type=cq_type,
params=params, params=params,
@@ -360,11 +367,11 @@ class CQCode_tool:
user_id=0, user_id=0,
reply_message=reply reply_message=reply
) )
# 进行翻译处理 # 进行翻译处理
instance.translate() await instance.translate()
return instance return instance
@staticmethod @staticmethod
def create_reply_cq(message_id: int) -> str: def create_reply_cq(message_id: int) -> str:
""" """
@@ -375,6 +382,6 @@ class CQCode_tool:
回复CQ码字符串 回复CQ码字符串
""" """
return f"[CQ:reply,id={message_id}]" return f"[CQ:reply,id={message_id}]"
cq_code_tool = CQCode_tool() cq_code_tool = CQCode_tool()

View File

@@ -20,6 +20,7 @@ import traceback
from nonebot import get_driver from nonebot import get_driver
from ..chat.config import global_config from ..chat.config import global_config
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
from ..chat.utils_image import image_path_to_base64
from ..chat.utils import get_embedding from ..chat.utils import get_embedding
driver = get_driver() driver = get_driver()
@@ -98,7 +99,7 @@ class EmojiManager:
# 获取文本的embedding # 获取文本的embedding
text_for_search= await self._get_kimoji_for_text(text) text_for_search= await self._get_kimoji_for_text(text)
text_embedding = get_embedding(text_for_search) text_embedding = await get_embedding(text_for_search)
if not text_embedding: if not text_embedding:
logger.error("无法获取文本的embedding") logger.error("无法获取文本的embedding")
return None return None
@@ -160,27 +161,6 @@ class EmojiManager:
logger.error(f"获取表情包失败: {str(e)}") logger.error(f"获取表情包失败: {str(e)}")
return None return None
async def _get_emoji_tag(self, image_base64: str) -> str:
"""获取表情包的标签"""
try:
prompt = '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签不要输出其他任何内容只输出情感标签就好'
content, _ = await self.llm.generate_response_for_image(prompt, image_base64)
tag_result = content.strip().lower()
valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"]
for tag_match in valid_tags:
if tag_match in tag_result or tag_match == tag_result:
return tag_match
print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_result}, 跳过")
except Exception as e:
print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}")
return "neutral"
print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral")
return "neutral" # 默认标签
async def _get_emoji_discription(self, image_base64: str) -> str: async def _get_emoji_discription(self, image_base64: str) -> str:
"""获取表情包的标签""" """获取表情包的标签"""
try: try:
@@ -208,7 +188,7 @@ class EmojiManager:
async def _get_kimoji_for_text(self, text:str): async def _get_kimoji_for_text(self, text:str):
try: try:
prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。' prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。'
content, _ = await self.lm.generate_response_async(prompt) content, _ = await self.lm.generate_response_async(prompt)
logger.info(f"输出描述: {content}") logger.info(f"输出描述: {content}")
@@ -217,76 +197,7 @@ class EmojiManager:
except Exception as e: except Exception as e:
logger.error(f"获取标签失败: {str(e)}") logger.error(f"获取标签失败: {str(e)}")
return None return None
async def _compress_image(self, image_path: str, target_size: int = 0.8 * 1024 * 1024) -> Optional[str]:
"""压缩图片并返回base64编码
Args:
image_path: 图片文件路径
target_size: 目标文件大小字节默认0.8MB
Returns:
Optional[str]: 成功返回base64编码的图片数据失败返回None
"""
try:
file_size = os.path.getsize(image_path)
if file_size <= target_size:
# 如果文件已经小于目标大小直接读取并返回base64
with open(image_path, 'rb') as f:
return base64.b64encode(f.read()).decode('utf-8')
# 打开图片
with Image.open(image_path) as img:
# 获取原始尺寸
original_width, original_height = img.size
# 计算缩放比例
scale = min(1.0, (target_size / file_size) ** 0.5)
# 计算新的尺寸
new_width = int(original_width * scale)
new_height = int(original_height * scale)
# 创建内存缓冲区
output_buffer = io.BytesIO()
# 如果是GIF处理所有帧
if getattr(img, "is_animated", False):
frames = []
for frame_idx in range(img.n_frames):
img.seek(frame_idx)
new_frame = img.copy()
new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS)
frames.append(new_frame)
# 保存到缓冲区
frames[0].save(
output_buffer,
format='GIF',
save_all=True,
append_images=frames[1:],
optimize=True,
duration=img.info.get('duration', 100),
loop=img.info.get('loop', 0)
)
else:
# 处理静态图片
resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
# 保存到缓冲区,保持原始格式
if img.format == 'PNG' and img.mode in ('RGBA', 'LA'):
resized_img.save(output_buffer, format='PNG', optimize=True)
else:
resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True)
# 获取压缩后的数据并转换为base64
compressed_data = output_buffer.getvalue()
logger.success(f"压缩图片: {os.path.basename(image_path)} ({original_width}x{original_height} -> {new_width}x{new_height})")
return base64.b64encode(compressed_data).decode('utf-8')
except Exception as e:
logger.error(f"压缩图片失败: {os.path.basename(image_path)}, 错误: {str(e)}")
return None
async def scan_new_emojis(self): async def scan_new_emojis(self):
"""扫描新的表情包""" """扫描新的表情包"""
try: try:
@@ -305,22 +216,22 @@ class EmojiManager:
continue continue
# 压缩图片并获取base64编码 # 压缩图片并获取base64编码
image_base64 = await self._compress_image(image_path) image_base64 = image_path_to_base64(image_path)
if image_base64 is None: if image_base64 is None:
os.remove(image_path) os.remove(image_path)
continue continue
# 获取表情包的描述 # 获取表情包的描述
discription = await self._get_emoji_discription(image_base64) discription = await self._get_emoji_discription(image_base64)
check = await self._check_emoji(image_base64) if global_config.EMOJI_CHECK:
if '' not in check: check = await self._check_emoji(image_base64)
os.remove(image_path) if '' not in check:
logger.info(f"描述: {discription}") os.remove(image_path)
logger.info(f"其不满足过滤规则,被剔除 {check}") logger.info(f"描述: {discription}")
continue logger.info(f"其不满足过滤规则,被剔除 {check}")
logger.info(f"check通过 {check}") continue
tag = await self._get_emoji_tag(image_base64) logger.info(f"check通过 {check}")
embedding = get_embedding(discription) embedding = await get_embedding(discription)
if discription is not None: if discription is not None:
# 准备数据库记录 # 准备数据库记录
emoji_record = { emoji_record = {
@@ -328,7 +239,6 @@ class EmojiManager:
'path': image_path, 'path': image_path,
'embedding':embedding, 'embedding':embedding,
'discription': discription, 'discription': discription,
'tag':tag,
'timestamp': int(time.time()) 'timestamp': int(time.time())
} }

View File

@@ -27,58 +27,60 @@ class Message:
"""消息数据类""" """消息数据类"""
message_id: int = None message_id: int = None
time: float = None time: float = None
group_id: int = None group_id: int = None
group_name: str = None # 群名称 group_name: str = None # 群名称
user_id: int = None user_id: int = None
user_nickname: str = None # 用户昵称 user_nickname: str = None # 用户昵称
user_cardname: str=None # 用户群昵称 user_cardname: str = None # 用户群昵称
raw_message: str = None # 原始消息包含未解析的cq码 raw_message: str = None # 原始消息包含未解析的cq码
plain_text: str = None # 纯文本 plain_text: str = None # 纯文本
reply_message: Dict = None # 存储 回复的 源消息
# 延迟初始化字段
_initialized: bool = False
message_segments: List[Dict] = None # 存储解析后的消息片段 message_segments: List[Dict] = None # 存储解析后的消息片段
processed_plain_text: str = None # 用于存储处理后的plain_text processed_plain_text: str = None # 用于存储处理后的plain_text
detailed_plain_text: str = None # 用于存储详细可读文本 detailed_plain_text: str = None # 用于存储详细可读文本
reply_message: Dict = None # 存储 回复的 源消息 # 状态标志
is_emoji: bool = False
is_emoji: bool = False # 是否是表情包 has_emoji: bool = False
has_emoji: bool = False # 是否包含表情包 translate_cq: bool = True
translate_cq: bool = True # 是否翻译cq码 async def initialize(self):
"""显式异步初始化方法(必须调用)"""
def __post_init__(self): if self._initialized:
if self.time is None: return
self.time = int(time.time())
# 异步获取补充信息
if not self.group_name: self.group_name = self.group_name or get_groupname(self.group_id)
self.group_name = get_groupname(self.group_id) self.user_nickname = self.user_nickname or get_user_nickname(self.user_id)
self.user_cardname = self.user_cardname or get_user_cardname(self.user_id)
if not self.user_nickname:
self.user_nickname = get_user_nickname(self.user_id) # 消息解析
if self.raw_message:
if not self.user_cardname: self.message_segments = await self.parse_message_segments(self.raw_message)
self.user_cardname=get_user_cardname(self.user_id) self.processed_plain_text = ' '.join(
seg.translated_plain_text
if not self.processed_plain_text: for seg in self.message_segments
if self.raw_message: )
self.message_segments = self.parse_message_segments(str(self.raw_message))
self.processed_plain_text = ' '.join( # 构建详细文本
seg.translated_plain_text
for seg in self.message_segments
)
#将详细翻译为详细可读文本
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time)) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time))
try: name = (
name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})"
except: if self.user_cardname
name = self.user_nickname or f"用户{self.user_id}" else f"{self.user_nickname or f'用户{self.user_id}'}"
content = self.processed_plain_text )
self.detailed_plain_text = f"[{time_str}] {name}: {content}\n" self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n"
self._initialized = True
def parse_message_segments(self, message: str) -> List[CQCode]: async def parse_message_segments(self, message: str) -> List[CQCode]:
""" """
将消息解析为片段列表包括纯文本和CQ码 将消息解析为片段列表包括纯文本和CQ码
返回的列表中每个元素都是字典,包含: 返回的列表中每个元素都是字典,包含:
@@ -136,7 +138,7 @@ class Message:
#翻译作为字典的CQ码 #翻译作为字典的CQ码
for _code_item in cq_code_dict_list: for _code_item in cq_code_dict_list:
message_obj = cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message) message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message)
trans_list.append(message_obj) trans_list.append(message_obj)
return trans_list return trans_list

View File

@@ -2,7 +2,7 @@ import time
import random import random
from ..schedule.schedule_generator import bot_schedule from ..schedule.schedule_generator import bot_schedule
import os import os
from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text,find_similar_topics from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text
from ...common.database import Database from ...common.database import Database
from .config import global_config from .config import global_config
from .topic_identifier import topic_identifier from .topic_identifier import topic_identifier
@@ -60,7 +60,7 @@ class PromptBuilder:
prompt_info = '' prompt_info = ''
promt_info_prompt = '' promt_info_prompt = ''
prompt_info = self.get_prompt_info(message_txt,threshold=0.5) prompt_info = await self.get_prompt_info(message_txt,threshold=0.5)
if prompt_info: if prompt_info:
prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n''' prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n'''
@@ -215,10 +215,10 @@ class PromptBuilder:
return prompt_for_initiative return prompt_for_initiative
def get_prompt_info(self,message:str,threshold:float): async def get_prompt_info(self,message:str,threshold:float):
related_info = '' related_info = ''
print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}")
embedding = get_embedding(message) embedding = await get_embedding(message)
related_info += self.get_info_from_db(embedding,threshold=threshold) related_info += self.get_info_from_db(embedding,threshold=threshold)
return related_info return related_info

View File

@@ -33,16 +33,18 @@ def combine_messages(messages: List[Message]) -> str:
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time)) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time))
name = message.user_nickname or f"用户{message.user_id}" name = message.user_nickname or f"用户{message.user_id}"
content = message.processed_plain_text or message.plain_text content = message.processed_plain_text or message.plain_text
result += f"[{time_str}] {name}: {content}\n" result += f"[{time_str}] {name}: {content}\n"
return result return result
def db_message_to_str (message_dict: Dict) -> str:
def db_message_to_str(message_dict: Dict) -> str:
print(f"message_dict: {message_dict}") print(f"message_dict: {message_dict}")
time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"]))
try: try:
name="[(%s)%s]%s" % (message_dict['user_id'],message_dict.get("user_nickname", ""),message_dict.get("user_cardname", "")) name = "[(%s)%s]%s" % (
message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", ""))
except: except:
name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}"
content = message_dict.get("processed_plain_text", "") content = message_dict.get("processed_plain_text", "")
@@ -59,6 +61,7 @@ def is_mentioned_bot_in_message(message: Message) -> bool:
return True return True
return False return False
def is_mentioned_bot_in_txt(message: str) -> bool: def is_mentioned_bot_in_txt(message: str) -> bool:
"""检查消息是否提到了机器人""" """检查消息是否提到了机器人"""
keywords = [global_config.BOT_NICKNAME] keywords = [global_config.BOT_NICKNAME]
@@ -67,10 +70,13 @@ def is_mentioned_bot_in_txt(message: str) -> bool:
return True return True
return False return False
def get_embedding(text):
async def get_embedding(text):
"""获取文本的embedding向量""" """获取文本的embedding向量"""
llm = LLM_request(model=global_config.embedding) llm = LLM_request(model=global_config.embedding)
return llm.get_embedding_sync(text) # return llm.get_embedding_sync(text)
return await llm.get_embedding(text)
def cosine_similarity(v1, v2): def cosine_similarity(v1, v2):
dot_product = np.dot(v1, v2) dot_product = np.dot(v1, v2)
@@ -78,52 +84,55 @@ def cosine_similarity(v1, v2):
norm2 = np.linalg.norm(v2) norm2 = np.linalg.norm(v2)
return dot_product / (norm1 * norm2) return dot_product / (norm1 * norm2)
def calculate_information_content(text): def calculate_information_content(text):
"""计算文本的信息量(熵)""" """计算文本的信息量(熵)"""
char_count = Counter(text) char_count = Counter(text)
total_chars = len(text) total_chars = len(text)
entropy = 0 entropy = 0
for count in char_count.values(): for count in char_count.values():
probability = count / total_chars probability = count / total_chars
entropy -= probability * math.log2(probability) entropy -= probability * math.log2(probability)
return entropy return entropy
def get_cloest_chat_from_db(db, length: int, timestamp: str): def get_cloest_chat_from_db(db, length: int, timestamp: str):
"""从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数""" """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数"""
chat_text = '' chat_text = ''
closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)])
if closest_record and closest_record.get('memorized', 0) < 4: if closest_record and closest_record.get('memorized', 0) < 4:
closest_time = closest_record['time'] closest_time = closest_record['time']
group_id = closest_record['group_id'] # 获取groupid group_id = closest_record['group_id'] # 获取groupid
# 获取该时间戳之后的length条消息且groupid相同 # 获取该时间戳之后的length条消息且groupid相同
chat_records = list(db.db.messages.find( chat_records = list(db.db.messages.find(
{"time": {"$gt": closest_time}, "group_id": group_id} {"time": {"$gt": closest_time}, "group_id": group_id}
).sort('time', 1).limit(length)) ).sort('time', 1).limit(length))
# 更新每条消息的memorized属性 # 更新每条消息的memorized属性
for record in chat_records: for record in chat_records:
# 检查当前记录的memorized值 # 检查当前记录的memorized值
current_memorized = record.get('memorized', 0) current_memorized = record.get('memorized', 0)
if current_memorized > 3: if current_memorized > 3:
# print(f"消息已读取3次跳过") # print(f"消息已读取3次跳过")
return '' return ''
# 更新memorized值 # 更新memorized值
db.db.messages.update_one( db.db.messages.update_one(
{"_id": record["_id"]}, {"_id": record["_id"]},
{"$set": {"memorized": current_memorized + 1}} {"$set": {"memorized": current_memorized + 1}}
) )
chat_text += record["detailed_plain_text"] chat_text += record["detailed_plain_text"]
return chat_text return chat_text
# print(f"消息已读取3次跳过") # print(f"消息已读取3次跳过")
return '' return ''
def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
"""从数据库获取群组最近的消息记录 """从数据库获取群组最近的消息记录
Args: Args:
@@ -135,7 +144,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
list: Message对象列表按时间正序排列 list: Message对象列表按时间正序排列
""" """
# 从数据库获取最近消息 # 从数据库获取最近消息
recent_messages = list(db.db.messages.find( recent_messages = list(db.db.messages.find(
{"group_id": group_id}, {"group_id": group_id},
# { # {
@@ -150,7 +159,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
if not recent_messages: if not recent_messages:
return [] return []
# 转换为 Message对象列表 # 转换为 Message对象列表
from .message import Message from .message import Message
message_objects = [] message_objects = []
@@ -165,16 +174,18 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list:
processed_plain_text=msg_data.get("processed_text", ""), processed_plain_text=msg_data.get("processed_text", ""),
group_id=group_id group_id=group_id
) )
await msg.initialize()
message_objects.append(msg) message_objects.append(msg)
except KeyError: except KeyError:
print("[WARNING] 数据库中存在无效的消息") print("[WARNING] 数据库中存在无效的消息")
continue continue
# 按时间正序排列 # 按时间正序排列
message_objects.reverse() message_objects.reverse()
return message_objects return message_objects
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,combine = False):
def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False):
recent_messages = list(db.db.messages.find( recent_messages = list(db.db.messages.find(
{"group_id": group_id}, {"group_id": group_id},
{ {
@@ -188,16 +199,16 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
if not recent_messages: if not recent_messages:
return [] return []
message_detailed_plain_text = '' message_detailed_plain_text = ''
message_detailed_plain_text_list = [] message_detailed_plain_text_list = []
# 反转消息列表,使最新的消息在最后 # 反转消息列表,使最新的消息在最后
recent_messages.reverse() recent_messages.reverse()
if combine: if combine:
for msg_db_data in recent_messages: for msg_db_data in recent_messages:
message_detailed_plain_text+=str(msg_db_data["detailed_plain_text"]) message_detailed_plain_text += str(msg_db_data["detailed_plain_text"])
return message_detailed_plain_text return message_detailed_plain_text
else: else:
for msg_db_data in recent_messages: for msg_db_data in recent_messages:
@@ -205,7 +216,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb
return message_detailed_plain_text_list return message_detailed_plain_text_list
def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
"""将文本分割成句子,但保持书名号中的内容完整 """将文本分割成句子,但保持书名号中的内容完整
Args: Args:
@@ -225,30 +235,30 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
split_strength = 0.7 split_strength = 0.7
else: else:
split_strength = 0.9 split_strength = 0.9
#先移除换行符 # 先移除换行符
# print(f"split_strength: {split_strength}") # print(f"split_strength: {split_strength}")
# print(f"处理前的文本: {text}") # print(f"处理前的文本: {text}")
# 统一将英文逗号转换为中文逗号 # 统一将英文逗号转换为中文逗号
text = text.replace(',', '') text = text.replace(',', '')
text = text.replace('\n', ' ') text = text.replace('\n', ' ')
# print(f"处理前的文本: {text}") # print(f"处理前的文本: {text}")
text_no_1 = '' text_no_1 = ''
for letter in text: for letter in text:
# print(f"当前字符: {letter}") # print(f"当前字符: {letter}")
if letter in ['!','','?','']: if letter in ['!', '', '?', '']:
# print(f"当前字符: {letter}, 随机数: {random.random()}") # print(f"当前字符: {letter}, 随机数: {random.random()}")
if random.random() < split_strength: if random.random() < split_strength:
letter = '' letter = ''
if letter in ['','']: if letter in ['', '']:
# print(f"当前字符: {letter}, 随机数: {random.random()}") # print(f"当前字符: {letter}, 随机数: {random.random()}")
if random.random() < 1 - split_strength: if random.random() < 1 - split_strength:
letter = '' letter = ''
text_no_1 += letter text_no_1 += letter
# 对每个逗号单独判断是否分割 # 对每个逗号单独判断是否分割
sentences = [text_no_1] sentences = [text_no_1]
new_sentences = [] new_sentences = []
@@ -277,16 +287,17 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]:
sentences_done = [] sentences_done = []
for sentence in sentences: for sentence in sentences:
sentence = sentence.rstrip(',') sentence = sentence.rstrip(',')
if random.random() < split_strength*0.5: if random.random() < split_strength * 0.5:
sentence = sentence.replace('', '').replace(',', '') sentence = sentence.replace('', '').replace(',', '')
elif random.random() < split_strength: elif random.random() < split_strength:
sentence = sentence.replace('', ' ').replace(',', ' ') sentence = sentence.replace('', ' ').replace(',', ' ')
sentences_done.append(sentence) sentences_done.append(sentence)
print(f"处理后的句子: {sentences_done}") print(f"处理后的句子: {sentences_done}")
return sentences_done return sentences_done
def random_remove_punctuation(text: str) -> str: def random_remove_punctuation(text: str) -> str:
"""随机处理标点符号,模拟人类打字习惯 """随机处理标点符号,模拟人类打字习惯
@@ -298,7 +309,7 @@ def random_remove_punctuation(text: str) -> str:
""" """
result = '' result = ''
text_len = len(text) text_len = len(text)
for i, char in enumerate(text): for i, char in enumerate(text):
if char == '' and i == text_len - 1: # 结尾的句号 if char == '' and i == text_len - 1: # 结尾的句号
if random.random() > 0.4: # 80%概率删除结尾句号 if random.random() > 0.4: # 80%概率删除结尾句号
@@ -314,11 +325,12 @@ def random_remove_punctuation(text: str) -> str:
return result return result
def process_llm_response(text: str) -> List[str]: def process_llm_response(text: str) -> List[str]:
# processed_response = process_text_with_typos(content) # processed_response = process_text_with_typos(content)
if len(text) > 300: if len(text) > 300:
print(f"回复过长 ({len(text)} 字符),返回默认回复") print(f"回复过长 ({len(text)} 字符),返回默认回复")
return ['懒得说'] return ['懒得说']
# 处理长消息 # 处理长消息
typo_generator = ChineseTypoGenerator( typo_generator = ChineseTypoGenerator(
error_rate=0.03, error_rate=0.03,
@@ -332,9 +344,10 @@ def process_llm_response(text: str) -> List[str]:
if len(sentences) > 4: if len(sentences) > 4:
print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复") print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复")
return [f'{global_config.BOT_NICKNAME}不知道哦'] return [f'{global_config.BOT_NICKNAME}不知道哦']
return sentences return sentences
def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float: def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float:
""" """
计算输入字符串所需的时间,中文和英文字符有不同的输入时间 计算输入字符串所需的时间,中文和英文字符有不同的输入时间
@@ -347,32 +360,10 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_
if '\u4e00' <= char <= '\u9fff': # 判断是否为中文字符 if '\u4e00' <= char <= '\u9fff': # 判断是否为中文字符
total_time += chinese_time total_time += chinese_time
else: # 其他字符(如英文) else: # 其他字符(如英文)
total_time += english_time total_time += english_time
return total_time return total_time
def find_similar_topics(message_txt: str, all_memory_topic: list, top_k: int = 5) -> list:
"""使用重排序API找出与输入文本最相似的话题
Args:
message_txt: 输入文本
all_memory_topic: 所有记忆主题列表
top_k: 返回最相似的话题数量
Returns:
list: 最相似话题列表及其相似度分数
"""
if not all_memory_topic:
return []
try:
llm = LLM_request(model=global_config.rerank)
return llm.rerank_sync(message_txt, all_memory_topic, top_k)
except Exception as e:
print(f"重排序API调用出错: {str(e)}")
return []
def cosine_similarity(v1, v2): def cosine_similarity(v1, v2):
"""计算余弦相似度""" """计算余弦相似度"""
dot_product = np.dot(v1, v2) dot_product = np.dot(v1, v2)
@@ -382,6 +373,7 @@ def cosine_similarity(v1, v2):
return 0 return 0
return dot_product / (norm1 * norm2) return dot_product / (norm1 * norm2)
def text_to_vector(text): def text_to_vector(text):
"""将文本转换为词频向量""" """将文本转换为词频向量"""
# 分词 # 分词
@@ -390,11 +382,12 @@ def text_to_vector(text):
word_freq = Counter(words) word_freq = Counter(words)
return word_freq return word_freq
def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
"""使用简单的余弦相似度计算文本相似度""" """使用简单的余弦相似度计算文本相似度"""
# 将输入文本转换为词频向量 # 将输入文本转换为词频向量
text_vector = text_to_vector(text) text_vector = text_to_vector(text)
# 计算每个主题的相似度 # 计算每个主题的相似度
similarities = [] similarities = []
for topic in topics: for topic in topics:
@@ -407,6 +400,6 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list:
# 计算相似度 # 计算相似度
similarity = cosine_similarity(v1, v2) similarity = cosine_similarity(v1, v2)
similarities.append((topic, similarity)) similarities.append((topic, similarity))
# 按相似度降序排序并返回前k个 # 按相似度降序排序并返回前k个
return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k] return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k]

View File

@@ -4,6 +4,7 @@ import hashlib
import time import time
import os import os
from ...common.database import Database from ...common.database import Database
from ..chat.config import global_config
import zlib # 用于 CRC32 import zlib # 用于 CRC32
import base64 import base64
from nonebot import get_driver from nonebot import get_driver
@@ -143,6 +144,8 @@ def storage_emoji(image_data: bytes) -> bytes:
Returns: Returns:
bytes: 原始图片数据 bytes: 原始图片数据
""" """
if not global_config.EMOJI_SAVE:
return image_data
try: try:
# 使用 CRC32 计算哈希值 # 使用 CRC32 计算哈希值
hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x')
@@ -227,7 +230,7 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10
image_data = base64.b64decode(base64_data) image_data = base64.b64decode(base64_data)
# 如果已经小于目标大小,直接返回原图 # 如果已经小于目标大小,直接返回原图
if len(image_data) <= target_size: if len(image_data) <= 2*1024*1024:
return base64_data return base64_data
# 将字节数据转换为图片对象 # 将字节数据转换为图片对象
@@ -252,7 +255,7 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10
for frame_idx in range(img.n_frames): for frame_idx in range(img.n_frames):
img.seek(frame_idx) img.seek(frame_idx)
new_frame = img.copy() new_frame = img.copy()
new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS) new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折
frames.append(new_frame) frames.append(new_frame)
# 保存到缓冲区 # 保存到缓冲区
@@ -286,4 +289,19 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10
logger.error(f"压缩图片失败: {str(e)}") logger.error(f"压缩图片失败: {str(e)}")
import traceback import traceback
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
return base64_data return base64_data
def image_path_to_base64(image_path: str) -> str:
"""将图片路径转换为base64编码
Args:
image_path: 图片文件路径
Returns:
str: base64编码的图片数据
"""
try:
with open(image_path, 'rb') as f:
image_data = f.read()
return base64.b64encode(image_data).decode('utf-8')
except Exception as e:
logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}")
return None

View File

@@ -11,7 +11,7 @@ from ..chat.config import global_config
from ...common.database import Database # 使用正确的导入语法 from ...common.database import Database # 使用正确的导入语法
from ..models.utils_model import LLM_request from ..models.utils_model import LLM_request
import math import math
from ..chat.utils import calculate_information_content, get_cloest_chat_from_db ,find_similar_topics,text_to_vector,cosine_similarity from ..chat.utils import calculate_information_content, get_cloest_chat_from_db ,text_to_vector,cosine_similarity
@@ -673,7 +673,7 @@ class Hippocampus:
if first_layer: if first_layer:
# 如果记忆条数超过限制,随机选择指定数量的记忆 # 如果记忆条数超过限制,随机选择指定数量的记忆
if len(first_layer) > max_memory_num/2: if len(first_layer) > max_memory_num/2:
first_layer = random.sample(first_layer, max_memory_num) first_layer = random.sample(first_layer, max_memory_num//2)
# 为每条记忆添加来源主题和相似度信息 # 为每条记忆添加来源主题和相似度信息
for memory in first_layer: for memory in first_layer:
relevant_memories.append({ relevant_memories.append({

View File

@@ -25,603 +25,221 @@ class LLM_request:
self.model_name = model["name"] self.model_name = model["name"]
self.params = kwargs self.params = kwargs
async def generate_response(self, prompt: str) -> Tuple[str, str]: async def _execute_request(
"""根据输入的提示生成模型的异步响应""" self,
headers = { endpoint: str,
"Authorization": f"Bearer {self.api_key}", prompt: str = None,
"Content-Type": "application/json" image_base64: str = None,
payload: dict = None,
retry_policy: dict = None,
response_handler: callable = None,
):
"""统一请求执行入口
Args:
endpoint: API端点路径 (如 "chat/completions")
prompt: prompt文本
image_base64: 图片的base64编码
payload: 请求体数据
is_async: 是否异步
retry_policy: 自定义重试策略
(示例: {"max_retries":3, "base_wait":15, "retry_codes":[429,500]})
response_handler: 自定义响应处理器
"""
# 合并重试策略
default_retry = {
"max_retries": 3, "base_wait": 15,
"retry_codes": [429, 413, 500, 503],
"abort_codes": [400, 401, 402, 403]}
policy = {**default_retry, **(retry_policy or {})}
# 常见Error Code Mapping
error_code_mapping = {
400: "参数不正确",
401: "API key 错误,认证失败",
402: "账号余额不足",
403: "需要实名,或余额不足",
404: "Not Found",
429: "请求过于频繁,请稍后再试",
500: "服务器内部故障",
503: "服务器负载过高"
} }
api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}"
logger.info(f"发送请求到URL: {api_url}")
logger.info(f"使用模型: {self.model_name}")
# 构建请求体 # 构建请求体
data = { if image_base64:
"model": self.model_name, payload = await self._build_payload(prompt, image_base64)
"messages": [{"role": "user", "content": prompt}], elif payload is None:
**self.params payload = await self._build_payload(prompt)
}
# 发送请求到完整的chat/completions端点 for retry in range(policy["max_retries"]):
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"发送请求到URL: {api_url}/{self.model_name}") # 记录请求的URL
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try: try:
# 使用上下文管理器处理会话
headers = await self._build_headers()
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response: async with session.post(api_url, headers=headers, json=payload) as response:
if response.status == 429: # 处理需要重试的状态码
wait_time = base_wait_time * (2 ** retry) # 指数退避 if response.status in policy["retry_codes"]:
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") wait_time = policy["base_wait"] * (2 ** retry)
logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试")
if response.status == 413:
logger.warning("请求体过大,尝试压缩...")
image_base64 = compress_base64_image_by_scale(image_base64)
payload = await self._build_payload(prompt, image_base64)
elif response.status in [500, 503]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
else:
logger.warning(f"请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
continue continue
elif response.status in policy["abort_codes"]:
logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}")
raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}")
if response.status in [500, 503]: response.raise_for_status()
logger.error(f"服务器错误: {response.status}")
raise RuntimeError("服务器负载过高模型恢复失败QAQ")
response.raise_for_status() # 检查其他响应状态
result = await response.json() result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
message = result["choices"][0]["message"] # 使用自定义处理器或默认处理
content = message.get("content", "") return response_handler(result) if response_handler else self._default_response_handler(result)
think_match = None
reasoning_content = message.get("reasoning_content", "")
if not reasoning_content:
think_match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
return content, reasoning_content
return "没有返回结果", ""
except Exception as e: except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会 if retry < policy["max_retries"] - 1:
wait_time = base_wait_time * (2 ** retry) wait_time = policy["base_wait"] * (2 ** retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) logger.error(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time) await asyncio.sleep(wait_time)
else: else:
logger.critical(f"请求失败: {str(e)}", exc_info=True) logger.critical(f"请求失败: {str(e)}")
logger.critical(f"请求头: {headers} 请求体: {data}") logger.critical(f"请求头: {await self._build_headers()} 请求体: {payload}")
raise RuntimeError(f"API请求失败: {str(e)}") raise RuntimeError(f"API请求失败: {str(e)}")
logger.error("达到最大重试次数,请求仍然失败") logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败") raise RuntimeError("达到最大重试次数API请求仍然失败")
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: async def _build_payload(self, prompt: str, image_base64: str = None) -> dict:
"""根据输入的提示和图片生成模型的异步响应""" """构建请求体"""
headers = { if image_base64:
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
# 构建请求体
def build_request_data(img_base64: str):
return { return {
"model": self.model_name, "model": self.model_name,
"messages": [ "messages": [
{ {
"role": "user", "role": "user",
"content": [ "content": [
{ {"type": "text", "text": prompt},
"type": "text", {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{img_base64}"
}
}
] ]
} }
], ],
"max_tokens": global_config.max_response_length,
**self.params
}
else:
return {
"model": self.model_name,
"messages": [{"role": "user", "content": prompt}],
"max_tokens": global_config.max_response_length,
**self.params **self.params
} }
def _default_response_handler(self, result: dict) -> Tuple:
"""默认响应解析"""
if "choices" in result and result["choices"]:
message = result["choices"][0]["message"]
content = message.get("content", "")
content, reasoning = self._extract_reasoning(content)
reasoning_content = message.get("model_extra", {}).get("reasoning_content", "")
if not reasoning_content:
reasoning_content = reasoning
# 发送请求到完整的chat/completions端点 return content, reasoning_content
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"发送请求到URL: {api_url}/{self.model_name}") # 记录请求的URL
max_retries = 3 return "没有返回结果", ""
base_wait_time = 15
current_image_base64 = image_base64 def _extract_reasoning(self, content: str) -> tuple[str, str]:
current_image_base64 = compress_base64_image_by_scale(current_image_base64) """CoT思维链提取"""
match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
if match:
reasoning = match.group(1).strip()
else:
reasoning = ""
return content, reasoning
for retry in range(max_retries): async def _build_headers(self) -> dict:
try: """构建请求头"""
data = build_request_data(current_image_base64) return {
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
elif response.status == 413:
logger.warning("图片太大(413),尝试压缩...")
current_image_base64 = compress_base64_image_by_scale(current_image_base64)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
message = result["choices"][0]["message"]
content = message.get("content", "")
think_match = None
reasoning_content = message.get("reasoning_content", "")
if not reasoning_content:
think_match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
await asyncio.sleep(wait_time)
else:
logger.critical(f"请求失败: {str(e)}", exc_info=True)
logger.critical(f"请求头: {headers} 请求体: {data}")
raise RuntimeError(f"API请求失败: {str(e)}")
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败")
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""异步方式根据输入的提示生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}", "Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json" "Content-Type": "application/json"
} }
async def generate_response(self, prompt: str) -> Tuple[str, str]:
"""根据输入的提示生成模型的异步响应"""
content, reasoning_content = await self._execute_request(
endpoint="/chat/completions",
prompt=prompt
)
return content, reasoning_content
async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]:
"""根据输入的提示和图片生成模型的异步响应"""
content, reasoning_content = await self._execute_request(
endpoint="/chat/completions",
prompt=prompt,
image_base64=image_base64
)
return content, reasoning_content
async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]:
"""异步方式根据输入的提示生成模型的响应"""
# 构建请求体 # 构建请求体
data = { data = {
"model": self.model_name, "model": self.model_name,
"messages": [{"role": "user", "content": prompt}], "messages": [{"role": "user", "content": prompt}],
"temperature": 0.5, "temperature": 0.5,
"max_tokens": global_config.max_response_length,
**self.params **self.params
} }
# 发送请求到完整的 chat/completions 端点 content, reasoning_content = await self._execute_request(
api_url = f"{self.base_url.rstrip('/')}/chat/completions" endpoint="/chat/completions",
logger.info(f"Request URL: {api_url}") # 记录请求的 URL payload=data,
prompt=prompt
)
return content, reasoning_content
max_retries = 3 async def get_embedding(self, text: str) -> Union[list, None]:
base_wait_time = 15
async with aiohttp.ClientSession() as session:
for retry in range(max_retries):
try:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry) # 指数退避
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = await response.json()
if "choices" in result and len(result["choices"]) > 0:
message = result["choices"][0]["message"]
content = message.get("content", "")
think_match = None
reasoning_content = message.get("reasoning_content", "")
if not reasoning_content:
think_match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}")
await asyncio.sleep(wait_time)
else:
logger.error(f"请求失败: {str(e)}")
logger.critical(f"请求头: {headers} 请求体: {data}")
return f"请求失败: {str(e)}", ""
logger.error("达到最大重试次数,请求仍然失败")
return "达到最大重试次数,请求仍然失败", ""
def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]:
"""同步方法:根据输入的提示和图片生成模型的响应"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
image_base64=compress_base64_image_by_scale(image_base64)
# 构建请求体
data = {
"model": self.model_name,
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": prompt
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}"
}
}
]
}
],
**self.params
}
# 发送请求到完整的chat/completions端点
api_url = f"{self.base_url.rstrip('/')}/chat/completions"
logger.info(f"发送请求到URL: {api_url}/{self.model_name}") # 记录请求的URL
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry)
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status() # 检查其他响应状态
result = response.json()
if "choices" in result and len(result["choices"]) > 0:
message = result["choices"][0]["message"]
content = message.get("content", "")
think_match = None
reasoning_content = message.get("reasoning_content", "")
if not reasoning_content:
think_match = re.search(r'(?:<think>)?(.*?)</think>', content, re.DOTALL)
if think_match:
reasoning_content = think_match.group(1).strip()
content = re.sub(r'(?:<think>)?.*?</think>', '', content, flags=re.DOTALL, count=1).strip()
return content, reasoning_content
return "没有返回结果", ""
except Exception as e:
if retry < max_retries - 1: # 如果还有重试机会
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
time.sleep(wait_time)
else:
logger.critical(f"请求失败: {str(e)}", exc_info=True)
logger.critical(f"请求头: {headers} 请求体: {data}")
raise RuntimeError(f"API请求失败: {str(e)}")
logger.error("达到最大重试次数,请求仍然失败")
raise RuntimeError("达到最大重试次数API请求仍然失败")
def get_embedding_sync(self, text: str) -> Union[list, None]:
"""同步方法获取文本的embedding向量
Args:
text: 需要获取embedding的文本
Returns:
list: embedding向量如果失败则返回None
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model_name,
"input": text,
"encoding_format": "float"
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
logger.info(f"发送请求到URL: {api_url}/{self.model_name}") # 记录请求的URL
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry)
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
response.raise_for_status()
result = response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
time.sleep(wait_time)
else:
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
logger.critical(f"请求头: {headers} 请求体: {data}")
return None
logger.error("达到最大重试次数embedding请求仍然失败")
return None
async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]:
"""异步方法获取文本的embedding向量 """异步方法获取文本的embedding向量
Args: Args:
text: 需要获取embedding的文本 text: 需要获取embedding的文本
model: 使用的模型名称,默认为"BAAI/bge-m3"
Returns: Returns:
list: embedding向量如果失败则返回None list: embedding向量如果失败则返回None
""" """
headers = { def embedding_handler(result):
"Authorization": f"Bearer {self.api_key}", """处理响应"""
"Content-Type": "application/json" if "data" in result and len(result["data"]) > 0:
} return result["data"][0].get("embedding", None)
return None
data = { embedding = await self._execute_request(
"model": model, endpoint="/embeddings",
"input": text, prompt=text,
"encoding_format": "float" payload={
}
api_url = f"{self.base_url.rstrip('/')}/embeddings"
logger.info(f"发送请求到URL: {api_url}/{self.model_name}") # 记录请求的URL
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry)
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
response.raise_for_status()
result = await response.json()
if 'data' in result and len(result['data']) > 0:
return result['data'][0]['embedding']
return None
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
await asyncio.sleep(wait_time)
else:
logger.critical(f"embedding请求失败: {str(e)}", exc_info=True)
logger.critical(f"请求头: {headers} 请求体: {data}")
return None
logger.error("达到最大重试次数embedding请求仍然失败")
return None
def rerank_sync(self, query: str, documents: list, top_k: int = 5) -> list:
"""同步方法使用重排序API对文档进行排序
Args:
query: 查询文本
documents: 待排序的文档列表
top_k: 返回前k个结果
Returns:
list: [(document, score), ...] 格式的结果列表
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model_name,
"query": query,
"documents": documents,
"top_n": top_k,
"return_documents": True,
}
api_url = f"{self.base_url.rstrip('/')}/rerank"
logger.info(f"发送请求到URL: {api_url}")
max_retries = 2
base_wait_time = 6
for retry in range(max_retries):
try:
response = requests.post(api_url, headers=headers, json=data, timeout=30)
if response.status_code == 429:
wait_time = base_wait_time * (2 ** retry)
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
time.sleep(wait_time)
continue
if response.status_code in [500, 503]:
wait_time = base_wait_time * (2 ** retry)
logger.error(f"服务器错误({response.status_code}),等待{wait_time}秒后重试...")
if retry < max_retries - 1:
time.sleep(wait_time)
continue
else:
# 如果是最后一次重试尝试使用chat/completions作为备选方案
return self._fallback_rerank_with_chat(query, documents, top_k)
response.raise_for_status()
result = response.json()
if 'results' in result:
return [(item["document"], item["score"]) for item in result["results"]]
return []
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[rerank_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
time.sleep(wait_time)
else:
logger.critical(f"重排序请求失败: {str(e)}", exc_info=True)
logger.error("达到最大重试次数,重排序请求仍然失败")
return []
async def rerank(self, query: str, documents: list, top_k: int = 5) -> list:
"""异步方法使用重排序API对文档进行排序
Args:
query: 查询文本
documents: 待排序的文档列表
top_k: 返回前k个结果
Returns:
list: [(document, score), ...] 格式的结果列表
"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model_name,
"query": query,
"documents": documents,
"top_n": top_k,
"return_documents": True,
}
api_url = f"{self.base_url.rstrip('/')}/v1/rerank"
logger.info(f"发送请求到URL: {api_url}")
max_retries = 3
base_wait_time = 15
for retry in range(max_retries):
try:
async with aiohttp.ClientSession() as session:
async with session.post(api_url, headers=headers, json=data) as response:
if response.status == 429:
wait_time = base_wait_time * (2 ** retry)
logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...")
await asyncio.sleep(wait_time)
continue
if response.status in [500, 503]:
wait_time = base_wait_time * (2 ** retry)
logger.error(f"服务器错误({response.status}),等待{wait_time}秒后重试...")
if retry < max_retries - 1:
await asyncio.sleep(wait_time)
continue
else:
# 如果是最后一次重试尝试使用chat/completions作为备选方案
return await self._fallback_rerank_with_chat_async(query, documents, top_k)
response.raise_for_status()
result = await response.json()
if 'results' in result:
return [(item["document"], item["score"]) for item in result["results"]]
return []
except Exception as e:
if retry < max_retries - 1:
wait_time = base_wait_time * (2 ** retry)
logger.error(f"[rerank]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True)
await asyncio.sleep(wait_time)
else:
logger.critical(f"重排序请求失败: {str(e)}", exc_info=True)
# 作为最后的备选方案尝试使用chat/completions
return await self._fallback_rerank_with_chat_async(query, documents, top_k)
logger.error("达到最大重试次数,重排序请求仍然失败")
return []
async def _fallback_rerank_with_chat_async(self, query: str, documents: list, top_k: int = 5) -> list:
"""当rerank API失败时的备选方案使用chat/completions异步实现重排序
Args:
query: 查询文本
documents: 待排序的文档列表
top_k: 返回前k个结果
Returns:
list: [(document, score), ...] 格式的结果列表
"""
try:
logger.info("使用chat/completions作为重排序的备选方案")
# 构建提示词
prompt = f"""请对以下文档列表进行重排序,按照与查询的相关性从高到低排序。
查询: {query}
文档列表:
{documents}
请以JSON格式返回排序结果格式为
[{{"document": "文档内容", "score": 相关性分数}}, ...]
只返回JSON不要其他任何文字。"""
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
data = {
"model": self.model_name, "model": self.model_name,
"messages": [{"role": "user", "content": prompt}], "input": text,
**self.params "encoding_format": "float"
} },
retry_policy={
api_url = f"{self.base_url.rstrip('/')}/v1/chat/completions" "max_retries": 2,
"base_wait": 6
async with aiohttp.ClientSession() as session: },
async with session.post(api_url, headers=headers, json=data) as response: response_handler=embedding_handler
response.raise_for_status() )
result = await response.json() return embedding
if "choices" in result and len(result["choices"]) > 0:
message = result["choices"][0]["message"]
content = message.get("content", "")
try:
import json
parsed_content = json.loads(content)
if isinstance(parsed_content, list):
return [(item["document"], item["score"]) for item in parsed_content]
except:
pass
return []
except Exception as e:
logger.error(f"备选方案也失败了: {str(e)}")
return []

View File

@@ -20,6 +20,9 @@ ban_words = [
[emoji] [emoji]
check_interval = 120 # 检查表情包的时间间隔 check_interval = 120 # 检查表情包的时间间隔
register_interval = 10 # 注册表情包的时间间隔 register_interval = 10 # 注册表情包的时间间隔
auto_save = true # 自动偷表情包
enable_check = false # 是否启用表情包过滤
check_prompt = "符合公序良俗" # 表情包过滤要求
[cq_code] [cq_code]
enable_pic_translate = false enable_pic_translate = false
@@ -28,6 +31,7 @@ enable_pic_translate = false
model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率 model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率
model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率 model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率
model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率 model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率
max_response_length = 1024 # 麦麦回答的最大token数
[memory] [memory]
build_memory_interval = 300 # 记忆构建间隔 单位秒 build_memory_interval = 300 # 记忆构建间隔 单位秒