From 11807fda38f0d341fa19838149a959dee4310606 Mon Sep 17 00:00:00 2001 From: KawaiiYusora Date: Thu, 6 Mar 2025 23:50:14 +0800 Subject: [PATCH 01/11] =?UTF-8?q?refactor(models)=EF=BC=9A=E7=BB=9F?= =?UTF-8?q?=E4=B8=80=E8=AF=B7=E6=B1=82=E5=A4=84=E7=90=86=E5=B9=B6=E4=BC=98?= =?UTF-8?q?=E5=8C=96=E5=93=8D=E5=BA=94=E5=A4=84=E7=90=86=20(refactor/unifi?= =?UTF-8?q?ed=5Frequest)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 对 `utils_model.py` 中的请求处理逻辑进行重构,创建统一的请求执行方法 `_execute_request`。该方法集中处理请求构建、重试逻辑和响应处理,替代了 `generate_response`、`generate_response_for_image` 和 `generate_response_async` 中的冗余代码。 关键变更: - 引入 `_execute_request` 作为 API 请求的单一入口 - 新增支持自定义重试策略和响应处理器 - 通过 `_build_payload` 简化图像和文本载荷构建 - 改进错误处理和日志记录 - 移除已弃用的同步方法 - 加入了`max_response_length`以兼容koboldcpp硬编码的默认值500 此次重构在保持现有功能的同时提高了代码可维护性,减少了重复代码 --- config/bot_config_template.toml | 1 + src/plugins/chat/config.py | 3 + src/plugins/chat/cq_code.py | 24 +- src/plugins/chat/prompt_builder.py | 8 +- src/plugins/chat/utils.py | 122 +++-- src/plugins/memory_system/memory.py | 2 +- src/plugins/models/utils_model.py | 706 +++++++--------------------- 7 files changed, 243 insertions(+), 623 deletions(-) diff --git a/config/bot_config_template.toml b/config/bot_config_template.toml index 28ffb0ce3..f3582de12 100644 --- a/config/bot_config_template.toml +++ b/config/bot_config_template.toml @@ -28,6 +28,7 @@ enable_pic_translate = false model_r1_probability = 0.8 # 麦麦回答时选择R1模型的概率 model_v3_probability = 0.1 # 麦麦回答时选择V3模型的概率 model_r1_distill_probability = 0.1 # 麦麦回答时选择R1蒸馏模型的概率 +max_response_length = 1024 # 麦麦回答的最大token数 [memory] build_memory_interval = 300 # 记忆构建间隔 单位秒 diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index d5ee364ce..ba1ca0b71 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -32,6 +32,8 @@ class BotConfig: EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) ban_words = set() + + max_response_length: int = 1024 # 最大回复长度 # 模型配置 llm_reasoning: Dict[str, str] = field(default_factory=lambda: {}) @@ -113,6 +115,7 @@ class BotConfig: config.MODEL_R1_DISTILL_PROBABILITY = response_config.get("model_r1_distill_probability", config.MODEL_R1_DISTILL_PROBABILITY) config.API_USING = response_config.get("api_using", config.API_USING) config.API_PAID = response_config.get("api_paid", config.API_PAID) + config.max_response_length = response_config.get("max_response_length", config.max_response_length) # 加载模型配置 if "model" in toml_dict: diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index 4d70736cd..df93c6fa2 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -64,15 +64,15 @@ class CQCode: """初始化LLM实例""" self._llm = LLM_request(model=global_config.vlm, temperature=0.4, max_tokens=300) - def translate(self): + async def translate(self): """根据CQ码类型进行相应的翻译处理""" if self.type == 'text': self.translated_plain_text = self.params.get('text', '') elif self.type == 'image': if self.params.get('sub_type') == '0': - self.translated_plain_text = self.translate_image() + self.translated_plain_text = await self.translate_image() else: - self.translated_plain_text = self.translate_emoji() + self.translated_plain_text = await self.translate_emoji() elif self.type == 'at': user_nickname = get_user_nickname(self.params.get('qq', '')) if user_nickname: @@ -158,7 +158,7 @@ class CQCode: return None - def translate_emoji(self) -> str: + async def translate_emoji(self) -> str: """处理表情包类型的CQ码""" if 'url' not in self.params: return '[表情包]' @@ -167,12 +167,12 @@ class CQCode: # 将 base64 字符串转换为字节类型 image_bytes = base64.b64decode(base64_str) storage_emoji(image_bytes) - return self.get_emoji_description(base64_str) + return await self.get_emoji_description(base64_str) else: return '[表情包]' - def translate_image(self) -> str: + async def translate_image(self) -> str: """处理图片类型的CQ码,区分普通图片和表情包""" #没有url,直接返回默认文本 if 'url' not in self.params: @@ -181,25 +181,27 @@ class CQCode: if base64_str: image_bytes = base64.b64decode(base64_str) storage_image(image_bytes) - return self.get_image_description(base64_str) + return await self.get_image_description(base64_str) else: return '[图片]' - def get_emoji_description(self, image_base64: str) -> str: + async def get_emoji_description(self, image_base64: str) -> str: """调用AI接口获取表情包描述""" try: prompt = "这是一个表情包,请用简短的中文描述这个表情包传达的情感和含义。最多20个字。" - description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + # description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + description, _ = await self._llm.generate_response_for_image(prompt, image_base64) return f"[表情包:{description}]" except Exception as e: print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") return "[表情包]" - def get_image_description(self, image_base64: str) -> str: + async def get_image_description(self, image_base64: str) -> str: """调用AI接口获取普通图片描述""" try: prompt = "请用中文描述这张图片的内容。如果有文字,请把文字都描述出来。并尝试猜测这个图片的含义。最多200个字。" - description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + # description, _ = self._llm.generate_response_for_image_sync(prompt, image_base64) + description, _ = await self._llm.generate_response_for_image(prompt, image_base64) return f"[图片:{description}]" except Exception as e: print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") diff --git a/src/plugins/chat/prompt_builder.py b/src/plugins/chat/prompt_builder.py index 1c510e251..1c1431577 100644 --- a/src/plugins/chat/prompt_builder.py +++ b/src/plugins/chat/prompt_builder.py @@ -2,7 +2,7 @@ import time import random from ..schedule.schedule_generator import bot_schedule import os -from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text,find_similar_topics +from .utils import get_embedding, combine_messages, get_recent_group_detailed_plain_text from ...common.database import Database from .config import global_config from .topic_identifier import topic_identifier @@ -60,7 +60,7 @@ class PromptBuilder: prompt_info = '' promt_info_prompt = '' - prompt_info = self.get_prompt_info(message_txt,threshold=0.5) + prompt_info = await self.get_prompt_info(message_txt,threshold=0.5) if prompt_info: prompt_info = f'''\n----------------------------------------------------\n你有以下这些[知识]:\n{prompt_info}\n请你记住上面的[知识],之后可能会用到\n----------------------------------------------------\n''' @@ -214,10 +214,10 @@ class PromptBuilder: return prompt_for_initiative - def get_prompt_info(self,message:str,threshold:float): + async def get_prompt_info(self,message:str,threshold:float): related_info = '' print(f"\033[1;34m[调试]\033[0m 获取知识库内容,元消息:{message[:30]}...,消息长度: {len(message)}") - embedding = get_embedding(message) + embedding = await get_embedding(message) related_info += self.get_info_from_db(embedding,threshold=threshold) return related_info diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 63daf6680..38aeefd21 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -32,16 +32,18 @@ def combine_messages(messages: List[Message]) -> str: time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message.time)) name = message.user_nickname or f"用户{message.user_id}" content = message.processed_plain_text or message.plain_text - + result += f"[{time_str}] {name}: {content}\n" - + return result -def db_message_to_str (message_dict: Dict) -> str: + +def db_message_to_str(message_dict: Dict) -> str: print(f"message_dict: {message_dict}") time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(message_dict["time"])) try: - name="[(%s)%s]%s" % (message_dict['user_id'],message_dict.get("user_nickname", ""),message_dict.get("user_cardname", "")) + name = "[(%s)%s]%s" % ( + message_dict['user_id'], message_dict.get("user_nickname", ""), message_dict.get("user_cardname", "")) except: name = message_dict.get("user_nickname", "") or f"用户{message_dict['user_id']}" content = message_dict.get("processed_plain_text", "") @@ -58,6 +60,7 @@ def is_mentioned_bot_in_message(message: Message) -> bool: return True return False + def is_mentioned_bot_in_txt(message: str) -> bool: """检查消息是否提到了机器人""" keywords = [global_config.BOT_NICKNAME] @@ -66,10 +69,13 @@ def is_mentioned_bot_in_txt(message: str) -> bool: return True return False -def get_embedding(text): + +async def get_embedding(text): """获取文本的embedding向量""" llm = LLM_request(model=global_config.embedding) - return llm.get_embedding_sync(text) + # return llm.get_embedding_sync(text) + return await llm.get_embedding(text) + def cosine_similarity(v1, v2): dot_product = np.dot(v1, v2) @@ -77,51 +83,54 @@ def cosine_similarity(v1, v2): norm2 = np.linalg.norm(v2) return dot_product / (norm1 * norm2) + def calculate_information_content(text): """计算文本的信息量(熵)""" char_count = Counter(text) total_chars = len(text) - + entropy = 0 for count in char_count.values(): probability = count / total_chars entropy -= probability * math.log2(probability) - + return entropy + def get_cloest_chat_from_db(db, length: int, timestamp: str): """从数据库中获取最接近指定时间戳的聊天记录,并记录读取次数""" chat_text = '' closest_record = db.db.messages.find_one({"time": {"$lte": timestamp}}, sort=[('time', -1)]) - - if closest_record and closest_record.get('memorized', 0) < 4: + + if closest_record and closest_record.get('memorized', 0) < 4: closest_time = closest_record['time'] group_id = closest_record['group_id'] # 获取groupid # 获取该时间戳之后的length条消息,且groupid相同 chat_records = list(db.db.messages.find( {"time": {"$gt": closest_time}, "group_id": group_id} ).sort('time', 1).limit(length)) - + # 更新每条消息的memorized属性 for record in chat_records: # 检查当前记录的memorized值 current_memorized = record.get('memorized', 0) - if current_memorized > 3: + if current_memorized > 3: # print(f"消息已读取3次,跳过") return '' - + # 更新memorized值 db.db.messages.update_one( {"_id": record["_id"]}, {"$set": {"memorized": current_memorized + 1}} ) - + chat_text += record["detailed_plain_text"] - + return chat_text # print(f"消息已读取3次,跳过") return '' + def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: """从数据库获取群组最近的消息记录 @@ -134,7 +143,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: list: Message对象列表,按时间正序排列 """ - # 从数据库获取最近消息 + # 从数据库获取最近消息 recent_messages = list(db.db.messages.find( {"group_id": group_id}, # { @@ -149,7 +158,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: if not recent_messages: return [] - + # 转换为 Message对象列表 from .message import Message message_objects = [] @@ -168,12 +177,13 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: except KeyError: print("[WARNING] 数据库中存在无效的消息") continue - + # 按时间正序排列 message_objects.reverse() return message_objects -def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,combine = False): + +def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12, combine=False): recent_messages = list(db.db.messages.find( {"group_id": group_id}, { @@ -187,16 +197,16 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb if not recent_messages: return [] - + message_detailed_plain_text = '' message_detailed_plain_text_list = [] - + # 反转消息列表,使最新的消息在最后 recent_messages.reverse() - + if combine: for msg_db_data in recent_messages: - message_detailed_plain_text+=str(msg_db_data["detailed_plain_text"]) + message_detailed_plain_text += str(msg_db_data["detailed_plain_text"]) return message_detailed_plain_text else: for msg_db_data in recent_messages: @@ -204,7 +214,6 @@ def get_recent_group_detailed_plain_text(db, group_id: int, limit: int = 12,comb return message_detailed_plain_text_list - def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: """将文本分割成句子,但保持书名号中的内容完整 Args: @@ -224,30 +233,30 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: split_strength = 0.7 else: split_strength = 0.9 - #先移除换行符 + # 先移除换行符 # print(f"split_strength: {split_strength}") - + # print(f"处理前的文本: {text}") - + # 统一将英文逗号转换为中文逗号 text = text.replace(',', ',') text = text.replace('\n', ' ') - + # print(f"处理前的文本: {text}") - + text_no_1 = '' for letter in text: # print(f"当前字符: {letter}") - if letter in ['!','!','?','?']: + if letter in ['!', '!', '?', '?']: # print(f"当前字符: {letter}, 随机数: {random.random()}") if random.random() < split_strength: letter = '' - if letter in ['。','…']: + if letter in ['。', '…']: # print(f"当前字符: {letter}, 随机数: {random.random()}") if random.random() < 1 - split_strength: letter = '' text_no_1 += letter - + # 对每个逗号单独判断是否分割 sentences = [text_no_1] new_sentences = [] @@ -276,15 +285,16 @@ def split_into_sentences_w_remove_punctuation(text: str) -> List[str]: sentences_done = [] for sentence in sentences: sentence = sentence.rstrip(',,') - if random.random() < split_strength*0.5: + if random.random() < split_strength * 0.5: sentence = sentence.replace(',', '').replace(',', '') elif random.random() < split_strength: sentence = sentence.replace(',', ' ').replace(',', ' ') sentences_done.append(sentence) - + print(f"处理后的句子: {sentences_done}") return sentences_done + # 常见的错别字映射 TYPO_DICT = { '的': '地得', @@ -355,6 +365,7 @@ TYPO_DICT = { '嘻': '嘻西希' } + def random_remove_punctuation(text: str) -> str: """随机处理标点符号,模拟人类打字习惯 @@ -366,7 +377,7 @@ def random_remove_punctuation(text: str) -> str: """ result = '' text_len = len(text) - + for i, char in enumerate(text): if char == '。' and i == text_len - 1: # 结尾的句号 if random.random() > 0.4: # 80%概率删除结尾句号 @@ -381,6 +392,7 @@ def random_remove_punctuation(text: str) -> str: result += char return result + def add_typos(text: str) -> str: TYPO_RATE = 0.02 # 控制错别字出现的概率(2%) result = "" @@ -393,20 +405,22 @@ def add_typos(text: str) -> str: result += char return result + def process_llm_response(text: str) -> List[str]: # processed_response = process_text_with_typos(content) if len(text) > 300: - print(f"回复过长 ({len(text)} 字符),返回默认回复") - return ['懒得说'] + print(f"回复过长 ({len(text)} 字符),返回默认回复") + return ['懒得说'] # 处理长消息 sentences = split_into_sentences_w_remove_punctuation(add_typos(text)) # 检查分割后的消息数量是否过多(超过3条) if len(sentences) > 4: print(f"分割后消息数量过多 ({len(sentences)} 条),返回默认回复") return [f'{global_config.BOT_NICKNAME}不知道哦'] - + return sentences + def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_time: float = 0.1) -> float: """ 计算输入字符串所需的时间,中文和英文字符有不同的输入时间 @@ -419,32 +433,10 @@ def calculate_typing_time(input_string: str, chinese_time: float = 0.2, english_ if '\u4e00' <= char <= '\u9fff': # 判断是否为中文字符 total_time += chinese_time else: # 其他字符(如英文) - total_time += english_time + total_time += english_time return total_time -def find_similar_topics(message_txt: str, all_memory_topic: list, top_k: int = 5) -> list: - """使用重排序API找出与输入文本最相似的话题 - - Args: - message_txt: 输入文本 - all_memory_topic: 所有记忆主题列表 - top_k: 返回最相似的话题数量 - - Returns: - list: 最相似话题列表及其相似度分数 - """ - - if not all_memory_topic: - return [] - - try: - llm = LLM_request(model=global_config.rerank) - return llm.rerank_sync(message_txt, all_memory_topic, top_k) - except Exception as e: - print(f"重排序API调用出错: {str(e)}") - return [] - def cosine_similarity(v1, v2): """计算余弦相似度""" dot_product = np.dot(v1, v2) @@ -454,6 +446,7 @@ def cosine_similarity(v1, v2): return 0 return dot_product / (norm1 * norm2) + def text_to_vector(text): """将文本转换为词频向量""" # 分词 @@ -462,11 +455,12 @@ def text_to_vector(text): word_freq = Counter(words) return word_freq + def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: """使用简单的余弦相似度计算文本相似度""" # 将输入文本转换为词频向量 text_vector = text_to_vector(text) - + # 计算每个主题的相似度 similarities = [] for topic in topics: @@ -479,6 +473,6 @@ def find_similar_topics_simple(text: str, topics: list, top_k: int = 5) -> list: # 计算相似度 similarity = cosine_similarity(v1, v2) similarities.append((topic, similarity)) - + # 按相似度降序排序并返回前k个 - return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k] \ No newline at end of file + return sorted(similarities, key=lambda x: x[1], reverse=True)[:top_k] diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index cdb6e6e1b..43db3729d 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -11,7 +11,7 @@ from ..chat.config import global_config from ...common.database import Database # 使用正确的导入语法 from ..models.utils_model import LLM_request import math -from ..chat.utils import calculate_information_content, get_cloest_chat_from_db ,find_similar_topics,text_to_vector,cosine_similarity +from ..chat.utils import calculate_information_content, get_cloest_chat_from_db ,text_to_vector,cosine_similarity diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 2801a3553..3e4d7f1a2 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -25,354 +25,195 @@ class LLM_request: self.model_name = model["name"] self.params = kwargs - async def generate_response(self, prompt: str) -> Tuple[str, str]: - """根据输入的提示生成模型的异步响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" + async def _execute_request( + self, + endpoint: str, + prompt: str = None, + image_base64: str = None, + payload: dict = None, + retry_policy: dict = None, + response_handler: callable = None, + ): + """统一请求执行入口 + Args: + endpoint: API端点路径 (如 "chat/completions") + prompt: prompt文本 + image_base64: 图片的base64编码 + payload: 请求体数据 + is_async: 是否异步 + retry_policy: 自定义重试策略 + (示例: {"max_retries":3, "base_wait":15, "retry_codes":[429,500]}) + response_handler: 自定义响应处理器 + """ + # 合并重试策略 + default_retry = { + "max_retries": 3, "base_wait": 15, + "retry_codes": [429, 413, 500, 503], + "abort_codes": [400, 401, 402, 403]} + policy = {**default_retry, **(retry_policy or {})} + + # 常见Error Code Mapping + error_code_mapping = { + 400: "参数不正确", + 401: "API key 错误,认证失败", + 402: "账号余额不足", + 403: "需要实名,或余额不足", + 404: "Not Found", + 429: "请求过于频繁,请稍后再试", + 500: "服务器内部故障", + 503: "服务器负载过高" } + api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" + logger.info(f"发送请求到URL: {api_url}{self.model_name}") + # 构建请求体 - data = { - "model": self.model_name, - "messages": [{"role": "user", "content": prompt}], - **self.params - } + if image_base64: + payload = await self._build_payload(prompt, image_base64) + elif payload is None: + payload = await self._build_payload(prompt) - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL + session_method = aiohttp.ClientSession() - max_retries = 3 - base_wait_time = 15 - - for retry in range(max_retries): + for retry in range(policy["max_retries"]): try: - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue + # 使用上下文管理器处理会话 + headers = await self._build_headers() - if response.status in [500, 503]: - logger.error(f"服务器错误: {response.status}") - raise RuntimeError("服务器负载过高,模型恢复失败QAQ") + async with session_method as session: + response = await session.post(api_url, headers=headers, json=payload) - response.raise_for_status() # 检查其他响应状态 + # 处理需要重试的状态码 + if response.status in policy["retry_codes"]: + wait_time = policy["base_wait"] * (2 ** retry) + logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试") + if response.status == 413: + logger.warning("请求体过大,尝试压缩...") + image_base64 = compress_base64_image_by_scale(image_base64) + payload = await self._build_payload(prompt, image_base64) + elif response.status in [500, 503]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError("服务器负载过高,模型恢复失败QAQ") + else: + logger.warning(f"请求限制(429),等待{wait_time}秒后重试...") - result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(?:)?(.*?)', content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r'(?:)?.*?', '', content, flags=re.DOTALL, count=1).strip() - return content, reasoning_content - return "没有返回结果", "" + await asyncio.sleep(wait_time) + continue + elif response.status in policy["abort_codes"]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") + + response.raise_for_status() + result = await response.json() + + # 使用自定义处理器或默认处理 + return response_handler(result) if response_handler else self._default_response_handler(result) except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) + if retry < policy["max_retries"] - 1: + wait_time = policy["base_wait"] * (2 ** retry) + logger.error(f"请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") await asyncio.sleep(wait_time) else: - logger.critical(f"请求失败: {str(e)}", exc_info=True) - logger.critical(f"请求头: {headers} 请求体: {data}") + logger.critical(f"请求失败: {str(e)}") + logger.critical(f"请求头: {self._build_headers()} 请求体: {payload}") raise RuntimeError(f"API请求失败: {str(e)}") logger.error("达到最大重试次数,请求仍然失败") raise RuntimeError("达到最大重试次数,API请求仍然失败") - async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: - """根据输入的提示和图片生成模型的异步响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - # 构建请求体 - def build_request_data(img_base64: str): + async def _build_payload(self, prompt: str, image_base64: str = None) -> dict: + """构建请求体""" + if image_base64: return { "model": self.model_name, "messages": [ { "role": "user", "content": [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{img_base64}" - } - } + {"type": "text", "text": prompt}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}} ] } ], + "max_tokens": global_config.max_response_length, + **self.params + } + else: + return { + "model": self.model_name, + "messages": [{"role": "user", "content": prompt}], + "max_tokens": global_config.max_response_length, **self.params } + def _default_response_handler(self, result: dict) -> Tuple: + """默认响应解析""" + if "choices" in result and result["choices"]: + message = result["choices"][0]["message"] + content = message.get("content", "") + content, reasoning = self._extract_reasoning(content) + reasoning_content = message.get("model_extra", {}).get("reasoning_content", "") + if not reasoning_content: + reasoning_content = reasoning - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL + return content, reasoning_content - max_retries = 3 - base_wait_time = 15 + return "没有返回结果", "" - current_image_base64 = image_base64 - current_image_base64 = compress_base64_image_by_scale(current_image_base64) + def _extract_reasoning(self, content: str) -> tuple[str, str]: + """CoT思维链提取""" + match = re.search(r'(?:)?(.*?)', content, re.DOTALL) + content = re.sub(r'(?:)?.*?', '', content, flags=re.DOTALL, count=1).strip() + if match: + reasoning = match.group(1).strip() + else: + reasoning = "" + return content, reasoning - for retry in range(max_retries): - try: - data = build_request_data(current_image_base64) - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue - - elif response.status == 413: - logger.warning("图片太大(413),尝试压缩...") - current_image_base64 = compress_base64_image_by_scale(current_image_base64) - continue - - response.raise_for_status() # 检查其他响应状态 - - result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(?:)?(.*?)', content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r'(?:)?.*?', '', content, flags=re.DOTALL, count=1).strip() - return content, reasoning_content - return "没有返回结果", "" - - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[image回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - await asyncio.sleep(wait_time) - else: - logger.critical(f"请求失败: {str(e)}", exc_info=True) - logger.critical(f"请求头: {headers} 请求体: {data}") - raise RuntimeError(f"API请求失败: {str(e)}") - - logger.error("达到最大重试次数,请求仍然失败") - raise RuntimeError("达到最大重试次数,API请求仍然失败") - - async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]: - """异步方式根据输入的提示生成模型的响应""" - headers = { + async def _build_headers(self) -> dict: + """构建请求头""" + return { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } + async def generate_response(self, prompt: str) -> Tuple[str, str]: + """根据输入的提示生成模型的异步响应""" + + content, reasoning_content = await self._execute_request( + endpoint="/chat/completions", + prompt=prompt + ) + return content, reasoning_content + + async def generate_response_for_image(self, prompt: str, image_base64: str) -> Tuple[str, str]: + """根据输入的提示和图片生成模型的异步响应""" + + content, reasoning_content = await self._execute_request( + endpoint="/chat/completions", + prompt=prompt, + image_base64=image_base64 + ) + return content, reasoning_content + + async def generate_response_async(self, prompt: str) -> Union[str, Tuple[str, str]]: + """异步方式根据输入的提示生成模型的响应""" # 构建请求体 data = { "model": self.model_name, "messages": [{"role": "user", "content": prompt}], "temperature": 0.5, + "max_tokens": global_config.max_response_length, **self.params } - # 发送请求到完整的 chat/completions 端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"Request URL: {api_url}") # 记录请求的 URL - - max_retries = 3 - base_wait_time = 15 - - async with aiohttp.ClientSession() as session: - for retry in range(max_retries): - try: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) # 指数退避 - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue - - response.raise_for_status() # 检查其他响应状态 - - result = await response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(?:)?(.*?)', content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r'(?:)?.*?', '', content, flags=re.DOTALL, count=1).strip() - return content, reasoning_content - return "没有返回结果", "" - - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}") - await asyncio.sleep(wait_time) - else: - logger.error(f"请求失败: {str(e)}") - logger.critical(f"请求头: {headers} 请求体: {data}") - return f"请求失败: {str(e)}", "" - - logger.error("达到最大重试次数,请求仍然失败") - return "达到最大重试次数,请求仍然失败", "" - - - - def generate_response_for_image_sync(self, prompt: str, image_base64: str) -> Tuple[str, str]: - """同步方法:根据输入的提示和图片生成模型的响应""" - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - image_base64=compress_base64_image_by_scale(image_base64) - - # 构建请求体 - data = { - "model": self.model_name, - "messages": [ - { - "role": "user", - "content": [ - { - "type": "text", - "text": prompt - }, - { - "type": "image_url", - "image_url": { - "url": f"data:image/jpeg;base64,{image_base64}" - } - } - ] - } - ], - **self.params - } - - # 发送请求到完整的chat/completions端点 - api_url = f"{self.base_url.rstrip('/')}/chat/completions" - logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL - - max_retries = 2 - base_wait_time = 6 - - for retry in range(max_retries): - try: - response = requests.post(api_url, headers=headers, json=data, timeout=30) - - if response.status_code == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - time.sleep(wait_time) - continue - - response.raise_for_status() # 检查其他响应状态 - - result = response.json() - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - think_match = None - reasoning_content = message.get("reasoning_content", "") - if not reasoning_content: - think_match = re.search(r'(?:)?(.*?)', content, re.DOTALL) - if think_match: - reasoning_content = think_match.group(1).strip() - content = re.sub(r'(?:)?.*?', '', content, flags=re.DOTALL, count=1).strip() - return content, reasoning_content - return "没有返回结果", "" - - except Exception as e: - if retry < max_retries - 1: # 如果还有重试机会 - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[image_sync回复]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - time.sleep(wait_time) - else: - logger.critical(f"请求失败: {str(e)}", exc_info=True) - logger.critical(f"请求头: {headers} 请求体: {data}") - raise RuntimeError(f"API请求失败: {str(e)}") - - logger.error("达到最大重试次数,请求仍然失败") - raise RuntimeError("达到最大重试次数,API请求仍然失败") - - def get_embedding_sync(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]: - """同步方法:获取文本的embedding向量 - - Args: - text: 需要获取embedding的文本 - model: 使用的模型名称,默认为"BAAI/bge-m3" - - Returns: - list: embedding向量,如果失败则返回None - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - data = { - "model": model, - "input": text, - "encoding_format": "float" - } - - api_url = f"{self.base_url.rstrip('/')}/embeddings" - logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL - - max_retries = 2 - base_wait_time = 6 - - for retry in range(max_retries): - try: - response = requests.post(api_url, headers=headers, json=data, timeout=30) - - if response.status_code == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - time.sleep(wait_time) - continue - - response.raise_for_status() - - result = response.json() - if 'data' in result and len(result['data']) > 0: - return result['data'][0]['embedding'] - return None - - except Exception as e: - if retry < max_retries - 1: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[embedding_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - time.sleep(wait_time) - else: - logger.critical(f"embedding请求失败: {str(e)}", exc_info=True) - logger.critical(f"请求头: {headers} 请求体: {data}") - return None - - logger.error("达到最大重试次数,embedding请求仍然失败") - return None + content, reasoning_content = await self._execute_request( + endpoint="/chat/completions", + payload=data, + prompt=prompt + ) + return content, reasoning_content async def get_embedding(self, text: str, model: str = "BAAI/bge-m3") -> Union[list, None]: """异步方法:获取文本的embedding向量 @@ -384,245 +225,24 @@ class LLM_request: Returns: list: embedding向量,如果失败则返回None """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } + def embedding_handler(result): + """处理响应""" + if "data" in result and len(result["data"]) > 0: + return result["data"][0].get("embedding", None) + return None - data = { - "model": model, - "input": text, - "encoding_format": "float" - } - - api_url = f"{self.base_url.rstrip('/')}/embeddings" - logger.info(f"发送请求到URL: {api_url}{self.model_name}") # 记录请求的URL - - max_retries = 3 - base_wait_time = 15 - - for retry in range(max_retries): - try: - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue - - response.raise_for_status() - - result = await response.json() - if 'data' in result and len(result['data']) > 0: - return result['data'][0]['embedding'] - return None - - except Exception as e: - if retry < max_retries - 1: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[embedding]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - await asyncio.sleep(wait_time) - else: - logger.critical(f"embedding请求失败: {str(e)}", exc_info=True) - logger.critical(f"请求头: {headers} 请求体: {data}") - return None - - logger.error("达到最大重试次数,embedding请求仍然失败") - return None - - def rerank_sync(self, query: str, documents: list, top_k: int = 5) -> list: - """同步方法:使用重排序API对文档进行排序 - - Args: - query: 查询文本 - documents: 待排序的文档列表 - top_k: 返回前k个结果 - - Returns: - list: [(document, score), ...] 格式的结果列表 - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - data = { - "model": self.model_name, - "query": query, - "documents": documents, - "top_n": top_k, - "return_documents": True, - } - - api_url = f"{self.base_url.rstrip('/')}/rerank" - logger.info(f"发送请求到URL: {api_url}") - - max_retries = 2 - base_wait_time = 6 - - for retry in range(max_retries): - try: - response = requests.post(api_url, headers=headers, json=data, timeout=30) - - if response.status_code == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - time.sleep(wait_time) - continue - - if response.status_code in [500, 503]: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"服务器错误({response.status_code}),等待{wait_time}秒后重试...") - if retry < max_retries - 1: - time.sleep(wait_time) - continue - else: - # 如果是最后一次重试,尝试使用chat/completions作为备选方案 - return self._fallback_rerank_with_chat(query, documents, top_k) - - response.raise_for_status() - - result = response.json() - if 'results' in result: - return [(item["document"], item["score"]) for item in result["results"]] - return [] - - except Exception as e: - if retry < max_retries - 1: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[rerank_sync]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - time.sleep(wait_time) - else: - logger.critical(f"重排序请求失败: {str(e)}", exc_info=True) - - logger.error("达到最大重试次数,重排序请求仍然失败") - return [] - - async def rerank(self, query: str, documents: list, top_k: int = 5) -> list: - """异步方法:使用重排序API对文档进行排序 - - Args: - query: 查询文本 - documents: 待排序的文档列表 - top_k: 返回前k个结果 - - Returns: - list: [(document, score), ...] 格式的结果列表 - """ - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - data = { - "model": self.model_name, - "query": query, - "documents": documents, - "top_n": top_k, - "return_documents": True, - } - - api_url = f"{self.base_url.rstrip('/')}/v1/rerank" - logger.info(f"发送请求到URL: {api_url}") - - max_retries = 3 - base_wait_time = 15 - - for retry in range(max_retries): - try: - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - if response.status == 429: - wait_time = base_wait_time * (2 ** retry) - logger.warning(f"遇到请求限制(429),等待{wait_time}秒后重试...") - await asyncio.sleep(wait_time) - continue - - if response.status in [500, 503]: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"服务器错误({response.status}),等待{wait_time}秒后重试...") - if retry < max_retries - 1: - await asyncio.sleep(wait_time) - continue - else: - # 如果是最后一次重试,尝试使用chat/completions作为备选方案 - return await self._fallback_rerank_with_chat_async(query, documents, top_k) - - response.raise_for_status() - - result = await response.json() - if 'results' in result: - return [(item["document"], item["score"]) for item in result["results"]] - return [] - - except Exception as e: - if retry < max_retries - 1: - wait_time = base_wait_time * (2 ** retry) - logger.error(f"[rerank]请求失败,等待{wait_time}秒后重试... 错误: {str(e)}", exc_info=True) - await asyncio.sleep(wait_time) - else: - logger.critical(f"重排序请求失败: {str(e)}", exc_info=True) - # 作为最后的备选方案,尝试使用chat/completions - return await self._fallback_rerank_with_chat_async(query, documents, top_k) - - logger.error("达到最大重试次数,重排序请求仍然失败") - return [] - - async def _fallback_rerank_with_chat_async(self, query: str, documents: list, top_k: int = 5) -> list: - """当rerank API失败时的备选方案,使用chat/completions异步实现重排序 - - Args: - query: 查询文本 - documents: 待排序的文档列表 - top_k: 返回前k个结果 - - Returns: - list: [(document, score), ...] 格式的结果列表 - """ - try: - logger.info("使用chat/completions作为重排序的备选方案") - - # 构建提示词 - prompt = f"""请对以下文档列表进行重排序,按照与查询的相关性从高到低排序。 -查询: {query} - -文档列表: -{documents} - -请以JSON格式返回排序结果,格式为: -[{{"document": "文档内容", "score": 相关性分数}}, ...] -只返回JSON,不要其他任何文字。""" - - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - data = { - "model": self.model_name, - "messages": [{"role": "user", "content": prompt}], - **self.params - } - - api_url = f"{self.base_url.rstrip('/')}/v1/chat/completions" - - async with aiohttp.ClientSession() as session: - async with session.post(api_url, headers=headers, json=data) as response: - response.raise_for_status() - result = await response.json() - - if "choices" in result and len(result["choices"]) > 0: - message = result["choices"][0]["message"] - content = message.get("content", "") - try: - import json - parsed_content = json.loads(content) - if isinstance(parsed_content, list): - return [(item["document"], item["score"]) for item in parsed_content] - except: - pass - return [] - except Exception as e: - logger.error(f"备选方案也失败了: {str(e)}") - return [] + embedding = await self._execute_request( + endpoint="/embeddings", + prompt=text, + payload={ + "model": model, + "input": text, + "encoding_format": "float" + }, + retry_policy={ + "max_retries": 2, + "base_wait": 6 + }, + response_handler=embedding_handler + ) + return embedding From 26f99664eebe706ed73c8b2f719c322763b91c5c Mon Sep 17 00:00:00 2001 From: KawaiiYusora Date: Fri, 7 Mar 2025 00:04:36 +0800 Subject: [PATCH 02/11] fix: cq_code async --- src/plugins/chat/cq_code.py | 4 ++-- src/plugins/chat/message.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index df93c6fa2..43d9d0862 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -335,7 +335,7 @@ class CQCode: class CQCode_tool: @staticmethod - def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: + async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: """ 将CQ码字典转换为CQCode对象 @@ -364,7 +364,7 @@ class CQCode_tool: ) # 进行翻译处理 - instance.translate() + await instance.translate() return instance @staticmethod diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 539e07989..02f56b975 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -49,7 +49,7 @@ class Message: translate_cq: bool = True # 是否翻译cq码 - def __post_init__(self): + async def __post_init__(self): if self.time is None: self.time = int(time.time()) @@ -64,7 +64,7 @@ class Message: if not self.processed_plain_text: if self.raw_message: - self.message_segments = self.parse_message_segments(str(self.raw_message)) + self.message_segments = await self.parse_message_segments(str(self.raw_message)) self.processed_plain_text = ' '.join( seg.translated_plain_text for seg in self.message_segments @@ -78,7 +78,7 @@ class Message: content = self.processed_plain_text self.detailed_plain_text = f"[{time_str}] {name}: {content}\n" - def parse_message_segments(self, message: str) -> List[CQCode]: + async def parse_message_segments(self, message: str) -> List[CQCode]: """ 将消息解析为片段列表,包括纯文本和CQ码 返回的列表中每个元素都是字典,包含: @@ -136,7 +136,7 @@ class Message: #翻译作为字典的CQ码 for _code_item in cq_code_dict_list: - message_obj = cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message) + message_obj = await cq_code_tool.cq_from_dict_to_class(_code_item,reply = self.reply_message) trans_list.append(message_obj) return trans_list From 0ebd24107750aa29fcd05dfa3c70ffe82e857633 Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Fri, 7 Mar 2025 01:06:36 +0800 Subject: [PATCH 03/11] =?UTF-8?q?fix:=20=E5=A2=9E=E5=8A=A0=E8=AE=BE?= =?UTF-8?q?=E7=BD=AE=E6=A8=A1=E6=9D=BF=EF=BC=8C=E4=BC=98=E5=8C=96emotion?= =?UTF-8?q?=EF=BC=8C=E4=BC=98=E5=8C=96=E5=8E=8B=E7=BC=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/bot_config_template.toml | 1 + src/plugins/chat/bot.py | 12 ++++++------ src/plugins/chat/emoji_manager.py | 25 +------------------------ src/plugins/chat/utils_image.py | 2 +- 4 files changed, 9 insertions(+), 31 deletions(-) diff --git a/config/bot_config_template.toml b/config/bot_config_template.toml index bc4ac18e3..afc2b5079 100644 --- a/config/bot_config_template.toml +++ b/config/bot_config_template.toml @@ -20,6 +20,7 @@ ban_words = [ [emoji] check_interval = 120 # 检查表情包的时间间隔 register_interval = 10 # 注册表情包的时间间隔 +check_prompt = "不要包含违反公序良俗的内容" # 表情包过滤要求 [cq_code] enable_pic_translate = false diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 89c15b388..add9bf978 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -163,12 +163,6 @@ class ChatBot: message_manager.add_message(message_set) bot_response_time = tinking_time_point - emotion = await self.gpt._get_emotion_tags(raw_content) - print(f"为 '{response}' 获取到的情感标签为:{emotion}") - valuedict={ - 'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25 - } - await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]]) if random() < global_config.emoji_chance: emoji_path = await emoji_manager.get_emoji_for_text(response) @@ -196,6 +190,12 @@ class ChatBot: # reply_message_id=message.message_id ) message_manager.add_message(bot_message) + emotion = await self.gpt._get_emotion_tags(raw_content) + print(f"为 '{response}' 获取到的情感标签为:{emotion}") + valuedict={ + 'happy':0.5,'angry':-1,'sad':-0.5,'surprised':0.5,'disgusted':-1.5,'fearful':-0.25,'neutral':0.25 + } + await relationship_manager.update_relationship_value(message.user_id, relationship_value=valuedict[emotion[0]]) # willing_manager.change_reply_willing_after_sent(event.group_id) diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index ede0d7135..cec454e4d 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -160,27 +160,6 @@ class EmojiManager: logger.error(f"获取表情包失败: {str(e)}") return None - async def _get_emoji_tag(self, image_base64: str) -> str: - """获取表情包的标签""" - try: - prompt = '这是一个表情包,请从"happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"中选出1个情感标签。只输出标签,不要输出其他任何内容,只输出情感标签就好' - - content, _ = await self.llm.generate_response_for_image(prompt, image_base64) - tag_result = content.strip().lower() - - valid_tags = ["happy", "angry", "sad", "surprised", "disgusted", "fearful", "neutral"] - for tag_match in valid_tags: - if tag_match in tag_result or tag_match == tag_result: - return tag_match - print(f"\033[1;33m[警告]\033[0m 无效的标签: {tag_result}, 跳过") - - except Exception as e: - print(f"\033[1;31m[错误]\033[0m 获取标签失败: {str(e)}") - return "neutral" - - print(f"\033[1;32m[调试信息]\033[0m 使用默认标签: neutral") - return "neutral" # 默认标签 - async def _get_emoji_discription(self, image_base64: str) -> str: """获取表情包的标签""" try: @@ -208,7 +187,7 @@ class EmojiManager: async def _get_kimoji_for_text(self, text:str): try: - prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。' + prompt = f'这是{global_config.BOT_NICKNAME}将要发送的消息内容:\n{text}\n若要为其配上表情包,请你输出这个表情包应该表达怎样的情感,应该给人什么样的感觉,不要太简洁也不要太长,注意不要输出任何对消息内容的分析内容,只输出\"一种什么样的感觉\"中间的形容词部分。' content, _ = await self.lm.generate_response_async(prompt) logger.info(f"输出描述: {content}") @@ -319,7 +298,6 @@ class EmojiManager: logger.info(f"其不满足过滤规则,被剔除 {check}") continue logger.info(f"check通过 {check}") - tag = await self._get_emoji_tag(image_base64) embedding = get_embedding(discription) if discription is not None: # 准备数据库记录 @@ -328,7 +306,6 @@ class EmojiManager: 'path': image_path, 'embedding':embedding, 'discription': discription, - 'tag':tag, 'timestamp': int(time.time()) } diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 922ab5228..9a7ef789a 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -252,7 +252,7 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 for frame_idx in range(img.n_frames): img.seek(frame_idx) new_frame = img.copy() - new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS) + new_frame = new_frame.resize((new_width//4, new_height//4), Image.Resampling.LANCZOS) # 动图折上折 frames.append(new_frame) # 保存到缓冲区 From e0e3ee417794a1294f51d476c7ad7f8514226b4d Mon Sep 17 00:00:00 2001 From: KawaiiYusora Date: Fri, 7 Mar 2025 01:31:03 +0800 Subject: [PATCH 04/11] fix: update CQCode and Message classes for async initialization and processing --- src/plugins/chat/bot.py | 1 + src/plugins/chat/cq_code.py | 71 +++++++++++++------------ src/plugins/chat/message.py | 88 ++++++++++++++++--------------- src/plugins/chat/utils.py | 3 +- src/plugins/models/utils_model.py | 52 +++++++++--------- 5 files changed, 112 insertions(+), 103 deletions(-) diff --git a/src/plugins/chat/bot.py b/src/plugins/chat/bot.py index 89c15b388..3a2d43f0e 100644 --- a/src/plugins/chat/bot.py +++ b/src/plugins/chat/bot.py @@ -58,6 +58,7 @@ class ChatBot: plain_text=event.get_plaintext(), reply_message=event.reply, ) + await message.initialize() # 过滤词 for word in global_config.ban_words: diff --git a/src/plugins/chat/cq_code.py b/src/plugins/chat/cq_code.py index 43d9d0862..0077f7295 100644 --- a/src/plugins/chat/cq_code.py +++ b/src/plugins/chat/cq_code.py @@ -10,11 +10,11 @@ from nonebot.adapters.onebot.v11 import Bot from .config import global_config import time import asyncio -from .utils_image import storage_image,storage_emoji +from .utils_image import storage_image, storage_emoji from .utils_user import get_user_nickname from ..models.utils_model import LLM_request -#解析各种CQ码 -#包含CQ码类 +# 解析各种CQ码 +# 包含CQ码类 import urllib3 from urllib3.util import create_urllib3_context from nonebot import get_driver @@ -27,6 +27,7 @@ ctx = create_urllib3_context() ctx.load_default_certs() ctx.set_ciphers("AES128-GCM-SHA256") + class TencentSSLAdapter(requests.adapters.HTTPAdapter): def __init__(self, ssl_context=None, **kwargs): self.ssl_context = ssl_context @@ -37,6 +38,7 @@ class TencentSSLAdapter(requests.adapters.HTTPAdapter): num_pools=connections, maxsize=maxsize, block=block, ssl_context=self.ssl_context) + @dataclass class CQCode: """ @@ -80,13 +82,13 @@ class CQCode: else: self.translated_plain_text = f"@某人" elif self.type == 'reply': - self.translated_plain_text = self.translate_reply() + self.translated_plain_text = await self.translate_reply() elif self.type == 'face': face_id = self.params.get('id', '') # self.translated_plain_text = f"[表情{face_id}]" self.translated_plain_text = f"[表情]" elif self.type == 'forward': - self.translated_plain_text = self.translate_forward() + self.translated_plain_text = await self.translate_forward() else: self.translated_plain_text = f"[{self.type}]" @@ -133,7 +135,7 @@ class CQCode: # 腾讯服务器特殊状态码处理 if response.status_code == 400 and 'multimedia.nt.qq.com.cn' in url: return None - + if response.status_code != 200: raise requests.exceptions.HTTPError(f"HTTP {response.status_code}") @@ -157,7 +159,7 @@ class CQCode: return None return None - + async def translate_emoji(self) -> str: """处理表情包类型的CQ码""" if 'url' not in self.params: @@ -170,11 +172,10 @@ class CQCode: return await self.get_emoji_description(base64_str) else: return '[表情包]' - - + async def translate_image(self) -> str: """处理图片类型的CQ码,区分普通图片和表情包""" - #没有url,直接返回默认文本 + # 没有url,直接返回默认文本 if 'url' not in self.params: return '[图片]' base64_str = self.get_img() @@ -206,13 +207,13 @@ class CQCode: except Exception as e: print(f"\033[1;31m[错误]\033[0m AI接口调用失败: {str(e)}") return "[图片]" - - def translate_forward(self) -> str: + + async def translate_forward(self) -> str: """处理转发消息""" try: if 'content' not in self.params: return '[转发消息]' - + # 解析content内容(需要先反转义) content = self.unescape(self.params['content']) # print(f"\033[1;34m[调试信息]\033[0m 转发消息内容: {content}") @@ -223,17 +224,17 @@ class CQCode: except ValueError as e: print(f"\033[1;31m[错误]\033[0m 解析转发消息内容失败: {str(e)}") return '[转发消息]' - + # 处理每条消息 formatted_messages = [] for msg in messages: sender = msg.get('sender', {}) nickname = sender.get('card') or sender.get('nickname', '未知用户') - + # 获取消息内容并使用Message类处理 raw_message = msg.get('raw_message', '') message_array = msg.get('message', []) - + if message_array and isinstance(message_array, list): # 检查是否包含嵌套的转发消息 for message_part in message_array: @@ -251,6 +252,7 @@ class CQCode: plain_text=raw_message, group_id=msg.get('group_id', 0) ) + await message_obj.initialize() content = message_obj.processed_plain_text else: content = '[空消息]' @@ -265,23 +267,24 @@ class CQCode: plain_text=raw_message, group_id=msg.get('group_id', 0) ) + await message_obj.initialize() content = message_obj.processed_plain_text else: content = '[空消息]' - + formatted_msg = f"{nickname}: {content}" formatted_messages.append(formatted_msg) - + # 合并所有消息 combined_messages = '\n'.join(formatted_messages) print(f"\033[1;34m[调试信息]\033[0m 合并后的转发消息: {combined_messages}") return f"[转发消息:\n{combined_messages}]" - + except Exception as e: print(f"\033[1;31m[错误]\033[0m 处理转发消息失败: {str(e)}") return '[转发消息]' - def translate_reply(self) -> str: + async def translate_reply(self) -> str: """处理回复类型的CQ码""" # 创建Message对象 @@ -289,7 +292,7 @@ class CQCode: if self.reply_message == None: # print(f"\033[1;31m[错误]\033[0m 回复消息为空") return '[回复某人消息]' - + if self.reply_message.sender.user_id: message_obj = Message( user_id=self.reply_message.sender.user_id, @@ -297,6 +300,7 @@ class CQCode: raw_message=str(self.reply_message.message), group_id=self.group_id ) + await message_obj.initialize() if message_obj.user_id == global_config.BOT_QQ: return f"[回复 {global_config.BOT_NICKNAME} 的消息: {message_obj.processed_plain_text}]" else: @@ -310,9 +314,9 @@ class CQCode: def unescape(text: str) -> str: """反转义CQ码中的特殊字符""" return text.replace(',', ',') \ - .replace('[', '[') \ - .replace(']', ']') \ - .replace('&', '&') + .replace('[', '[') \ + .replace(']', ']') \ + .replace('&', '&') @staticmethod def create_emoji_cq(file_path: str) -> str: @@ -327,12 +331,13 @@ class CQCode: abs_path = os.path.abspath(file_path) # 转义特殊字符 escaped_path = abs_path.replace('&', '&') \ - .replace('[', '[') \ - .replace(']', ']') \ - .replace(',', ',') + .replace('[', '[') \ + .replace(']', ']') \ + .replace(',', ',') # 生成CQ码,设置sub_type=1表示这是表情包 return f"[CQ:image,file=file:///{escaped_path},sub_type=1]" - + + class CQCode_tool: @staticmethod async def cq_from_dict_to_class(cq_code: Dict, reply: Optional[Dict] = None) -> CQCode: @@ -354,7 +359,7 @@ class CQCode_tool: params['text'] = cq_code.get('data', {}).get('text', '') else: params = cq_code.get('data', {}) - + instance = CQCode( type=cq_type, params=params, @@ -362,11 +367,11 @@ class CQCode_tool: user_id=0, reply_message=reply ) - + # 进行翻译处理 await instance.translate() return instance - + @staticmethod def create_reply_cq(message_id: int) -> str: """ @@ -377,6 +382,6 @@ class CQCode_tool: 回复CQ码字符串 """ return f"[CQ:reply,id={message_id}]" - - + + cq_code_tool = CQCode_tool() diff --git a/src/plugins/chat/message.py b/src/plugins/chat/message.py index 02f56b975..e1d36568c 100644 --- a/src/plugins/chat/message.py +++ b/src/plugins/chat/message.py @@ -27,56 +27,58 @@ class Message: """消息数据类""" message_id: int = None time: float = None - + group_id: int = None - group_name: str = None # 群名称 - + group_name: str = None # 群名称 + user_id: int = None user_nickname: str = None # 用户昵称 - user_cardname: str=None # 用户群昵称 - - raw_message: str = None # 原始消息,包含未解析的cq码 - plain_text: str = None # 纯文本 - + user_cardname: str = None # 用户群昵称 + + raw_message: str = None # 原始消息,包含未解析的cq码 + plain_text: str = None # 纯文本 + + reply_message: Dict = None # 存储 回复的 源消息 + + # 延迟初始化字段 + _initialized: bool = False message_segments: List[Dict] = None # 存储解析后的消息片段 processed_plain_text: str = None # 用于存储处理后的plain_text detailed_plain_text: str = None # 用于存储详细可读文本 - - reply_message: Dict = None # 存储 回复的 源消息 - - is_emoji: bool = False # 是否是表情包 - has_emoji: bool = False # 是否包含表情包 - - translate_cq: bool = True # 是否翻译cq码 - - async def __post_init__(self): - if self.time is None: - self.time = int(time.time()) - - if not self.group_name: - self.group_name = get_groupname(self.group_id) - - if not self.user_nickname: - self.user_nickname = get_user_nickname(self.user_id) - - if not self.user_cardname: - self.user_cardname=get_user_cardname(self.user_id) - - if not self.processed_plain_text: - if self.raw_message: - self.message_segments = await self.parse_message_segments(str(self.raw_message)) - self.processed_plain_text = ' '.join( - seg.translated_plain_text - for seg in self.message_segments - ) - #将详细翻译为详细可读文本 + + # 状态标志 + is_emoji: bool = False + has_emoji: bool = False + translate_cq: bool = True + + async def initialize(self): + """显式异步初始化方法(必须调用)""" + if self._initialized: + return + + # 异步获取补充信息 + self.group_name = self.group_name or get_groupname(self.group_id) + self.user_nickname = self.user_nickname or get_user_nickname(self.user_id) + self.user_cardname = self.user_cardname or get_user_cardname(self.user_id) + + # 消息解析 + if self.raw_message: + self.message_segments = await self.parse_message_segments(self.raw_message) + self.processed_plain_text = ' '.join( + seg.translated_plain_text + for seg in self.message_segments + ) + + # 构建详细文本 time_str = time.strftime("%m-%d %H:%M:%S", time.localtime(self.time)) - try: - name = f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" - except: - name = self.user_nickname or f"用户{self.user_id}" - content = self.processed_plain_text - self.detailed_plain_text = f"[{time_str}] {name}: {content}\n" + name = ( + f"{self.user_nickname}(ta的昵称:{self.user_cardname},ta的id:{self.user_id})" + if self.user_cardname + else f"{self.user_nickname or f'用户{self.user_id}'}" + ) + self.detailed_plain_text = f"[{time_str}] {name}: {self.processed_plain_text}\n" + + self._initialized = True async def parse_message_segments(self, message: str) -> List[CQCode]: """ diff --git a/src/plugins/chat/utils.py b/src/plugins/chat/utils.py index 38aeefd21..42c91b93c 100644 --- a/src/plugins/chat/utils.py +++ b/src/plugins/chat/utils.py @@ -131,7 +131,7 @@ def get_cloest_chat_from_db(db, length: int, timestamp: str): return '' -def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: +async def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: """从数据库获取群组最近的消息记录 Args: @@ -173,6 +173,7 @@ def get_recent_group_messages(db, group_id: int, limit: int = 12) -> list: processed_plain_text=msg_data.get("processed_text", ""), group_id=group_id ) + await msg.initialize() message_objects.append(msg) except KeyError: print("[WARNING] 数据库中存在无效的消息") diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 3e4d7f1a2..8addf6a46 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -65,7 +65,8 @@ class LLM_request: } api_url = f"{self.base_url.rstrip('/')}/{endpoint.lstrip('/')}" - logger.info(f"发送请求到URL: {api_url}{self.model_name}") + logger.info(f"发送请求到URL: {api_url}") + logger.info(f"使用模型: {self.model_name}") # 构建请求体 if image_base64: @@ -81,33 +82,32 @@ class LLM_request: headers = await self._build_headers() async with session_method as session: - response = await session.post(api_url, headers=headers, json=payload) + async with session.post(api_url, headers=headers, json=payload) as response: + # 处理需要重试的状态码 + if response.status in policy["retry_codes"]: + wait_time = policy["base_wait"] * (2 ** retry) + logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试") + if response.status == 413: + logger.warning("请求体过大,尝试压缩...") + image_base64 = compress_base64_image_by_scale(image_base64) + payload = await self._build_payload(prompt, image_base64) + elif response.status in [500, 503]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError("服务器负载过高,模型恢复失败QAQ") + else: + logger.warning(f"请求限制(429),等待{wait_time}秒后重试...") - # 处理需要重试的状态码 - if response.status in policy["retry_codes"]: - wait_time = policy["base_wait"] * (2 ** retry) - logger.warning(f"错误码: {response.status}, 等待 {wait_time}秒后重试") - if response.status == 413: - logger.warning("请求体过大,尝试压缩...") - image_base64 = compress_base64_image_by_scale(image_base64) - payload = await self._build_payload(prompt, image_base64) - elif response.status in [500, 503]: - logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") - raise RuntimeError("服务器负载过高,模型恢复失败QAQ") - else: - logger.warning(f"请求限制(429),等待{wait_time}秒后重试...") + await asyncio.sleep(wait_time) + continue + elif response.status in policy["abort_codes"]: + logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") + raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") - await asyncio.sleep(wait_time) - continue - elif response.status in policy["abort_codes"]: - logger.error(f"错误码: {response.status} - {error_code_mapping.get(response.status)}") - raise RuntimeError(f"请求被拒绝: {error_code_mapping.get(response.status)}") + response.raise_for_status() + result = await response.json() - response.raise_for_status() - result = await response.json() - - # 使用自定义处理器或默认处理 - return response_handler(result) if response_handler else self._default_response_handler(result) + # 使用自定义处理器或默认处理 + return response_handler(result) if response_handler else self._default_response_handler(result) except Exception as e: if retry < policy["max_retries"] - 1: @@ -116,7 +116,7 @@ class LLM_request: await asyncio.sleep(wait_time) else: logger.critical(f"请求失败: {str(e)}") - logger.critical(f"请求头: {self._build_headers()} 请求体: {payload}") + logger.critical(f"请求头: {await self._build_headers()} 请求体: {payload}") raise RuntimeError(f"API请求失败: {str(e)}") logger.error("达到最大重试次数,请求仍然失败") From a463f3a1a47dcbddfde38df725b57b75952313ef Mon Sep 17 00:00:00 2001 From: KawaiiYusora Date: Fri, 7 Mar 2025 01:37:17 +0800 Subject: [PATCH 05/11] fix: issue (bug_risk): Reusing ClientSession across retries may lead to closed session issues. --- src/plugins/models/utils_model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/plugins/models/utils_model.py b/src/plugins/models/utils_model.py index 8addf6a46..5d1f90ebb 100644 --- a/src/plugins/models/utils_model.py +++ b/src/plugins/models/utils_model.py @@ -74,14 +74,12 @@ class LLM_request: elif payload is None: payload = await self._build_payload(prompt) - session_method = aiohttp.ClientSession() - for retry in range(policy["max_retries"]): try: # 使用上下文管理器处理会话 headers = await self._build_headers() - async with session_method as session: + async with aiohttp.ClientSession() as session: async with session.post(api_url, headers=headers, json=payload) as response: # 处理需要重试的状态码 if response.status in policy["retry_codes"]: From b77d73ddc7b3215cb7cfeda3fbd5b0147c486709 Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Fri, 7 Mar 2025 01:49:42 +0800 Subject: [PATCH 06/11] =?UTF-8?q?feat:=20=E7=8E=B0=E5=9C=A8=E5=8F=AF?= =?UTF-8?q?=E4=BB=A5=E8=AE=BE=E7=BD=AE=E6=98=AF=E5=90=A6=E8=87=AA=E5=8A=A8?= =?UTF-8?q?=E4=BF=9D=E5=AD=98=E8=A1=A8=E6=83=85=E5=8C=85=E4=BA=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/bot_config_template.toml | 1 + src/plugins/chat/config.py | 2 ++ src/plugins/chat/utils_image.py | 3 +++ 3 files changed, 6 insertions(+) diff --git a/config/bot_config_template.toml b/config/bot_config_template.toml index afc2b5079..4428e1512 100644 --- a/config/bot_config_template.toml +++ b/config/bot_config_template.toml @@ -20,6 +20,7 @@ ban_words = [ [emoji] check_interval = 120 # 检查表情包的时间间隔 register_interval = 10 # 注册表情包的时间间隔 +auto_save = true # 自动偷表情包 check_prompt = "不要包含违反公序良俗的内容" # 表情包过滤要求 [cq_code] diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index e044edc5e..6fb6045da 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -30,6 +30,7 @@ class BotConfig: forget_memory_interval: int = 300 # 记忆遗忘间隔(秒) EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) + EMOJI_SAVE: bool = True # 偷表情包 EMOJI_CHECK_PROMPT: str = "不要包含违反公序良俗的内容" # 表情包过滤要求 ban_words = set() @@ -96,6 +97,7 @@ class BotConfig: config.EMOJI_CHECK_INTERVAL = emoji_config.get("check_interval", config.EMOJI_CHECK_INTERVAL) config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL) config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt',config.EMOJI_CHECK_PROMPT) + config.EMOJI_SAVE = emoji_config.get('auto_save',config.EMOJI_SAVE) if "cq_code" in toml_dict: cq_code_config = toml_dict["cq_code"] diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 9a7ef789a..503c2fa85 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -4,6 +4,7 @@ import hashlib import time import os from ...common.database import Database +from ..chat.config import global_config import zlib # 用于 CRC32 import base64 from nonebot import get_driver @@ -143,6 +144,8 @@ def storage_emoji(image_data: bytes) -> bytes: Returns: bytes: 原始图片数据 """ + if not global_config.EMOJI_SAVE: + return image_data try: # 使用 CRC32 计算哈希值 hash_value = format(zlib.crc32(image_data) & 0xFFFFFFFF, 'x') From 94fd4f5ddd1ac99e1484ba7dd5479594bc33294c Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Fri, 7 Mar 2025 02:47:52 +0800 Subject: [PATCH 07/11] =?UTF-8?q?fix:=20=E5=AF=B92MB=E4=BB=A5=E4=B8=8B?= =?UTF-8?q?=E7=9A=84=E5=9B=BE=E7=89=87=E4=BA=88=E4=BB=A5=E6=94=BE=E8=A1=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/chat/utils_image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index 503c2fa85..d79c0a913 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -230,7 +230,7 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 image_data = base64.b64decode(base64_data) # 如果已经小于目标大小,直接返回原图 - if len(image_data) <= target_size: + if len(image_data) <= 2*1024*1024: return base64_data # 将字节数据转换为图片对象 @@ -255,7 +255,7 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 for frame_idx in range(img.n_frames): img.seek(frame_idx) new_frame = img.copy() - new_frame = new_frame.resize((new_width//4, new_height//4), Image.Resampling.LANCZOS) # 动图折上折 + new_frame = new_frame.resize((new_width//2, new_height//2), Image.Resampling.LANCZOS) # 动图折上折 frames.append(new_frame) # 保存到缓冲区 From a3b8a545afa65596bc7a3963fd987e8a19fc9c6e Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Fri, 7 Mar 2025 03:12:35 +0800 Subject: [PATCH 08/11] =?UTF-8?q?fix:=20=E7=B4=A7=E6=80=A5=E4=B8=BAcheck?= =?UTF-8?q?=E5=8A=9F=E8=83=BD=E5=8A=A0=E5=85=A5=E5=BC=80=E5=85=B3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/bot_config_template.toml | 1 + src/plugins/chat/config.py | 2 ++ src/plugins/chat/emoji_manager.py | 15 ++++++++------- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/config/bot_config_template.toml b/config/bot_config_template.toml index 9a9fa2ebc..3287b3d20 100644 --- a/config/bot_config_template.toml +++ b/config/bot_config_template.toml @@ -21,6 +21,7 @@ ban_words = [ check_interval = 120 # 检查表情包的时间间隔 register_interval = 10 # 注册表情包的时间间隔 auto_save = true # 自动偷表情包 +enable_check = false # 是否启用表情包过滤 check_prompt = "不要包含违反公序良俗的内容" # 表情包过滤要求 [cq_code] diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index dbb6d7a6a..6cb8b9fee 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -31,6 +31,7 @@ class BotConfig: EMOJI_CHECK_INTERVAL: int = 120 # 表情包检查间隔(分钟) EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) EMOJI_SAVE: bool = True # 偷表情包 + EMOJI_CHECK: bool = False #是否开启过滤 EMOJI_CHECK_PROMPT: str = "不要包含违反公序良俗的内容" # 表情包过滤要求 ban_words = set() @@ -100,6 +101,7 @@ class BotConfig: config.EMOJI_REGISTER_INTERVAL = emoji_config.get("register_interval", config.EMOJI_REGISTER_INTERVAL) config.EMOJI_CHECK_PROMPT = emoji_config.get('check_prompt',config.EMOJI_CHECK_PROMPT) config.EMOJI_SAVE = emoji_config.get('auto_save',config.EMOJI_SAVE) + config.EMOJI_CHECK = emoji_config.get('enable_check',config.EMOJI_CHECK) if "cq_code" in toml_dict: cq_code_config = toml_dict["cq_code"] diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index cec454e4d..3592bd09b 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -291,13 +291,14 @@ class EmojiManager: # 获取表情包的描述 discription = await self._get_emoji_discription(image_base64) - check = await self._check_emoji(image_base64) - if '是' not in check: - os.remove(image_path) - logger.info(f"描述: {discription}") - logger.info(f"其不满足过滤规则,被剔除 {check}") - continue - logger.info(f"check通过 {check}") + if global_config.EMOJI_CHECK: + check = await self._check_emoji(image_base64) + if '是' not in check: + os.remove(image_path) + logger.info(f"描述: {discription}") + logger.info(f"其不满足过滤规则,被剔除 {check}") + continue + logger.info(f"check通过 {check}") embedding = get_embedding(discription) if discription is not None: # 准备数据库记录 From a7fedba79effd1d7e9ee75ba59f7574361ef5e7d Mon Sep 17 00:00:00 2001 From: KawaiiYusora Date: Fri, 7 Mar 2025 03:36:37 +0800 Subject: [PATCH 09/11] =?UTF-8?q?fix:=20=E9=82=A3=E6=88=91=E9=97=AE?= =?UTF-8?q?=E4=BD=A0=E9=82=A3=E6=88=91=E9=97=AE=E4=BD=A0=20get=5Fembedding?= =?UTF-8?q?=E6=B2=A1=E6=9C=89=E4=BD=BF=E7=94=A8=E5=8D=8F=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/chat/emoji_manager.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 3592bd09b..9bd71ddd8 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -299,7 +299,7 @@ class EmojiManager: logger.info(f"其不满足过滤规则,被剔除 {check}") continue logger.info(f"check通过 {check}") - embedding = get_embedding(discription) + embedding = await get_embedding(discription) if discription is not None: # 准备数据库记录 emoji_record = { From 0ced4939ec20f4185106f83d221273f7678d4f94 Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Fri, 7 Mar 2025 03:40:14 +0800 Subject: [PATCH 10/11] =?UTF-8?q?fix:=20=E4=BF=AE=E6=94=B9embedding?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/bot_config_template.toml | 2 +- src/plugins/chat/config.py | 2 +- src/plugins/chat/emoji_manager.py | 4 ++-- src/plugins/memory_system/memory.py | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/config/bot_config_template.toml b/config/bot_config_template.toml index 3287b3d20..507c6d2d6 100644 --- a/config/bot_config_template.toml +++ b/config/bot_config_template.toml @@ -22,7 +22,7 @@ check_interval = 120 # 检查表情包的时间间隔 register_interval = 10 # 注册表情包的时间间隔 auto_save = true # 自动偷表情包 enable_check = false # 是否启用表情包过滤 -check_prompt = "不要包含违反公序良俗的内容" # 表情包过滤要求 +check_prompt = "符合公序良俗" # 表情包过滤要求 [cq_code] enable_pic_translate = false diff --git a/src/plugins/chat/config.py b/src/plugins/chat/config.py index 6cb8b9fee..a2adc9e30 100644 --- a/src/plugins/chat/config.py +++ b/src/plugins/chat/config.py @@ -32,7 +32,7 @@ class BotConfig: EMOJI_REGISTER_INTERVAL: int = 10 # 表情包注册间隔(分钟) EMOJI_SAVE: bool = True # 偷表情包 EMOJI_CHECK: bool = False #是否开启过滤 - EMOJI_CHECK_PROMPT: str = "不要包含违反公序良俗的内容" # 表情包过滤要求 + EMOJI_CHECK_PROMPT: str = "符合公序良俗" # 表情包过滤要求 ban_words = set() diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 3592bd09b..1cdb62c07 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -98,7 +98,7 @@ class EmojiManager: # 获取文本的embedding text_for_search= await self._get_kimoji_for_text(text) - text_embedding = get_embedding(text_for_search) + text_embedding = await get_embedding(text_for_search) if not text_embedding: logger.error("无法获取文本的embedding") return None @@ -299,7 +299,7 @@ class EmojiManager: logger.info(f"其不满足过滤规则,被剔除 {check}") continue logger.info(f"check通过 {check}") - embedding = get_embedding(discription) + embedding = await get_embedding(discription) if discription is not None: # 准备数据库记录 emoji_record = { diff --git a/src/plugins/memory_system/memory.py b/src/plugins/memory_system/memory.py index 49d19c253..a25e15bdf 100644 --- a/src/plugins/memory_system/memory.py +++ b/src/plugins/memory_system/memory.py @@ -673,7 +673,7 @@ class Hippocampus: if first_layer: # 如果记忆条数超过限制,随机选择指定数量的记忆 if len(first_layer) > max_memory_num/2: - first_layer = random.sample(first_layer, max_memory_num) + first_layer = random.sample(first_layer, max_memory_num//2) # 为每条记忆添加来源主题和相似度信息 for memory in first_layer: relevant_memories.append({ From d0047e82bf75386b2845776d7b0cc7aac1646021 Mon Sep 17 00:00:00 2001 From: tcmofashi Date: Fri, 7 Mar 2025 04:01:09 +0800 Subject: [PATCH 11/11] =?UTF-8?q?fix:=20=E5=8E=BB=E9=99=A4emoji=5Fmanager?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=E5=9B=BE=E7=89=87=E5=8E=8B=E7=BC=A9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/plugins/chat/emoji_manager.py | 72 +------------------------------ src/plugins/chat/utils_image.py | 17 +++++++- 2 files changed, 18 insertions(+), 71 deletions(-) diff --git a/src/plugins/chat/emoji_manager.py b/src/plugins/chat/emoji_manager.py index 1cdb62c07..4784c0a3d 100644 --- a/src/plugins/chat/emoji_manager.py +++ b/src/plugins/chat/emoji_manager.py @@ -20,6 +20,7 @@ import traceback from nonebot import get_driver from ..chat.config import global_config from ..models.utils_model import LLM_request +from ..chat.utils_image import image_path_to_base64 from ..chat.utils import get_embedding driver = get_driver() @@ -196,76 +197,7 @@ class EmojiManager: except Exception as e: logger.error(f"获取标签失败: {str(e)}") return None - - async def _compress_image(self, image_path: str, target_size: int = 0.8 * 1024 * 1024) -> Optional[str]: - """压缩图片并返回base64编码 - Args: - image_path: 图片文件路径 - target_size: 目标文件大小(字节),默认0.8MB - Returns: - Optional[str]: 成功返回base64编码的图片数据,失败返回None - """ - try: - file_size = os.path.getsize(image_path) - if file_size <= target_size: - # 如果文件已经小于目标大小,直接读取并返回base64 - with open(image_path, 'rb') as f: - return base64.b64encode(f.read()).decode('utf-8') - - # 打开图片 - with Image.open(image_path) as img: - # 获取原始尺寸 - original_width, original_height = img.size - - # 计算缩放比例 - scale = min(1.0, (target_size / file_size) ** 0.5) - - # 计算新的尺寸 - new_width = int(original_width * scale) - new_height = int(original_height * scale) - - # 创建内存缓冲区 - output_buffer = io.BytesIO() - - # 如果是GIF,处理所有帧 - if getattr(img, "is_animated", False): - frames = [] - for frame_idx in range(img.n_frames): - img.seek(frame_idx) - new_frame = img.copy() - new_frame = new_frame.resize((new_width, new_height), Image.Resampling.LANCZOS) - frames.append(new_frame) - # 保存到缓冲区 - frames[0].save( - output_buffer, - format='GIF', - save_all=True, - append_images=frames[1:], - optimize=True, - duration=img.info.get('duration', 100), - loop=img.info.get('loop', 0) - ) - else: - # 处理静态图片 - resized_img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) - - # 保存到缓冲区,保持原始格式 - if img.format == 'PNG' and img.mode in ('RGBA', 'LA'): - resized_img.save(output_buffer, format='PNG', optimize=True) - else: - resized_img.save(output_buffer, format='JPEG', quality=95, optimize=True) - - # 获取压缩后的数据并转换为base64 - compressed_data = output_buffer.getvalue() - logger.success(f"压缩图片: {os.path.basename(image_path)} ({original_width}x{original_height} -> {new_width}x{new_height})") - - return base64.b64encode(compressed_data).decode('utf-8') - - except Exception as e: - logger.error(f"压缩图片失败: {os.path.basename(image_path)}, 错误: {str(e)}") - return None - async def scan_new_emojis(self): """扫描新的表情包""" try: @@ -284,7 +216,7 @@ class EmojiManager: continue # 压缩图片并获取base64编码 - image_base64 = await self._compress_image(image_path) + image_base64 = image_path_to_base64(image_path) if image_base64 is None: os.remove(image_path) continue diff --git a/src/plugins/chat/utils_image.py b/src/plugins/chat/utils_image.py index d79c0a913..eff788868 100644 --- a/src/plugins/chat/utils_image.py +++ b/src/plugins/chat/utils_image.py @@ -289,4 +289,19 @@ def compress_base64_image_by_scale(base64_data: str, target_size: int = 0.8 * 10 logger.error(f"压缩图片失败: {str(e)}") import traceback logger.error(traceback.format_exc()) - return base64_data \ No newline at end of file + return base64_data + +def image_path_to_base64(image_path: str) -> str: + """将图片路径转换为base64编码 + Args: + image_path: 图片文件路径 + Returns: + str: base64编码的图片数据 + """ + try: + with open(image_path, 'rb') as f: + image_data = f.read() + return base64.b64encode(image_data).decode('utf-8') + except Exception as e: + logger.error(f"读取图片失败: {image_path}, 错误: {str(e)}") + return None \ No newline at end of file